Replacement-aware SAE Training

mechinterp
Author
Published

2026-01-12T07:52Z

Modified

2026-01-28T01:55Z

This post is a writeup of some mechinterp ideas I explored in late 2025. You can reproduce the results presented here using this notebook. Full source code is available here.

Note added 2026-01-28Z: While working on an upcoming blog post, I discovered a bug in the implementation originally posted here, primarily affecting the next layer + finetuning method. Applying this fix makes my method look slightly better by the metrics shown here. I have updated the plots in the post below; the original plots can still be viewed at the notebook using the prior version of the code. I also fixed some minor inconsistencies in the indices in displayed math equations.


Summary

One important use case for SAEs is in scalably identifying feature circuits (Ameisen et al. 2025), based on a replacement model in which transformer component outputs are re-encoded by their corresponding SAEs. A major drawback of this approach is that reconstruction errors tend to compound, since transformer models perform sequential operations on a shared residual stream. Standard residual stream SAE training is maximally myopic with respect to this use case, in that every SAE is trained in isolation. Even end-to-end training (Braun et al. 2024), which additionally looks at overall model performance, still trains each SAE’s parameters individually without consideration of how this will affect the downstream SAEs. We should be able to do better if we explicitly consider the replacement model during training. In this post I consider one of the simplest possible replacement-aware training paradigms, in which a term for the feature-space error of the next layer’s SAE is added to the standard SAE loss function. I present a brief quantitative comparison of SAEs trained by these various approaches on TinyStories (Eldan and Li 2023).

The key ideas of this post are:

  1. Multi-SAE replacement models can be improved by considering the impact on downstream SAEs during training, and
  2. It is better to measure this impact in SAE feature space rather than in the native activation space.

Motivation

As part of the project I worked on for SPAR, I looked into extracting circuits built out of SAE features in a replacement model (Ameisen et al. 2025) formed by re-encoding each layer’s output with an SAE. This proved to be infeasible without the use of error nodes (Marks et al. 2024), since re-encoding more than a single layer made the model’s output completely incoherent–it became unable to generate plausible English of any kind. In my opinion, this calls into question the validity of this entire approach to circuit extraction. If the “error” terms are so critical that without them the model completely breaks down, I don’t see how we can trust any causal story for the model’s behavior that handwaves them away as noise. While some degradation in performance is to be expected, I think it’s worthwhile to demand that our replacement model retain the qualitative property of “still being a language model!”

One way we might achieve this goal is by simply improving the quality of reconstruction at each layer. The SAE is just one of many possible scalable methods to extract interpretable features out of model activations. We might look to the broader field of sparse coding for inspiration; see for example Lewis Smith’s writeup of trying the gradient pursuit algorithm for this task in a Google DeepMind team blogpost (Nanda et al. 2024). I also have some ideas and a few small experiments I’ve done along this vein loosely based on LISTA (Gregor and LeCun 2010), which I will write up more formally at a later date. However, I am not terribly optimistic that better sparse coding techniques alone will scale to frontier models. This is because errors in the replacement model inherently have a compounding effect, since the outputs of every subsequent layer depend on the previous ones. This problem gets worse the deeper a model gets.

An orthogonal angle of attack is to mitigate the compounding effect itself for a given level of error. This requires some reworking of the training setup. The standard approach treats each transformer layer independently. This has some advantages, such as lower memory and compute requirements and the ability to train all layers in parallel. However, this setup fails to take into account the use case at the start of this section; we generally also care about the overall performance of the model. The way one cares about things in machine learning is by adding them to the loss function; (Braun et al. 2024) achieve this by adding a term for the KL divergence between variants of the model with and without each SAE. While this does, as expected, improve the performance of the replacement model, it is much more computationally and memory intensive. These costs can be mitigated quite substantially with a hybrid approach. (Karvonen 2025) found that following up a standard training run with an end-to-end KL finetuning step, with as few as 5% of the total training tokens, resulted in comparable performance. This step of course requires just as much memory and per-token compute, but running it on many fewer tokens lowers the overall cost.

