Robert Važan

Sequence editing models

I want to describe what I call sequence editing models. In contrast to current next-token language models, hypothetical sequence editing models generate the entire sequence in parallel, probably by iteratively refining drafts, which makes them a better fit for current hardware. Diffusion models are an early and imperfect example. We need something that supports insertions and deletions though. I offer an idea inspired by Gaussian splats.

Hardware dictates architecture

GPUs have a lot of compute compared to memory bandwidth. They can perform hundreds of compute operations per single memory transfer operation. In order to make use of all that compute, neural network architectures try to formulate most of the workload as matrix-matrix multiplications, which cost O(N3) in compute but only O(N2) in memory bandwidth. For comparison, matrix-vector multiplications are O(N2) in both compute and bandwidth, so they end up being bandwidth-limited.

The reason current LLMs look the way they do is rooted in hardware. In order to train efficiently, the transformer architecture evaluates all tokens in the sequence in parallel. Sequential dependencies are introduced only on layer boundaries. There are only ~100 layers, so the architecture is almost perfectly parallel. Transformer architecture is designed to compile into a series of efficient matrix-matrix multiplications on the GPU. This makes training of transformer models far more efficient compared to recurrent models, for example.

Transformers are however parallel only during training when all tokens are known in advance. At inference time, transformer-based LLMs have to generate one token at a time, because generation of tokens requires knowledge of all prior tokens. Current next-token LLMs are thus grossly inefficient at inference time, which somewhat surprisingly impacts prompt processing too.

Memory will not get faster. Physics limits non-local memory access latency to a cube root of memory capacity (see also SE discussion). I would generalize this to bandwidth too, because energy consumption increases with distance. You can cheat a little with a wide bus to RAM and even wider and shorter bus to HBM, but you will eventually run into energy and heat dissipation limits. Details vary with hardware, but even L2 cache access needs an order of magnitude more energy than single floating-point multiplication.

If memory bandwidth is going to remain limited, we have only two options left: either colocate compute with memory to eliminate memory access entirely or switch to models with parallelizable inference to minimize data transfer overhead. Human brain colocates compute with memory, but replicating that in GPUs would require radical hardware changes. Models with parallel inference are the best short-term bet, because they require only software changes.

Diffusion models are too rigid

Token-level denoising (or diffusion) language models like the recent Gemini Diffusion add some sort of token-level noise and then train the model to remove the noise. This noise comes in many forms: masking tokens, token substitution and reordering, embedding noise, or noise in token probability distribution. Model can control sequence length by filling unused part of the token buffer with padding tokens. Token denoising models are the most well-known parallel alternative to next-token prediction, so let's discuss their positives.

Denoising seems like a totally alien approach to producing language. No human thinks like that. It works however, because text shares certain properties with images, where denoising models originated (the famous Stable Diffusion). Tokens predict surrounding text just like pixels predict surrounding image. If you mask 90% of words in an article, you can still tell what the article is about. Furthermore, language is elastic like images. You can fit the same information into a varying number of tokens. This alleviates the need for insertions and deletions to some extent. Denoising model works like a journalist who was asked to produce a 5,000-word article. Just like the journalist, the model allocates the token budget to sections, paragraphs, and sentences and adapts overall structure as well as detailed wording to fit in the budget.

I am not a big fan of denoising models though, because token-level denoising invariably resists insertions and deletions. Inserting or removing a token would require pushing or pulling all subsequent tokens. The model can theoretically do that with the right training objective, but it seems difficult and I have yet to see a denoising model that happily performs insertions and deletions. Making statements shorter or longer is tolerable in journalistic writing, but I cannot imagine LLMs generating sensible source code with fixed token budget.

Drawing inspiration from Gaussian splats

Tokens with fixed position, whether in next-token models or diffusion models, are akin to voxels in 3D modeling. Like voxels, tokens can change internal properties (sampled token or the embedding behind it), but their location and size are fixed. Gaussian splats are an interesting inspiration, because they replace rigid voxels with mobile, resizable, and overlapping splats.

Textual equivalent of Gaussian splats would be to model the sequence as an unordered set of independent text embeddings. Every embedding can represent a range of text from single token to the whole sequence. During inference, all embeddings can move within the sequence, change size, and change the embedding vector. I don't know of any existing implementations, so this is just a rough idea.

Graphs instead of sequences

To ease insertions and deletions, it would be better to use content addressing to represent position in the sequence instead of a numerical offset. Every embedding would carry some information about surrounding context, which would implicitly give it a position in the sequence relative to other embeddings. Content addressing would allow embeddings to form a directed graph of alternative solutions whereas numerical offsets would confine embeddings in a single totally ordered sequence.

Embedding graph would allow for some embeddings to form silent scaffolding that provides support, grounding, and space for experiments without directly altering the output. Such scaffolding could function as a parallel alternative to reasoning tokens. Long-span embeddings could be used early in the inference process as placeholders for high-level parts of the response that would be gradually populated by embeddings representing shorter spans or individual tokens.

Training and inference with mobile embeddings

I carry this idea of a graph of embeddings in my head for years, but details are still hazy. Gaussian splats are somewhat complicated to optimize, because they tend to get stuck in local optima. I guess inference with mobile text embeddings is going to be complicated too. The complexity does not matter though. What matters is that all embeddings can be optimized in parallel.

While inference seems merely complicated, training a model for mobile embeddings seems completely elusive. The basic idea is that there must be a neural network that optimizes every embedding while taking into account information from every other embedding. Something like transformer but without the rigid positions. But what would be the training objective? I have no idea.

The catch here is that in Gaussian splats, moving the splats to fit given set of images is itself implemented using gradient descent. Text equivalent is to move output embeddings to match given fixed input. That means we would be using gradient descent for inference. If gradient descent is used for inference, what would training look like? Maybe the Gaussian splat analogy breaks down here and mobile text embeddings require a different approach to inference and training.

If we could train a separate model for fixed-position embeddings first, then it would be easier to train mobile embedding model on top of that. A good fixed-position embedding represents both the current token as well as surrounding context. Good fixed-position embedding is also additive, so that it can be constructed as a weighted sum of several mobile embeddings. We could then somehow compare similarity between expected fixed-position embeddings and actual mobile embeddings, perhaps by materializing actual fixed-position embeddings from the mobile embeddings. That would hopefully give us a differentiable loss function. We would still have to figure out how to construct training samples with varying degrees of noise.

Anyway, I am way over my head here. What I wanted to say with this article is that we need some sequence editing architecture. The above idea of mobile embeddings is just an example of what such an architecture could look like. I would be happy to see research in this direction.