The current max for the number of tokens that can be inputted into the LLaMA model is 2048 tokens. For many NLP tasks such as long chat conversations, summarizing long documents, or question-answering, this window limit is frequently a pain. Researchers from Meta have discovered a method to extend the context window of a trained model up to 32 times its original with only several hundred training steps.
We’ve previously covered a few methods for extending the context window for LLMs including FlashAttention, a method to more efficiently compute the attention at the core of the Transformer architecture employed by these models. While FlashAttention (check that out if you haven’t yet) was utilized in this new approach, the main algorithm which enables such a large increase in context window is termed Position Interpolation.
Sidenote - When did Meta somehow become the good guy in this tale of open-source LLM research? Many of the issues of Let’s Talk Text cover open-source models or research done by Meta AI.
Before we get to Position Interpolation and how it works, we need to review position embeddings and how they’re computed.
Position embeddings explained
Transformers use a smart positional encoding scheme, in which each index in the input tokens is mapped to a vector. That means the position embedding for a sequence of tokens is effectively a matrix, in which each row corresponds to a position in the input sequence.
The values of this matrix are calculated uniquely. They, of course, use the index of the token in the sequence (termed k
), as well as a few other constants such as the number of columns in the output matrix. We don’t need to worry about these for now, but just take a look at the equations that are used below and know that each token in the input is effectively mapped onto a sin/cos curve:
That means for each index, which is denoted k
, there will be different sin and cos waves produced.
Source: Machine Learning Mastery
Position Interpolation
Remember, the main goal here is to extend the context window on an already-trained LLM. It is not to expand the context window on a model being trained from scratch by somehow changing its model architecture.
Earlier attempts to extend the context window included attempting to extrapolate the encodings outside the trained positions they were mapped to, however, this has been shown to lead to catastrophic values. Think of this approach as trying to expand the curves shown in the previous figures. Instead, it turns out you can take the same curve on which the original position encodings were mapped onto, and interpolate the position encodings at neighboring integer values.
The top left indicates a normal usage of an LLM. The top right indicates how previous research attempted to extrapolate the position encodings outside of the trained values. The bottom left shows how Position Interpolation works. Notice how there are much more dots (more position encodings) on the same curve. This means you’re essentially mapping more position embeddings onto a smaller sin/cos curve. This approach is far more stable and takes advantage of the fact that position encodings can be applied to non-integer values.
Empirical results
Turns out Position Interpolation is highly efficient:
With only 1000 training steps, the context window for LLaMA models ranging in size from 7B to 65B was extended from 2048 tokens to 32768 tokens.Â
Model quality is preserved for tasks within its original context window sizes
Models with Position Interpolation can take advantage of their greater context window - with competitive performance on text modeling and long text summarization
Hey Shaan! This is Ayush from India. I am doing MTech in Artificial Intelligence from IISc Bangalore. My MTech project will be focused on increasing the Context length of Large Language Models. I liked your post. I want to go through the code and would like to have a small discussion with you.
I have sent you a request on Linkedin. My email id is singhayush9084@gmail.com / ayushsingh@iisc.ac.in, I will be waiting to hear from you.
There's something I'm missing. If the original model is design to only be able to take 2048 tokens, regardless of the position encoding used it can still only attend to a maximum of 2048 tokens. Thats the size of the transformer input. How do you expand this on an already trained model ?