Inference-Pleasant Models with MixAttention | Databricks Weblog

Inference-Pleasant Models with MixAttention | Databricks Weblog
Inference-Pleasant Models with MixAttention | Databricks Weblog


Transformer fashions, the spine of contemporary language AI, depend on the eye mechanism to course of context when producing output. Throughout inference, the eye mechanism works by computing the important thing and worth vectors for every token seen to date, and utilizing these vectors to replace the interior illustration of the subsequent token which can be output. As a result of the identical key and worth vectors of the previous tokens get reused each time the mannequin outputs a brand new token, it’s commonplace apply to cache it in a knowledge construction referred to as the Key-Worth (KV) cache. Because the KV cache grows proportionally to the variety of tokens seen to date, KV cache dimension is a significant component in figuring out each the utmost context size (i.e., the utmost variety of tokens) and the utmost variety of concurrent requests that may be supported for inference on fashionable language fashions. Notably for lengthy inputs, LLM inference can be dominated by the I/O price of shifting the KV cache from Excessive Bandwidth Reminiscence (HBM) to the GPU’s shared reminiscence. Due to this fact, lowering the KV cache dimension has the potential to be a strong methodology to hurry up and scale back the price of inference on fashionable language fashions. On this put up, we discover ideas recently proposed by Character.AI for lowering KV cache dimension by changing a lot of the layers within the community with sliding window consideration (a type of native consideration that solely makes use of the important thing and worth vectors of a small variety of most up-to-date tokens) and sharing the KV cache amongst layers. We name this structure MixAttention; our experiments with completely different variants of this structure have demonstrated that it maintains each brief and lengthy context mannequin high quality whereas bettering the inference velocity and reminiscence footprint.

MixAttention Performance Tables
Determine 1: Pace and accuracy of MixAttention mannequin variants. (Mannequin variants proven in determine 2). Prime: We see that MixAttention fashions are quicker and use much less reminiscence throughout inference at 32K context size. Backside: MixAttention fashions keep high quality – they match the usual consideration mannequin on most evals. The fashions are all Combination of Consultants with 2B lively and 5B complete parameters.

We discovered that KV cache sharing between layers and including sliding window layers can velocity up inference and scale back inference reminiscence utilization whereas sustaining mannequin high quality, though some eval metrics present some degradation. As well as, our ablation experiments confirmed the next:

 

  • Having a couple of commonplace consideration layers is essential for the mannequin’s lengthy context talents. Specifically, having the usual KV cache computed within the deeper layers is extra vital for lengthy context talents than the usual KV cache of the primary few layers.
  • KV cache of normal consideration layers could be shared between non-consecutive layers with none noticed degradation in lengthy context talents.
  • Growing the KV-cache sharing between sliding window layers an excessive amount of additionally hurts lengthy context talents. 

We’ve offered a information to configuring and coaching MixAttention fashions utilizing LLM Foundry within the appendix of this weblog put up.

 

Image 2 Mix Attention Blog
Determine 2: (Left) A normal transformer mannequin the place all layers are commonplace consideration layers. (Center) Inference-friendly fashions with MixAttention. Inexperienced bars signify sliding window consideration and the traces connecting bars signify KV cache sharing. (Proper) A mannequin the place all layers are sliding window consideration.

MixAttention Structure Overview

Commonplace transformer fashions use international consideration in every layer. To create inference-friendly mannequin architectures, we used a mix of sliding window consideration layers, commonplace consideration, and KV cache reuse layers. Beneath is a short dialogue of every part:

 

  • Sliding Window Attention Layers: In Sliding Window Consideration (or Native Consideration) with window dimension s, the question solely pays consideration to the final s keys as an alternative of all of the keys previous it. Which means that throughout inference, the KV cache dimension must solely retailer the KV tensors for the previous s tokens as an alternative of storing the KV tensors for all of the previous tokens. In our experiments, we set a window dimension of s=1024 tokens.
  • Commonplace Consideration Layers: We discovered that regardless that Commonplace Consideration Layers result in larger KV caches and slower consideration computation in comparison with Sliding Window Consideration, having a couple of Commonplace Consideration Layers is essential for the mannequin’s lengthy context talents.
  • KV cache reuse: This refers to a layer within the transformer community that’s reusing the KV cache computed by a earlier layer. Therefore, if each l layers share KV tensors, then the scale of KV cache is decreased by issue of 1/l.

 