While the end-to-end training setup does attempt to address overall model performance, it is subtly different from the motivating application. Specifically, it only considers the effect of re-encoding a single layer at a time, whereas for circuit extraction we want to replace all layers simultaneously. In the following I will formalize this distinction and propose an alternative procedure that more closely approximates the replacement model during training.

Replacement models

Let \(M\) represent a GPT-style transformer model, consisting of \(m\) layers \(T_i\) for \(i\in[0..m-1]\). Our goal will be to train a set of residual stream SAEs, one for each layer. For these purposes, we can treat the output of the transformer as the sequential composition of each layer \(T_i\) (where \(h\) outputs the logits from the final residual, and \(x\) is the embedding of the input):

\[ M(x) = h \circ T_{m-1} \circ T_{m-2} \circ \dots \circ T_0(x) \]

Then, the \(r\)-replacement model \(\widehat{M}^r\), where \(r \subseteq[0..m-1]\), is an alternative model that re-encodes some subset of layers with the corresponding SAE:

\[ \begin{gather} \widehat{M}^r(x) = h \circ \widehat{T}^r_{m-1} \circ \widehat{T}^r_{m-2} \circ \dots \circ \widehat{T}^r_0(x) \\ \widehat{T}^r_i=\cases{ SAE_i \circ T_i & if $i \in r$ \\ T_i & otherwise } \end{gather} \]

In this notation, we would designate the original model as \(\widehat{M}^{\emptyset}\), and the full replacement model as \(\widehat{M}^{[0..m-1]}\). For convenience, I will also define the activation of the \(i^{th}\) layer of the \(r\)-replacement model for input embedding \(x\) as

\[ \widehat{A}^r_i(x)=\widehat{T}^r_i\circ\widehat{T}^r_{i-1}\circ \dots \circ \widehat{T}^r_0(x). \]

Standard SAE setup

