Brief view of Google’s Trax Library

What is Trax?

 Trax- An end-to-end library which gives easy and understandable code. The code in Trax usually done with a much simpler structure, which counters the big codes of the other libraries like TensorFlow and pytorch. actively used and maintained in the Google Brain team

It is originally a derived version of many libraries, but mainly it follows the TensorFlow style.

Installation: :-

         Fig:- Installation of Trax

You can see above that using it is similar to tensorFlow and numpy.

In tensorflow we define “import tensorflow as tf” here in this library we do it same as ”import trax”. 

It includes some of the basic models essential for NLP tasks (example: LSTM, ResNET, Transformer). Utilized in various fields such as research library for constantly developing new models and testing them on the dataset which includes Tensorflow datasets. T2T also known as Tensor-2-Tensor datasets.

This library – used for python based notebooks to work on custom models, as well as shell based command oriented for training of the model on pre-trained models.

How does Trax work?

originated from TensorFlow and uses the core python libraries like numpy. It has some packages which introduce more efficient ways to code in Trax.

  1. Trax and fastmath

Its model works on an array based structure known as Tensors, usually operated using the “numpy.array” library functions. Using this library and numpy library in combination the computation speed is increased by making use of GPUs and TPUs to accelerate them. This enables the need of calculating the gradient automatically on the tensors, which is also pre-packaged into “trax.fastmath” package thanks to its backends — JAX and TensorFlow numpy.

Following is the basic code of fastmath and trax numpy.

Fig :- Its working (https://github.com/google/trax)

  1. Layers

Layers are the necessary building blocks of this library, a layer in this is capable of computing a function with zero or more inputs or zero or more outputs. The inputs and outputs are tensors which work as JAX and numpy.array. A Trax layer which does not have any weights or sublayers can be used without initialization of the layer.

The Layers are also defined as objects, which makes them easy – the “__call__” method, this enables us to use directly on the input data.

Code below – from its documentation.

These Layers are the same as in other frameworks like Tensorflow or Pytorch, but what makes them different from others is the number of lines these layers are coded into.

We will now directly cut to the implementation of its layers. In tensorflow the model is defined using “Sequential”, here in this, it is done using “Serial”. “Serial” is a combinator that combines the sublayers based on the input/output of each layer. It uses stacking of layers which makes it easy to pass the inputs to each layer.

 Example :

It seems exactly as a “tensorflow.Sequential” model, but it’s the internal structure of the layers, which makes it to run fast.

It also allows you to define your own layers and sublayers.

Here’s how the trax model looks with actual layers:

Following is the code block from one of my course notebooks where I learnt about Trax.

Conclusion:-

As it is still in development, I am fortunate to get a hands on experience of using this library in an online course. I can say that this made the NLP tasks, rather implementing Deep Learning and Neural Networks like RNNs much easier.

References:-

  1. This link is about “ How Trax came into Existence ”

https://coursera.org/share/1bdab833b3fbbee79133006f2cab236f

  1. This link follows the Trax documentation in detail.

https://trax-ml.readthedocs.io/en/latest/notebooks/trax_intro.html

Written By: Pushpraj Maraje

Reviewed By: Viswanadh

If you are Interested In Machine Learning You Can Check Machine Learning Internship Program
Also Check Other Technical And Non Technical Internship Programs

Leave a Comment

Your email address will not be published. Required fields are marked *