We experimented with completely different mixtures of the parts above to ablate the results of every of them. (Extra mixtures are described within the appendices.) We discovered that not solely do every of the above parts play vital roles in lengthy context talents and inference velocity and reminiscence consumption, but in addition their relative positions and counts have vital results on these metrics.

 

The fashions we educated are 24-layer Combination of Consultants (MoE) fashions with 1.64B lively and 5.21B complete parameters. We used RoPE positional embeddings and elevated the RoPE base theta as we elevated the context size throughout coaching. We used Grouped Query Attention with 12 consideration heads and three KV heads.

 

Coaching

We used LLM Foundry to coach MixAttention fashions. Just like prior work on coaching lengthy context fashions, we adopted a multi-stage coaching process to impart lengthy context talents to the fashions.

 

  1. We pretrained the fashions with a RoPE theta of 0.5M on 101B tokens, the place every sequence has been truncated to 4k token size.
  2. To extend the context size, we then educated the mannequin on 9B tokens from a mixture of pure language and code information, the place the sequences have been truncated to 32k tokens. We elevated the RoPE theta to 8M for this stage. When coaching at 32k context size, we educated solely the eye weights and froze the remainder of the community. We discovered that this delivered higher outcomes than full community coaching.
  3. Lastly, we educated the mannequin on a 32k-length, artificial, long-context QA dataset.
    • To create the dataset, we took pure language paperwork and chunked them into 1k-token chunks. Every chunk was then fed to a pretrained instruction mannequin and the mannequin was prompted to generate a question-answer pair based mostly on the chunk. Then, we concatenated chunks from completely different paperwork collectively to function the “lengthy context.” On the finish of this lengthy context, the question-answer pairs for every of the chunks had been added. The loss gradients had been computed solely on the reply elements of those sequences. 
    • This part of coaching was performed on 500M tokens (this quantity consists of the tokens from the context, questions, and solutions). The RoPE theta was saved at 8M for this stage.

Analysis

The fashions had been evaluated on the Mosaic Evaluation Gauntlet to measure mannequin high quality throughout numerous metrics together with studying comprehension, commonsense reasoning, world information, symbolic drawback fixing, and language understanding. To guage the fashions’ lengthy context talents, we used RULER at a context size of 32000 tokens. RULER is a composite benchmark consisting of 13 particular person evals of the next sorts:

 

  • Needle-in-a-haystack (NIAH): Most of these evals conceal a single or a number of keys and values in a protracted textual content, and the mannequin is evaluated on its skill to retrieve the right worth(s) from the lengthy context for a given key(s).
  • Variable Monitoring (VT): This eval supplies the mannequin with a protracted context containing variable project statements, and the mannequin is tasked to determine which variables have a selected worth by the tip of all of the variable assignments.
  • Widespread and Frequent Phrase Extraction (CWE and FWE): These duties ask the mannequin to extract the commonest or frequent phrases from the textual content.
  • Query Answering (QA): Given a protracted context, the mannequin is requested a query from someplace within the context and is evaluated on whether or not it may accurately reply that query.

 

We used SGLang to deploy our fashions on 1 NVIDIA H100 GPU to run RULER and get inference velocity and reminiscence consumption metrics.

Outcomes

Place and Depend of Commonplace Consideration KV Caches

To measure the impact of the place and depend of the usual consideration KV caches, we tried 4 variants. All of the configurations are variants of the configuration proposed in Character.AI’s blog post.

MixAttention Image 3
Determine 3: KV Cache place and counts. To measure the impact of the place and depend of the usual consideration KV caches on MixAttention’s lengthy context talents, we educated and evaluated the 4 fashions proven above.
  1. MA: This variant has a single commonplace consideration KV cache, which is the KV cache of the primary layer. All the opposite commonplace consideration layers share this KV cache.
  2. MA-EndSlide: This variant is identical as MA, however the final layer is a sliding window consideration layer. This was finished to measure how a lot having commonplace consideration within the final layer impacts long-context talents.
  3. MA-Offset: This variant is just like MA, however the first commonplace consideration layer is offset to a later layer to permit the mannequin to course of the native context for a couple of layers earlier than the usual consideration layer is used to take a look at longer contexts.
  4. MA-Pairs: This variant computes two commonplace consideration KV caches (on the first and thirteenth layers), that are then shared with one other commonplace consideration layer every.

We in contrast these fashions to a transformer mannequin with Commonplace Consideration and a transformer mannequin with Sliding Window Consideration in all layers.

MixAttention image 4

