This article is still being written and is currently only available in French. You can still read it and access the associated GitHub repository, but the content is incomplete.
Reproduction de quelques résultats d'interprétabilité ; entrainement de Sparse Autoencoders, kernels SpMM et autres optimisations.
AUTHOR
Rémy SIAHAAN--GENSOLLEN
PUBLISHED ON
April 25, 2026
Le projet MistralSAE est disponible sur le lien GitHub ci-contre. Il contient le code des différentes optimisations et une CLI pour lancer la pipeline d'entrainement complète, ainsi qu'une interface web de chat avec un modèle « steeré ». Le checkpoint du SAE (~3B paramètres) est également disponible sur Hugging Face.
Je me souviens assez clairement du moment où j'ai lu Scaling Monosemanticity par Templeton et al. en 2024. C'est en explorant Transformer Circuits que j'ai découvert la recherche en interpretabilité, et comme beaucoup d'autres j'ai trouvé la possibilité de piloter (steering) les LLMs particulièrement intriguante — et amusante, cf. Golden Gate Claude. Un peu plus d'un an après et quelques lectures en plus j'ai donc cherché à répliquer ces expériences, et entraîner moi-même un Sparse Autoencoder, ou SAE, pour faire de l'apprentissage par dictionnaire. Cet article détaille mes travaux dans ce but, les différentes optimisations que j'ai mises en place (kernels triton pour l'entraînement, inférence des activations plus rapide, ...) et quelques résultats de steering.
Cadre théorique
Transformers, caractéristiques et polysémie
Comme son nom l'indique, l'interprétabilité mécaniste cherche à comprendre le comportement d'un réseau de neurones en y explicitant des mécanismes internes qui lui font donner tel résultat. Une approche classique pour traiter ce type de questions est d'essayer de « reverse engineer » l'objet d'intérêt, en le décomposant en des parties plus petites et plus prédictibles, et en comprenant comment celles-ci intéragissent. Commençons donc par rappeler brièvement l'architecture des LLM récents. Dans les grandes lignes, elle est assez similaire entre ces derniers, et basée sur la partie décodeur du transformer[Ashish Vaswani, 2023]
Attention Is All You Need
Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin (2023)
un texte d'entrée est découpé en N tokens, pour lesquels on donne un embedding dmodel qui dépend du token et de la position — on a donc un tenseur x0 de forme (N,dmodel) ;
Pour chaque couche 0≤i<L du transformer, on lit xi et calcule xi+1=ℓi(xi), également de forme (N,dmodel) ;
en sortie de la dernière couche, on normalise et projette xL pour obtenir les logits utilisés pour prédire le token suivant.
Architecture d'un LLM moderne (approximativement)
Valeurs issues de mistralai/Ministral-3-3B-Base-2512
Chaque couche i est composée de deux blocs : un bloc d'attention et un bloc feed-forward (pour les modèles dense — c'est un peu différent pour les mixtures d'experts), chacun avec normalisation. Depuis GPT-2[Radford, 2019]
Language Models are Unsupervised Multitask Learners
Radford, Alec and Wu, Jeffrey and Child, Rewon and Luan, David and Amodei, Dario and Sutskever, Ilya (2019)
, la normalisation par couche se fait généralement en entrée de chaque bloc, plutôt qu'en sortie comme dans le transformer original[Ashish Vaswani, 2023]
Attention Is All You Need
Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin (2023)
On remarquera que chaque bloc "met à jour" (xi)i, qu'on appele d'ailleurs le flux résiduel du modèle. Il faut également noter qu'ils interagissent linéairement[Elhage, 2021]
A Mathematical Framework for Transformer Circuits
Elhage, Nelson and Nanda, Neel and Olsson, Catherine and Henighan, Tom and Joseph, Nicholas and Mann, Ben and Askell, Amanda and Bai, Yuntao and Chen, Anna and Conerly, Tom and DasSarma, Nova and Drain, Dawn and Ganguli, Deep and Hatfield-Dodds, Zac and Hernandez, Danny and Jones, Andy and Kernion, Jackson and Lovitt, Liane and Ndousse, Kamal and Amodei, Dario and Brown, Tom and Clark, Jack and Kaplan, Jared and McCandlish, Sam and Olah, Chris (2021)
leur entrée est obtenue (à normalisation près) par des transformations linéaires du flux résiduel (matrices Q, K, V des têtes d'attention, première couche linéaire du FFN);
leur sortie, juste avant addition au flux résiduel, est précédée par une transformation linéaire (concaténation des têtes d'attention, dernière couche linéaire du FFN).
On peut donc avoir l'intuition que les représentations que se fait le modèle se constituent progressivement au fil des couches, et par ailleurs qu'elles respectent aussi une certaine linéarité : si x représente une voiture, 2x ne devrait pas représenter un chat. Les représentations sont donc des directions dans l'espace latent. Par ailleurs, la linéarité doit aussi impliquer une certaine additivité des représentations, et on peut par exemple espérer que [la représentation de] reine) soit proche de roi - homme + femme, ce type de régularités linguistiques étant déjà observées depuis 2013[Mikolov, 2013]
Linguistic Regularities in Continuous Space Word Representations
Mikolov, Tomas and Yih, Scott Wen-tau and Zweig, Geoffrey (2013)
Les représentations, qu'on appelle aussi caractéristiques (ou features) du modèle sont des directions dans l'espace latent Rdmodel. Or, il n'y a a priori pas de raisons de penser que les représentations les plus explicites aient comme directions celles de la base canonique, lesquelles correspondent aux sorties des neurones. Aussi, un même neurone peut très bien intervenir dans des contextes très différents (par exemple[Olah, 2017]
Feature Visualization
Olah, Chris and Mordvintsev, Alexander and Schubert, Ludwig (2017)
un neurone d'un modèle de vision actif sur des visages d'animaux ou des voitures). Similairement, un même token xi,n peut traiter de plusieurs notions en même temps. Dans ce cas, on parlera parfois de de superposition[Olah, 2020]
Zoom In: An Introduction to Circuits
Olah, Chris and Cammarata, Nick and Schubert, Ludwig and Goh, Gabriel and Petrov, Michael and Carter, Shan (2020)
, et de caractéristiques polysémiques/polysémantiques.
Apprentisage de dictionnaire
Pour revenir à la question intiale, les neurones ne sont donc pas forcément facilement interprétables. On peut donc chercher à trouver des caractéristiques qui le soient, par exemple des caractéristiques monosémantiques. En quelque sorte, on cherche à "démêler" (disentanglement) les représentations. Pour cela, une approche consiste à faire de l'apprentisage de dictionnaire : pour un token tn=xi,n∈Rdmodel en sortie de la couche i−1, on va chercher un vecteur αn∈Rh tel que
tn≈tn:=Dαn
où D∈Rdmodel×h est le dictionnaire, h≫dmodel et αn est très parcimonieux/sparse (autrement dit t est une combinaison linéaire de quelques colonnes de D). C'est une méthode utilisée en traitement du signal, puisqu'il permet de représenter un signal avec un nombre minimal de composantes (on pourrait faire une analogie avec la décomposition en série de Fourier, où αn correspondrait ici aux amplitudes / phases de chaque harmonique).
On a donc approximativement le problème d'optimisation suivant
D,{αn}argminn∑(∣∣tn−Dαn∣∣22+λ∣∣αn∣∣0)
où ∣∣⋅∣∣0 est la pseudo-norme 0 qui compte le nombre de valeurs non-nulles, et où λ>0 arbitre donc entre la parcimonie/sparsité et la qualité de l'approximation, c.-à.-d. l'erreur entre tn et tn.
Il existe plusieurs algorithmes permettant de résoudre ce problème. Cependant, celui-ci est non-convexe (à cause de la pseudo-norme 0) et NP-difficile[Tillmann, 2015]
On the Computational Intractability of Exact and Approximate Dictionary Learning
Tillmann, Andreas M. (2015)
IEEE Signal Processing Letters, vol. 22(1), pp. 45-49.
) : on risque de trouver des caractéristiques que le modèle n'utilise pas, car trop coûteuses à calculer (et donc trop coûteuses d'accès pour le modèle). Il y a donc un risque que notre dictionnaire overfit sur les données et n'extraie pas correctement les représentations que le modèle manipule. Ces méthodes ne sont également pas les plus simples pour passer à l'échelle[Bricken, 2023]
Towards Monosemanticity: Decomposing Language Models With Dictionary Learning
Bricken, Trenton and Templeton, Adly and Batson, Joshua and Chen, Brian and Jermyn, Adam and Conerly, Tom and Turner, Nick and Anil, Cem and Denison, Carson and Askell, Amanda and Lasenby, Robert and Wu, Yifan and Kravec, Shauna and Schiefer, Nicholas and Maxwell, Tim and Joseph, Nicholas and Hatfield-Dodds, Zac and Tamkin, Alex and Nguyen, Karina and McLean, Brayden and Burke, Josiah E and Hume, Tristan and Carter, Shan and Henighan, Tom and Olah, Christopher (2023)
, or l'on souhaite idéalement traiter un corpus suffisament large (en milliards de tokens).
Sparse Autoencoders
Une autre approche consiste à entraîner un Sparse Autoencoder. Chaque token tn, passe dans un encodeur αn=enc(tn)∈Rdlatent, puis dans un decodeur qui reprojette dans la dimension d'origine tn=dec(fn)∈Rdmodel. Classiquement, il s'agit d'un réseau de neurones qu'on peut entraîner par rétro-propagation :
Schéma d'un Sparse Autoencoder
On cherche à minimiser l'erreur entre tn et tn et à ce que αn soit le plus sparse possible. On se ramène ainsi à une forme faible d'apprentissage de dictionnaire. Au biais du decoder près, tn a d'ailleurs la même forme, et le dictionnaire D correspond aux poids du decodeur. On peut donc espérer que les neurones f1,f2,… de la couche latente, autrement dit des dimensions / direction canonique de Rdlatent, correspondent à des caractéristiques monosémantiques, et soient des représentations manipulées par le modèle. Notamment, l'architecture du bloc feed-forward des LLM est très similaire à celle du SAE que l'on vient de décrire.
Il y a plusieurs variations de SAE, et les premières itérations semblent dater de décembre 2022[Sharkey, 2022]
Interim Research Report: Taking Features Out of Superposition with Sparse Autoencoders
, notamment autour de la fonction d'activation en sortie de l'encodeur. On peut par exemple avec une ReLU et imposer la sparsité au niveau de la loss (comme dans le problème d'optimisation de l'aprentissage de dictionnaire). Il s'agit de l'approche des chercheurs d'Anthropic documentée dans Scaling Monosemanticy[Templeton, 2024]
Scaling Monosemanticity: Extracting Interpretable Features from Claude 3 Sonnet
Templeton, Adly and Conerly, Tom and Marcus, Jonathan and Lindsey, Jack and Bricken, Trenton and Chen, Brian and Pearce, Adam and Citro, Craig and Ameisen, Emmanuel and Jones, Andy and Cunningham, Hoagy and Turner, Nicholas L and McDougall, Callum and MacDiarmid, Monte and Freeman, C. Daniel and Sumers, Theodore R. and Rees, Edward and Batson, Joshua and Jermyn, Adam and Carter, Shan and Olah, Chris and Henighan, Tom (2024)
Towards Monosemanticity: Decomposing Language Models With Dictionary Learning
Bricken, Trenton and Templeton, Adly and Batson, Joshua and Chen, Brian and Jermyn, Adam and Conerly, Tom and Turner, Nick and Anil, Cem and Denison, Carson and Askell, Amanda and Lasenby, Robert and Wu, Yifan and Kravec, Shauna and Schiefer, Nicholas and Maxwell, Tim and Joseph, Nicholas and Hatfield-Dodds, Zac and Tamkin, Alex and Nguyen, Karina and McLean, Brayden and Burke, Josiah E and Hume, Tristan and Carter, Shan and Henighan, Tom and Olah, Christopher (2023)
une variante où la sparsité est imposée avec une fonction d'activation TopK (on ne conserve que les K valeurs les plus élevées (et positives) de αn). D'autres variations ont également été étudiées (notamment GatedSAE[Rajamanoharan, 2024]
Improving Dictionary Learning with Gated Sparse Autoencoders
Rajamanoharan, Senthooran and Arthur Conmy and Lewis Smith and Tom Lieberum and Vikrant Varma and János Kramár and Rohin Shah and Neel Nanda (2024)
La première partie ci-dessous détaille les conditions générales d'entrainement. Les deux suivantes donnent des détails sur les optimisations (kernels, ...) que j'ai implémenté pour l'entrainement.
Présentation générale
En premier lieu il convient de choisir le LLM qu'on souhaite étudier. J'ai personnellement choisi de travailler avec Mistral-Small-3.2-24B-Instruct-2506, pour plusieurs raisons.
C'est un modèle (raisonnablement) petit : 24 milliards de paramètres, donc demande dans les 50 GO de VRAM en BF16 pour l'inférence (en pratique il faudra moins, voir sections optimisations plus bas).
C'est un modèle dense. Comme mentionné plus haut, les mixtures d'experts ont une architecture un peu différente, et mettent potentiellement à mal les hypothèses formulées. D'autre part, il y a actuellement bien plus de littérature sur les SAE entrainés sur des modèles dense que sur des MoE.
C'est un modèle "instruct". Une partie du jeu de données que j'ai utilisé était sous forme d'échanges, il était donc pratique de pouvoir formatter ainsi les entrées du modèle.
Il faut également choisir la couche i où récupérer le flux résiduel. Une intuition est que les concepts les plus simples se forment dans les premières couches, et que ceux plus complexes se situent plus loin dans le modèle Un certain nombre d'observations empire vont dans ce sens, et de nombreux travaux placent le SAE au milieu du modèle. Mistral-Small-3 ayant 40 couches, j'ai décidé de placer le SAE en sortie de la 30è couche (on regarde donc xi+1 avec i=29).
Dans ma configuration, j'entraîne le SAE de manière "online" (terminologie approximative) : Le LLM et le SAE en entrainement sont donc en même temps sur GPU, et j'infère directement le LLM jusqu'à avoir assez d'activations pour faire un batch avec lequel entraîner le SAE. J'ai utilisé une NVIDIA RTX PRO 6000 Blackwell avec 96 GO de VRAM. mistral-small-3 ayant a une dimension dmodel=5120, j'ai choisi d'entraîner mon SAE avec
dlatent=56×dmodel=286720
ce qui fait environ 3B paramètres. J'entraine nativement en BF16 avec AdamW (les poids, gradients et moments représentent donc dans les 24 GO de VRAM), et avec une batchsize de 2048. Chaque batch est mélangé et normalisé par un facteur dmodelE∣∣tn∣∣22 (basé sur les travaux d'Anthropic). Le LLM est lui inféré sur des batch de 16 séquences de 256 tokens.
Ces choix sont principalement motivés par la contrainte en VRAM (s'agissant d'un projet personnel je n'ai pas pu me permettre de meilleure configuration). J'ai éxécuté plusieurs tests afin d'arbitrer et choisir au mieux ces valeurs (cf. benchmarks.md). J'ai cherché à me rapprocher le plus proche possible de Scaling Monosemanticity, aussi la fonction d'activation est une ReLU et la sparsité est obtenue au niveau de la loss : pour un token tn, on calcule
λ>0 arbitre entre reconstruction et sparsité. La normalisation décrite plus haut m'a permis de reprendre λ=5 des travaux d'Anthropic. Wdec,j correspond à la colonne j du decodeur (soit la représention de la caractéristique fj). Il est important de l'intégrer à la pénalité sans quoi le modèle peut baisser artificiellement les valeurs de αn en augmentant les valeurs du decodeur. Une autre approche[Leo Gao, 2024]
Scaling and evaluating sparse autoencoders
Leo Gao and Tom Dupré la Tour and Henk Tillman and Gabriel Goh and Rajan Troll and Alec Radford and Ilya Sutskever and Jan Leike and Jeffrey Wu (2024)
Wenc est initialisé comme la transposée de Wdec, lui-même initialisé selon He/Kaiming uniforme ; les biais benc et bdec sont initialisés à 0. AdamW est configuré avec β1=0.9, β2=0.999, pas de weight decay et ε=10−15. J'utilise du gradient clipping, un "warmup" de λ de 0 à 5 sur les 5 premiers % de l'entrainement, et un learning rate de 5×10−5 constant puis qui décroit sur les derniers 20 % de l'entrainement.
J'ai effectué des entraînements sur 600 millions, 1.2 milliards et 1.6 milliards de tokens. Pour surveiller le bon déroulement, j'ai calculé et log périodiquement:
le nombre de valeurs non nulles (nnz) par token, avec des statistiques de répartition (moyenne, médiane, p90, p99) pour caractériser la sparsité effective ;
les fréquences d'activation des latents, moyenne, maximale, quantiles (p50, p90, p99), ainsi que la proportion de latents dans différentes plages de fréquence (0.1-1%, 1-10%, >10%) ;
les amplitudes des activations conditionnelles à l'activité, moyenne, écart-type et p99, pour caractériser la distribution des contributions ;
des métriques sur le résidu de reconstruction, notamment la variance expliquée, la corrélation avec l'entrée et la norme du résidu.
Optimisation de l'inférence (LLM)
Il existe de nombreux travaux pour accélérer l'inférence des modèles de langages. On peut notamment citer PagedAttention du projet vLLM[Kwon, 2023]
Efficient Memory Management for Large Language Model Serving with PagedAttention
Kwon, Woosuk and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica (2023)
. Si les optimisations du premier concernent l'utilisation autorégressive et concurrente du modèle (ce qui n'est pas notre cas), celles de FlashAttention nous sont tout à fait utiles, augmentent le débit de tokens par seconde et réduisent la mémoire utilisée.
La manière classique de récupérer les activations/états intermédiaires d'un modèle est d'utiliser des hook ou des bibliothèques comme TransformerLens (qui elle aussi utilise des hook). Dans notre cas on peut aussi utiliser output_hidden_states=True dans le forward d'un modèle chargé avec transformers. Cette approche présente plusieurs défauts.
Elle infère toutes les couches du modèles, y compris les i′>i situées après le SAE, dont nous n'avons pas besoin.
Elle charge en VRAM la partie vision du modèle (vision tower et projecteur multimodal) dont nous n'avons pas besoin.
On peut plutôt modifier l'implémentation du modèle dans transformers pour retirer la partie multimodale ("langage seul"), et les couches non utiles et normalisation finale ("tranché").
Performance des implémentations
Sans FlashAttention ou compilation. mistralai/Mistral-Small-3.2-24B-Instruct-2506, BF16, i = 30.
Comme dit plus haut, un autre avantage et qu'on peut compiler le module obtenu. On teste la performance pour différentes méthodes de compilation possibles (avec et sans FlashAttention), puis pour différentes tailles de batch et de séquence. Il faut noter que la compilation demande idéalement d'avoir des tenseurs de forme fixe. On ajoute des tokens de padding et masque d'attention pour les séquences trop courtes, et on tronque les séquences trop longue. Il faut noter que les tokens de padding doivent être retirés après inférence avant d'être donnés au SAE, ce qui crée un surcoût (croissant en la longueur de séquence souhaitée). Au final, on parvient atteint plus du double en débit par rapport à l'implémentation de base.
Latence et mémoire maximale des implémentations
mistralai/Ministral-3-3B-Base-2512, BF16, i = 30.
Optimisation de l'entraînement (SAE)
Le SAE a globalement une structure très simple. Les optimisations à mettre en place tirent principalement profit de son caractère parcimonieux/sparse : pour un token, seulement quelques caractéristiques parmi f1,f2,…,fdlatent sont actives. On peut donc trouver une meilleure manière de représenter la sortie αn de l’encodeur, et implémenter de manière moins naïve l’encodeur et le decodeur. Commençons par l'encodeur : on a
αn=enc(tn)=ReLU(Wenctn+benc)
En pratique, on traite un batch de B tokens. Dans la suite, on notera ⋅ABC un tenseur de forme (A,B,C) ; notamment on notera WHD=Wenc les poids et bH=benc biais de l'encodeur. Pour le forward, notre encodeur effectue
zBHaBH=xBD(WHD)⊤+1BbH⊤=ReLU(zBH)
On va donc chercher à représenter de manière sparse aBH.
Dans le cas d'un SAE TopK, on obtient les matrices iBK et aBK de formes B×K suivantes :
∀0≤b<B,iBK[b,:],aBK[b,:]=TopK(aBH[b,:])
où aBH[b,:] est la b-ième ligne de aBH. Ses K plus grandes valeurs sont données par aBK[b,:], situées aux colonnes iBK[b,:] et et valeurs. Pour chaque token b du batch, on stocke donc quelles sont les K caractéristiques les plus actives et leur niveau d'activation.
Le cas d'un SAE où l'on a plutôt une pénalité et simplement une activation ReLU est un peu plus subtil, car on n'a pas de nombre fixé de caractéristiques actives par token ou même par batch. Notons N le nombre total de caractéristiques activées sur le batch. On a les vecteurs iBN, iHN et aN de taille N suivants :
où where>0 donne les lignes iBN, colonnes iHN et valeurs correspondantes aN strictement positives de aBH. Il s'agit du format COOrdinate (ou COO) pour les tenseurs sparse. Pour un 0≤n<N donné, aN[n] donne l'activation de la caractéristique iHN[n] pour le token iBN[n].
Pour le backward, on cherche à calculer ∂WHD∂L le gradient des poids, ∂bH∂L celui des bias et potentiellement ∂xBD∂L celui des entrées. Pour rappel, on a aBH=ReLU(zBH) donc
∂zBH∂L=JReLU(zBH)⊤∂aBH∂L=1zBH>0⊙∂aBH∂L
où ⊙ est le produit élément par élément (Hadamard).
Puisqu'on a une représentation sparse de aBH, il en va de même de son gradient ∂aBH∂L et donc du gradient ∂zBH∂L. Dans notre cas, on n'a pas besoin des gradients sur les entrées xBD. On cherche donc à faire la somme d'une matrice sparse le long d'une dimension pour le gradient de bH, et la multuplication de la transposée d'une matrice sparse par une matrice dense pour celui de WHD. Pour le premier, un scatter_add fait l'affaire. Pour le second, il va s'agir d'implémenter un kernel SpMM dédié.
Commençons par le cas TopK. Le gradient Gz=∂zBH∂L est une matrice sparse, représenté de la même manière que aBH : pour chaque ligne 0≤b<B (soit chaque token du batch), on a exactement K valeurs non-nulles, situées aux colonnes d'indices iBK[b,:]. On veut calculer le gradient GW=∂WHD∂L=Gz⊤xBD. On peut donc initialiser GW à zéro, puis ajouter la contribution de chaque token b du batch. On multiplie la ligne la ligne b de xBD par les K valeurs de la colonne b de G⊤, situées aux lignes iBK[b,:]. On ajoute ces K lignes à GW aux positions iBK[b,:].
Kernel SpMM scatter
On démarre B programmes, chacun s'occupant d'un b. On peut même subdiviser manuellement le travail selon D, en démarrant une grille B×BSD où chaque programme s'occupe d'un b et de kBSD≤d<(k+1)BSD. Un des avantages de cette approche est que les accès mémoire sont contigus : on lit xBD ligne par ligne et (Gz)⊤ colonne par colonne. Le défaut est cependant la présence d'additions atomiques à GW : une même caractéristique h peut être active dans plusieurs tokens donc plusieurs programmes peuvent ajouter aux mêmes positions. On a ici un kernel "scatter".
Cette situation est très similaire à celle de vouloir aggréger les sorties de différents experts d'une MoE : un même token n'active que quelques experts (caractère sparse), et plusieurs experts peuvent contribuer au même token.