Simplified Sparse Attention via Gist Tokens

· HF Daily Papers ·

SSA uses trained gist tokens to select relevant long-context chunks without changing transformer architecture.

Categories: Research

Excerpt

Yuzhen Mao, Michael Y. Li, Emily B. Fox — Sparse attention can reduce the cost of long-context inference, but most variants introduce new architectural components. We introduce Simplified Sparse Attention (SSA), a simpler approach to sparse attention that requires no architectural changes. Concretely, we first perform continued pretraining on sequences interleaved with gist tokens. We optimize the standard next-token loss as usual, but the gist tokens use an attention mask to restrict what parts of the context the language model can attend to; this teaches the model to pack each chunk's important information into the gist tokens. At inference time, SSA scores chunks via attention between the current query and the small set of gist tokens, selectively unfolding the top-k chunks by reintroducing their corresponding raw tokens. Since the query is scored only against the gist tokens, we avoid the memory-bandwidth cost associated with naive scoring against the full KV cache, without requiring the auxiliary KV cache approach used by sparse attention methods. On LongBench, SSA consistently outperforms compression and inference-time sparse-attention baselines under the same compression ratio. More strikingly, in retrieval-augmented generation, SSA can even outperform full attention after continued pretraining by over 5.7 points. We attribute this to the ability of SSA's selective unfolding, which concentrates attention on the query-relevant chunks and effectively filters out noise. S