MixAttention_image5
Fig. 4 and 5: Impact of Commonplace Consideration Layers. (Prime) Loss curves of the fashions when advantageous tuning on lengthy context QA dataset. (Backside) RULER evals for the fashions. MA and MA-EndSlide carry out poorly on lengthy context duties whereas MA-Offset and MA-Pairs carry out nicely. This means that having an ordinary consideration KV cache which is computed in later layers is vital for lengthy context talents. We additionally discovered that the loss on lengthy context QA dataset correlates nicely with the mannequin’s lengthy context talents.

Whereas the loss curves in Phases 1 and a pair of of Coaching had been shut for all of the fashions, we discovered that in Stage 3 (coaching on lengthy context QA dataset), there was a transparent bifurcation within the loss curves. Specifically, we see that configurations MA and MA-EndSlide present a lot worse loss than the others. These outcomes are in line with the lengthy context RULER evals, the place we discovered that MA and MA-EndSlide carried out a lot worse than others. Their efficiency was just like the efficiency of the community with solely sliding window consideration in all layers. We predict the loss in Stage 3 correlates nicely with RULER evals as a result of in contrast to Phases 1 and a pair of, which had been next-word prediction duties the place native context was ample to foretell the subsequent phrase more often than not, in Stage 3 the mannequin wanted to retrieve the right info from doubtlessly long-distance context to reply the questions. 

 

As we see from the RULER evals, MA-Offset and MA-Pairs have higher long-context talents than MA and MA-EndSlide throughout all of the classes. Each MA and MA-EndSlide have just one commonplace consideration KV cache, which is computed within the first layer, whereas each MA-Offset and MA-Pairs have not less than one commonplace consideration KV cache which is computed in deeper layers.  Therefore, this means that having not less than one commonplace consideration KV cache computed within the deeper layers of a transformer mannequin is important for good long-context talents.

KV cache sharing in sliding window layers

MixAttention Image 6
Fig. 6: Growing KV cache sharing in sliding window layers. To measure the impact of KV cache sharing within the sliding window layers, we in contrast the architectures proven within the determine above.

Mix Attention Image 7

Mix Attention Image 8
Fig. 7 and eight: Impact of accelerating KV cache sharing in sliding window layers. (Prime) Loss curves of the fashions when advantageous tuning on lengthy context QA dataset. (Backside) RULER evals for the fashions. We discovered that rising the KV cache sharing in sliding window layers worsened lengthy context talents of MixAttention Models.

We discovered that rising the sharing between sliding window layers degraded the mannequin’s lengthy context efficiency: MA-Offset-slide-share was worse than MA-Offset and MA-Pairs-SlideShare was worse than MA-Pairs. This reveals that the KV cache sharing sample amongst the sliding window layers can also be vital for lengthy context talents.

 

We’ve offered the outcomes of some extra ablation experiments within the appendices.

Gauntlet Evals

Utilizing the Mosaic Eval Gauntlet v0.3.0, we additionally measured the efficiency of MixAttention fashions on commonplace duties like MMLU, HellaSwag, and many others. to confirm that they maintain good shorter context talents. All the duties on this eval suite have context lengths of lower than a couple of thousand tokens.

MixAttention Figure 9
Fig. 9: Efficiency of MixAttention fashions on the Eval Gauntlet. We discovered that MixAttention fashions have comparable eval metrics to the baseline mannequin on commonsense reasoning, language understanding, and world information. Nevertheless, we see that they carry out worse on studying comprehension.

We discovered that MixAttention fashions have comparable eval metrics to the baseline mannequin on commonsense reasoning, language understanding, and world information; nevertheless, they carried out worse on studying comprehension. An fascinating open query is that if studying comprehension talents could possibly be improved with a distinct MixAttention configuration or by coaching MixAttention fashions longer.

Inference Pace and Reminiscence Consumption

Mix Attention Image 10

MixAttention Image 11
Fig. 10 and 11: (Prime) MixAttention fashions have considerably quicker inference than commonplace transformers. (Backside) MixAttention fashions can help extra tokens, and thus bigger batch sizes, throughout inference.

We benchmarked the inference velocity and reminiscence consumption of MixAttention fashions by deploying them on a single NVIDIA H100 GPU utilizing SGLang and querying them with 300 prompts, with an enter size of 31000 and output size of 1000. Within the determine, we present that the inference velocity of MixAttention fashions is far quicker than commonplace consideration fashions. We additionally present that with MixAttention, we will help a a lot bigger inference batch dimension by way of the full variety of tokens. 

 