In this post, we will use the common TopK SAE variant, which is defined as \[ \begin{aligned} SAE_i(a) &= F_i(a)(W^{dec}_i)^T + b^{dec}_i\\ F_i(a)[t, \ell] &= \left\{\begin{array}{lr} Y_i(a)[t, \ell], & \ell \in \textrm{argtopk}(k, Y_i(a)[t, :]) \text{ and}\ Y_i(a)[t, \ell] \geq 0 \\ 0 & \text{ otherwise} \end{array} \right. \\ Y_i(a) &= a(W^{enc}_i)^T + b^{enc}_i . \end{aligned} \tag{1}\]

Note that \(F_i(a)\), the feature vector for that input, is constructed element-wise at each token position \(t\) from a linear transformation of \(y\), where only the top \(k\) largest elements are retained, followed by a ReLU.1 The parameters of the SAEs are the encoder weights and bias (\(W_i^{enc} \in \mathbb{R}^{d_{model} \times d_{sae}}\) and \(b_i^{enc} \in \mathbb{R}^{d_{sae}}\)) and decoder weights and bias (\(W_i^{dec} \in \mathbb{R}^{d_{sae} \times d_{model}}\) and \(b_i^{dec} \in \mathbb{R}^{d_{model}}\)). \(k\in\mathbb{Z}^+\) is a hyperparameter controlling the sparsity of the feature vectors \(F_i\).

Then, the standard TopK SAE training objective over minibatches \(X\) is defined for each \(SAE_i\) as \[ \mathscr{L}^{std}_i = {1\over{n}} \cdot \sum_{j=0}^{n-1}||\widehat{A}^{\{i\}}_i(X_j) - \widehat{A}^{\emptyset}_i(X_j)||_2^2, \tag{2}\]

i.e. the MSE between the replaced and original \(i^{th}\) layer activations. The SAE for each layer is trained independently to minimize Equation 2.

KL fine-tuning and end-to-end training

(Karvonen 2025) first train each SAE using the standard loss of Equation 2, followed up by a fine-tuning phase that adds in the KL divergence of the replacement model including that layer’s SAE:

\[ \mathscr{L}^{finetune}_i = \mathscr{L}^{std}_i +{\beta\over{n}} \cdot\sum_{j=0}^{n-1} KL(\widehat{M}^{\{i\}}(X_j)\ ||\ \widehat{M}^{\emptyset}(X_j)) . \tag{3}\]

\(\beta\)2 is computed for each minibatch such that the magnitudes of the two loss terms match. I also follow this approach in my implementation.

In end-to-end training (Braun et al. 2024), the reconstruction errors for each later layer are combined into a single training objective, with an added term for the KL divergence between the \(\{i\}\)-replacement model and the original. Adapting the notation somewhat for this post3, and computing \(\beta\) as above (dividing by the number of MSE terms so that the KL term is balanced to their average), this loss is \[ \mathscr{L}^{e2e}_i = {1\over{n}} \cdot \sum_{j=0}^{n-1} \left( \begin{align} \sum_{\ell=i+1}^{m-1} ||\widehat{A}^{\{i\}}_\ell(X_j) - \widehat{A}^{\emptyset}_\ell(X_j)||^2_2 \\ + {\beta \over{m-i}}\cdot KL(\widehat{M}^{\{i\}}(X_j)\ ||\ \widehat{M}^{\emptyset}(X_j)) \end{align} \right) \tag{4}\]

Note that this intentionally does not include a term for the current layer’s reconstruction error, only those downstream of it.

Replacement-aware training

The loss functions presented so far have dealt with a quite limited replacement model, in which only a single layer was re-encoded with its SAE. This is somewhat mismatched for applications like circuit extraction where we would like to use the full replacement model. The most obvious way to accommodate this would be to use the full replacement model directly in training, i.e. by training all SAEs simultaneously with the following loss:

\[ \mathscr{L}^{full} = {1\over{n}} \cdot \sum_{j=0}^{n-1} \left( \begin{align} \sum_{\ell=0}^{m-1} ||\widehat{A}^{[0..m-1]}_\ell(X_j) - \widehat{A}^{\emptyset}_\ell(X_j)||^2_2 \\ + {\beta \over{m}} \cdot KL(\widehat{M}^{[0..m-1]}(X_j)\ ||\ \widehat{M}^{\emptyset}(X_j)) \end{align} \right) \tag{5}\]

Note that this is almost identical to Equation 4; the important difference is that it uses the full replacement model in both the reconstruction and KL terms. Implemented naively, this would be even more expensive than end-to-end training, since we would have to have every SAE parameter and its optimizer state in memory, rather than just those for a single layer. This may prove to be impractical for larger models.4 However, I suspect that we can get most of the benefits by looking ahead by a single layer (where \(\beta\) is again computed as above to balance the two loss terms):

\[ \mathscr{L}^{next}_i = \mathscr{L}^{std}_i + {\beta\over{n}} \cdot \sum_{j=0}^{n-1} \cases{ KL(\widehat{M}^{\{i\}}(X_j)\ ||\ \widehat{M}^{\emptyset}(X_j)) & $i = m-1$ \\ ||\widehat{A}^{\{i,i+1\}}_{i+1}(X_j) - \widehat{A}^{\emptyset}_{i+1}(X_j)||^2_2 & $i \in [0..m-2]$ } \tag{6}\]

That is, in addition to the standard reconstruction loss, we add a term that for the final layer measures the KL divergence of the replacement model. For all other layers, this term instead depends on the reconstruction error of the next layer when both this layer and the next are replaced by their SAEs (note the two elements of the replacement set).

Why do I think looking ahead by one layer is likely to be sufficient? If we take seriously the idea that SAEs capture exactly the information that is most important to the model, then we ought to be able to treat the next layer’s SAE as if it screens off the influence of the current layer on those even further downstream. (We still need that next SAE, because the one we’re currently training hasn’t learned which features are important yet!) If we train the set of SAEs in reverse layer order, freezing their parameters as we go, we can approximate the effects of the full replacement model while only needing to run two transformer layers and SAEs in each training loop.

After my preliminary implementation of Equation 6, I had an idea that leans even harder into the notion that SAE features are the only important information passed through the residual stream. The most direct way to preserve this information is to reduce disruption to the features themselves. Reconstruction error gets at this only indirectly, since we expect the residual stream to include some noise. If we waste optimization power on recovering this noise, we get less of what we actually want. Hence, we should calculate the next-layer term in feature space. Because features are very sparse vectors, we likely wouldn’t want to do this with Euclidean distance; something like cosine distance is more appropriate. This leads to the following loss:

\[ \mathscr{L}^{feature}_i = \mathscr{L}^{std}_i + {\beta\over{n}} \cdot \sum_{j=0}^{n-1} \cases{ KL(\widehat{M}^{\{i\}}(X_j)\ ||\ \widehat{M}^{\emptyset}(X_j)) & $i = m-1$ \\ 1 - \frac{(F_{i+1}\circ\widehat{T}^{\{i\}}_{i+1}(X_j)) \cdot (F_{i+1}\circ\widehat{T}^{\emptyset}_{i+1}(X_j))}{||F_{i+1}\circ\widehat{T}^{\{i\}}_{i+1}(X_j)||_2 \cdot ||F_{i+1}\circ\widehat{T}^{\emptyset}_{i+1}(X_j)||_2} & $i \in [0..m-2]$ } \tag{7}\]

where the complicated term is the cosine distance between the features computed by the next layer’s SAE given the current layer’s SAE output, versus the features computed by the next layer’s SAE given the baseline activations.

I implemented two variants of this idea for this post. The first, which I call the next-layer method, directly uses Equation 7 as its loss function. The second variant, inspired by (Karvonen 2025), follows this up with a fine-tuning phase that adds in the full replacement model KL loss, i.e. the KL term from Equation 5.

Results

Each method was run using TinyStories-33M as the base model over 10 million tokens from the TinyStories dataset (Eldan and Li 2023), using TopK SAEs with \(k=100\) and an expansion factor of 4. See Appendix A for more implementation details. Figures in this section can be zoomed and scrolled after clicking on them.

I’ll start with the key results. Below, I plot a histogram of the KL divergence at each token for the full replacement models over 1 million tokens from the TinyStories validation set. Note that the x-axis is log-scaled, since the KL values vary by several orders of magnitude and appear to be approximately log-normally distributed:

While the distributions are visually quite similar, we can see in the reported summary statistics that the fine-tuned next-layer method achieves the lowest geometric mean KL divergence. Given the approximate log-normality of these distributions, I would argue that this is the most natural summary statistic. We can also see that end-to-end and KL fine-tuning get essentially identical results to each other, as observed by (Karvonen 2025), while standard training does much worse.

I think it’s also instructive to look at the reconstruction errors for each method, since these can give a rough estimate of SAE faithfulness–large error could be an indication that the SAE is using different mechanisms than the base model. In Figure 2 I plot the distribution of relative reconstruction error, defined as \[ ||\widehat{A}_i^{[i..m-1]}(x) - \widehat{A}^{\emptyset}_i(x)||_2 \over ||\widehat{A}^{\emptyset}_i(x)||_2 \] for each layer and method. I use this metric rather than MSE because norms vary a lot across layers and token positions, and this should make the values more directly comparable.

One key thing I notice in these charts is that standard training achieves the best reconstruction, which is unsurprising since that’s its only training objective. We can also see that end-to-end achieves the worst reconstruction–also unsurprising since its loss doesn’t include a term for the current layer. Next layer is roughly on par with KL fine-tuning, which I also think makes sense; both are balancing reconstruction error with a single other term, so they naturally place between standard and end-to-end. I’m not sure what to make of the fine-tuned next layer having slightly worse reconstruction error than KL fine-tuned, especially on layer 3; this might be driven by the shorter left tail, since the distributions otherwise look extremely similar.

Another aspect of note is that earlier layers seem to be easier to reconstruct across the board. My hunch is that this is because these activations are more similar to the raw embeddings and thus have simpler structure. A final curiosity is the small second mode in the left tail in most of the later-layer plots. I would guess that this is an artifact of the synthetic nature of the TinyStories dataset (e.g., many examples start with “Once upon a time”, so these specific tokens may have been overrepresented), but would be interested to see if this holds in more realistic settings.

Next, let’s look at how each method performed during training. Below, I plot the batch mean of the \([i..m-1]\)-replacement model KL on the training set during training (roughly every 10,000 tokens) for each layer. Because layers are trained sequentially in reverse order, this means that the plot for layer 0 reflects the KL for the full replacement model. To save on training time, the KL fine-tuned model was initialized from a checkpoint of the standard model at 90% of training tokens, and similarly for the next-layer fine-tuned model from the next-layer model. Hence, the trajectories match until that point. Note that the y axis uses a log-scale for better separation, and the gray dashed line represents the switch to fine-tuning for the fine-tuned methods:

From Figure 3, we can see the compounding effect of SAE reconstruction error: the KL divergence attained after training increases for the earlier layers, since they replace progressively more of the model. We can also see that the standard training is straightforwardly worse than the other methods, while my next-layer-only method is somewhere between standard and end-to-end. Both forms of fine-tuning very quickly recover nearly identical performance to the end-to-end method.

Finally, it’s important to check the integrity of the replacement models qualitatively. Example generations with the full replacement models are presented in Table 1, all given the prompt “Once upon a time”. As far as I can tell, they all seem to be as coherent as the baseline; this is likely too small of a model for the complete degradation I observed using SAEs on larger models like DeepSeek-R1-Distill-Llama-8B.

Table 1: Example generations

Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big, red apple on the ground. She picked it up and took a bite. It was so juicy and delicious!

Suddenly, she heard a loud noise. It was a big, scary dog! Lily was scared and didn’t know what to do. But then, she remembered the apple she had picked earlier. She took a bite and it was even more delicious than before!

Lily learned that sometimes things that look scary can be delicious. She also learned that it’s important to be brave and not give up. From that day on, Lily always carried an apple with her, just in case she needed to face her fears.

Once upon a time, there was a little girl named Lily. She loved to play outside and explore. One day, she found a big, shiny rock in the grass. She picked it up and held it in her hands. She was so happy to have it.

Lily decided to take it home with her. She put it in her pocket and went to her mommy. Her mommy said, “Lily, you have a very special rock. It is very special.” Lily was so happy to have her special rock.

Later that day, Lily went to the park with her mommy. She saw a big tree and wanted to climb it. But she was scared and didn’t know how to climb it. Her mommy said, “Don’t worry, Lily. We can use your special rock to help you climb the tree.” Lily was happy and started to climb. When she reached the top, she was so proud of herself.

From then on, Lily knew she could do anything with her special rock. And she knew that she could do anything she wanted.

Once upon a time, there was a little girl named Lily. She loved to play with her toys and explore the world around her. One day, she found a shiny silver object in her pocket. She picked it up and showed it to her mommy.

“Mommy, look what I found!” Lily said, holding up the shiny silver object.

“Wow, that’s a pretty silver,” her mommy replied.

Lily put the silver in her pocket and continued to explore the world around her. She wanted to show her mommy the shiny silver object, but she was too small to make it work. So, she decided to take it out and show it to her mommy.

“Mommy, can I keep this shiny silver thing?” Lily asked.

“Yes, you can keep it,” her mommy replied.

Lily was so excited to have found such a special thing in her pocket. She put the silver shiny silver object back in her pocket and went back to playing with her toys. She was so proud of herself for being so brave and not too small to keep something shiny.

Once upon a time, there was a little girl named Lily. She loved to play outside in the park. One day, she saw a big tree with a swing and wanted to swing on it. She asked her mommy if she could have a swing. Her mommy said yes and Lily ran to the swing set. She swung high and low, feeling the wind in her hair. She felt so happy and peaceful.

After a while, Lily’s mommy said it was time to go home. Lily was sad to leave but she knew she had to go. She said goodbye to the swing and went home. When she got home, her mommy gave her a big hug and said she had a surprise for her. She had a big bag of candy and a big smile on her face. Lily was so happy and grateful for her candy. She couldn’t wait to eat it and show it to her friends.

Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big tree with lots of shade in the shade. She wanted to climb up the tree, but it was too high.

Lily asked her mom for help. Her mom said, “Let’s build a ladder out of some wood and sticks. Then you can climb up the tree.” So, they built a ladder out of the tree and Lily climbed up.

When Lily reached the top, she felt very proud of herself. She looked down at the sky and saw the sun shining. She said, “I’m so happy to be at the top!” Her mom smiled and said, “Yes, it sure is. You did a great job!”

Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big, red apple on the ground. She picked it up and took a bite. It was so yummy!

Suddenly, Lily’s friend came over and asked her if she wanted to play with her toys. Lily said yes and they played together for a while. But then, Lily’s friend accidentally knocked over the apple. It broke into pieces and made a big mess.

Lily felt sad because she loved her apple. But then, her mom came and said they could make a new one together. They worked together to make a new apple cake and it was even better than the first one. Lily was so happy and thanked her mom for helping her.

Conclusion

SAE training made replacement-aware by incorporating disruption to the next layer’s SAE features in the loss function slightly improves replacement model performance on TinyStories. The experiments here were of course extremely limited in scope, so I will be very interested to see if this transfers to larger models and how much it matters qualitatively. Ultimately I want to improve the foundations of scalable circuit extraction with the ability to run replacement models without error nodes, and hopefully this is a step in that direction.

References

Ameisen, Emmanuel, Jack Lindsey, Adam Pearce, Wes Gurnee, Nicholas L Turner, Brian Chen, Craig Citro, et al. 2025. Circuit tracing: Revealing computational graphs in language models.” https://transformer-circuits.pub/2025/attribution-graphs/methods.html.
Braun, Dan, Jordan Taylor, Nicholas Goldowsky-Dill, and Lee Sharkey. 2024. Identifying functionally important features with end-to-end sparse dictionary learning.” arXiv [Cs.LG], May.
Eldan, Ronen, and Yuanzhi Li. 2023. TinyStories: How small can language models be and still speak coherent English? arXiv [Cs.CL], May.
Gregor, Karol, and Yann LeCun. 2010. Learning fast approximations of sparse coding.” International Conference on Machine Learning, June, 399–406.
Karvonen, Adam. 2025. Revisiting end-to-end sparse autoencoder training: A short finetune is all you need.” arXiv [Cs.LG], March.
Marks, Samuel, Can Rager, Eric J Michaud, Yonatan Belinkov, David Bau, and Aaron Mueller. 2024. Sparse feature circuits: Discovering and editing interpretable causal graphs in language models.” arXiv [Cs.LG], March.
Nanda, Neel, Arthur Conmy, Lewis Smith, Senthooran Rajamanoharan, Tom Lieberum, János Kramár, and Vikrant Varma. 2024. [Full Post] Progress Update #1 from the GDM Mech Interp Team.” https://www.alignmentforum.org/posts/C5KAZQib3bzzpeyrg/full-post-progress-update-1-from-the-gdm-mech-interp-team.

Appendix A: Implementation details

Source code is available on GitHub. Note that the link includes a git tag, which marks the code as it appeared at the time of publishing. I will likely be iterating on these experiments in the future, so the main branch may diverge.

I implemented a torch-based SAE training pipeline from scratch, using transformers and datasets to handle loading model weights, code, and training data. It is somewhat hardcoded to expect the model to be TinyStories (which is based on GPT-Neo), but is likely compatible with most GPT-like models with slight tweaking. In addition to my next-layer method, I implemented standard training, KL fine-tuning, and end-to-end-training.

For the results in this post, I used a fixed number of SAE features across all training methods and layers, set to 4*d_model, which for TinyStories works out to 3072. The SAEs use the TopK activation function with k=100. The model used was TinyStories-33M. For each method, I trained the SAEs using 10 million (1e7) tokens from the TinyStories dataset (Eldan and Li 2023). For the KL and replacement model fine-tuning methods, 90% of the token budget (9 million tokens) went into standard training, with the fine-tuning step using the remaining 1 million tokens.

I do not normalize or shuffle activations. Activation shuffling is incompatible with all of the methods discussed here other than standard training. This is because computing the replacement model KL (or in my case, next layer features) requires the full context window, since the perturbation to the residual stream caused by the SAEs changes the attention patterns of downstream layers.

All methods use the torch implementation of Adam as their optimizer. I did not do a systematic hyperparameter sweep. After trying out a handful of values, I ended up using a constant learning rate of 1e-3, which seems to be reasonably effective across the different training methods. MSE-based and KL-based loss terms are balanced to have the same magnitude at each minibatch; there is currently no hyperparameter controlling their relative ratio. Note that due to this balancing, differences in intrinsic scale should theoretically be smoothed out, meaning that the optimal learning rate should transfer across methods. In the fine-tuning phases of the KL fine-tuning method and next-layer fine-tuning, I decay the learning rate linearly to 0 based on fraction of total tokens consumed after (Karvonen 2025).

For all methods during the main training phase, I train the SAEs sequentially in reverse layer order. I initialize the weights for the SAE at layer i from the already-trained SAE at layer i+1 if it exists, as I found informally that this speeds up convergence. Speculatively, this may also help the replacement model performance, since it will tend to nudge the SAE to keep features that are used by the next layer. For the methods that involve a fine-tuning phase, I cloned them from a checkpoint of the corresponding base model (standard training for KL fine-tuning, next-layer for next-layer + fine-tuning) at the appropriate number of training tokens, saving on total training time.

The data presented in this post were generated by running this notebook on an A40 instance rented from RunPod; this takes about 3 and a half hours to run.

Footnotes

  1. I’ve never seen a rendition of the TopK SAE in mathematical notation that I liked. I promise that the code looks much cleaner!↩︎

  2. The paper uses the notation \(\alpha_{KL}\).↩︎

  3. The original paper uses a “vanilla” SAE and so has an additional L1 regularization term. The weight on the KL term is also fixed at 1, and instead hyperparameters control the relative importance of each MSE term. The formulation here is equivalent to some setting of those parameters.↩︎

  4. As of writing, I haven’t yet tried this full approach, as this framing occurred to me only as I was formalizing the less-intensive next-layer-only approach.↩︎

Citation

BibTeX citation:
@online{lloyd2026-replacement-aware-sae-training,
  author = {Lloyd, Evan},
  title = {Replacement-aware {SAE} Training},
  date = {2026-01-12},
  url = {https://elloworld.net/posts/replacement-aware-sae-training/},
  langid = {en}
}
For attribution, please cite this work as:
Lloyd, Evan. 2026. Replacement-aware SAE Training. January 12, 2026. https://elloworld.net/posts/replacement-aware-sae-training/.
Lloyd, Evan. 2026. Replacement-aware SAE Training. January 12, 2026. https://elloworld.net/posts/replacement-aware-sae-training/.