Llama in Jax

Introduction

If you are not familiar with deep learning framework, start by reading my first post about Jax. In this post, we will build the llama model in jax. This is a state of the art open-source LLM from Meta’s AI research lab. In other words, this is chatGPT. We are building Llama3.2-1B-Instruct. It means the version 3.2, with 1 billion parameters and trained to respond to questions (instruct).

Getting started and how to start.

There are two things understand about Llama: There are the architecture and the weights. You can think of it as the metadata and the data, or the table of content and the real content. We will start by creating the architecture and then we will load the weigts in it.

Architecture

Llama’s model are transformers but they changed a few things since the original paper was published.

  • Uses Positional encoding (RoPE) instead of the original sinusoidal encoding.
  • 8