We discovered that the present implementation of Sliding Window Consideration in SGLang doesn’t optimize the reminiscence consumption for sliding window consideration; therefore, sliding window consideration has the identical most variety of tokens as the usual consideration Mannequin. Optimizing the reminiscence consumption for sliding window consideration ought to additional improve the utmost variety of tokens that MixAttention can help throughout inference.

Conclusion

We discovered that MixAttention fashions are aggressive with commonplace consideration fashions on each long- and short-context talents whereas being quicker throughout inference and supporting bigger batch sizes. We additionally noticed that on some lengthy context duties like Variable Monitoring and Widespread Phrase Extraction, neither MixAttention nor commonplace consideration fashions carried out nicely. We consider this was as a result of our fashions weren’t educated lengthy sufficient or the fashions want a distinct sort of lengthy context information to be educated for such duties. Extra analysis must be finished to measure the affect of MixAttention architectures on these metrics.

 

We encourage others to discover extra MixAttention architectures to study extra about them. Beneath are a couple of observations to assist with additional analysis:

 

  • Including an ordinary consideration layer within the preliminary layers by itself doesn’t appear to assist lengthy context talents (for instance, see MA-NoShare-1 within the appendix), even when the KV cache from that layer is reused in layers deeper into the community (MA and MA-EndSlide). Therefore we advocate putting the primary commonplace consideration layer deeper within the community (like MA-Offset) or having a number of commonplace consideration layers, not less than one in all which is computed at a deeper layer (like MA-Pairs).
  • Sliding window layers additionally contribute to the mannequin’s lengthy context talents. Growing the KV cache sharing amongst the sliding window layers worsened lengthy context talents (MA-Offset-SlideShare and MA-Pairs-SlideShare). For that cause, we expect that the 2-3 sharing sample in sliding window layers appears to strike an excellent stability.
  • Sharing full consideration KV caches between consecutive layers gave blended outcomes, with barely worse accuracy on lengthy context QA duties (see the appendix). 
  • In our experiments, MA-Offset and MA-Pair confirmed nice speedup and reminiscence financial savings throughout inference, whereas additionally sustaining lengthy and brief context talents. Therefore, MA-Offset and MA-Pairs is likely to be good configurations for additional analysis.
  • MixAttention fashions could be educated with LLM Foundry. Please see the appendix for tips.

 

Typically, there’s a massive hyperparameter area to discover, and we stay up for seeing a wide range of new methods for lowering the price of inference by way of mixtures of sliding window consideration and KV cache reuse.

Appendix: Utilizing LLM Foundry to coach MixAttention fashions

The best way to configure MixAttention fashions with LLM Foundry is to make use of the block_overrides characteristic. The block_overrides definition consists of two sections: order and overrides. The order key defines the ordering and the names of the layers within the community, whereas the overrides key comprises the customized configuration of every named layer. 

 

For instance, to create a 5 layer community with the primary two layers being the usual consideration layers, the subsequent two being the sliding window layers, and the final one being an ordinary consideration layer, we use the next YAML:

CodeSnippet1

Right here, the order part conveys that the primary two layers are of sort ‘default’, the subsequent two are of sort ‘sliding_window_layer’, and the final is of sort ‘default’ once more. The definitions of every of those sorts are contained within the overrides part utilizing the names outlined within the order part. It says that the ‘sliding_window_layer ought to have a sliding_window_size of 1024. Be aware that ‘default’ is a particular sort, which doesn’t want a definition within the overrides part as a result of it simply refers back to the default layer (on this case, an ordinary consideration layer). Additionally, observe that ‘sliding_window_layer‘ is only a customized identify and could be changed with some other arbitrary identify so long as that identify is correspondingly additionally outlined within the overrides part.

 

The mannequin configuration is printed within the logs, which can be utilized to verify that the mannequin is configured accurately. For instance, the above YAML will outcome within the following being printed within the logs:

CodeSnippet2

We are able to additionally configure the 2 sliding window layers to have completely different sliding window sizes as follows:

CodeSnippet3

The above will outcome within the third layer having a sliding window dimension of 1024, and the fourth layer having a sliding window dimension of 512. Be aware that the repeat key phrase defaults to 1. So, the above YAML can be written as:

CodeSnippet4

The repeat key phrase can also be relevant to the order key phrase. So, if we wish to create a 4 layer community with alternating commonplace and sliding window consideration layers like the next,

MixAttention Appendix 1

then we will use the next YAML:

CodeSnippet5

To make a layer reuse the KV cache of a earlier layer, we use reuse_kv_layer_idx within the attn_config within the override definition. The important thing reuse_kv_layer_idx comprises the relative layer index whose KV cache we wish this layer to reuse. To make a two layered community the place the second layer reuses the primary layer’s KV cache, we will use the next YAML:

CodeSnippet6

The worth -1 signifies that the layer named kv_reuse_layer reuses the KV cache of the layer that’s one layer earlier than it. To create a 5 layer community with the next configuration

Mix Attention Appendix Image 2

we will use the next YAML:

CodeSnippet7

Be aware that within the above configuration, layer #4 reuses the KV cache of layer #3, which in flip reuses the KV cache of layer #2. Therefore, layer #4 finally ends up reusing the KV cache of layer #2.

 

Lastly, observe that order could be outlined recursively; that’s, the order can comprise one other order sub-block. For instance, MA-Offset-SlideShare

Appendix 3 image

could be outlined as follows:

CodeSnippet8

Appendix: Different Ablation Experiments

Sharing Commonplace Consideration KV Caches between Consecutive Layers

Because the transformer layers progressively replace the latent illustration of a token because it progresses by way of the layers, the Question, Key, and Worth tensors may need considerably completely different representations for layers which might be far aside. Therefore, it’d make extra sense to share KV caches between consecutive layers. To check this, we in contrast 4 such configurations: MA-Successive-1, MA-Successive-2, MA-Successive-3, and MA-Successive-4 in opposition to MA-Pairs. These configurations differ the positions of the usual KV consideration layers and the space between the consecutive pairs of normal KV consideration layers.

MixAttention image 4
KV cache sharing between consecutive layers: To measure the impact of KV cache sharing between consecutive layers, we tried the 4 configurations above.

 

Because the transformer layers progressively replace the latent illustration of a token because it progresses by way of the layers, the Question, Key, and Worth tensors may need considerably completely different representations for layers which might be far aside. Therefore, it’d make extra sense to share KV caches between consecutive layers. To check this, we in contrast 4 such configurations: MA-Successive-1, MA-Successive-2, MA-Successive-3, and MA-Successive-4 in opposition to MA-Pairs. These configurations differ the positions of the usual KV consideration layers and the space between the consecutive pairs of normal KV consideration layers.

MixAttention appendix 5

MixAttention appendix 6
Impact of KV cache sharing between consecutive layers: (Prime) Loss curves of the fashions when advantageous tuning on lengthy context QA dataset. (Backside) RULER evals for the fashions. We discovered that KV cache sharing between consecutive layers doesn’t persistently improve lengthy context talents throughout all evals. Nevertheless, for duties like  SQuAD QA and Hotpot QA, which could be indicative of lengthy context RAG talents, the efficiency was barely worse when sharing KV cache between consecutive layers.

We decided that every one the fashions have comparable loss curves and comparable efficiency on NIAH single 1, 2, and three duties, which we take into account to be the simplest lengthy context duties. Nevertheless, we didn’t see a constant sample throughout the opposite NIAH duties. For lengthy context QA duties, we discovered that MA-Pairs was barely higher than the others. These outcomes point out that sharing commonplace consideration KV cache between layers which might be additional aside doesn’t result in any vital degradation in lengthy context talents as in comparison with sharing commonplace consideration KV cache between consecutive layers.

Impact of Sharing Commonplace Consideration KV Cache

MixAttention appendix 7
No commonplace consideration KV-cache sharing: To measure the impact of KV cache sharing between commonplace consideration layers we evaluate the architectures proven within the determine above.

MixAttention appendix 8

MixAttention appendix 9
Impact of no commonplace consideration KV-cache sharing: (Prime) Loss curves of the fashions when advantageous tuning on lengthy context QA dataset. (Backside) RULER evals for the fashions. We discovered that each MA-NoShare-2 and MA-NoShare-3 had been comparable with MA-Offset.

 

To check the impact of sharing the KV cache between commonplace consideration layers, we tried out three configurations: MA-NoShare-1, MA-NoShare-2, and MA-NoShare-3. We discovered that MA-NoShare-1 carried out very badly on RULER, indicating its lack of lengthy context talents. Nevertheless, MA-NoShare-2 and MA-NoShare-3 had been corresponding to MA-Offset on lengthy context duties. Therefore, we expect that additional analysis is required to determine the results of sharing commonplace consideration KV cache.

Leave a Reply

Your email address will not be published. Required fields are marked *