En cours de rédaction

Cet article est encore en cours de rédaction et n'est pour l'instant disponible qu'en français. Vous pouvez tout de même le lire et accéder au dépôt GitHub associé, mais le contenu est incomplet.

Training SAEs

Reproduction de quelques résultats d'interprétabilité ; entrainement de Sparse Autoencoders, kernels SpMM et autres optimisations.

AUTEUR
Rémy SIAHAAN--GENSOLLEN
PUBLIÉ LE
25 avril 2026

Le projet MistralSAE est disponible sur le lien GitHub ci-contre. Il contient le code des différentes optimisations et une CLI pour lancer la pipeline d'entrainement complète, ainsi qu'une interface web de chat avec un modèle « steeré ». Le checkpoint du SAE (~3B paramètres) est également disponible sur Hugging Face.

Je me souviens assez clairement du moment où j'ai lu Scaling Monosemanticity par Templeton et al. en 2024. C'est en explorant Transformer Circuits que j'ai découvert la recherche en interpretabilité, et comme beaucoup d'autres j'ai trouvé la possibilité de piloter (steering) les LLMs particulièrement intriguante — et amusante, cf. Golden Gate Claude. Un peu plus d'un an après et quelques lectures en plus j'ai donc cherché à répliquer ces expériences, et entraîner moi-même un Sparse Autoencoder, ou SAE, pour faire de l'apprentissage par dictionnaire. Cet article détaille mes travaux dans ce but, les différentes optimisations que j'ai mises en place (kernels triton pour l'entraînement, inférence des activations plus rapide, ...) et quelques résultats de steering.

Cadre théorique

Transformers, caractéristiques et polysémie

Comme son nom l'indique, l'interprétabilité mécaniste cherche à comprendre le comportement d'un réseau de neurones en y explicitant des mécanismes internes qui lui font donner tel résultat. Une approche classique pour traiter ce type de questions est d'essayer de « reverse engineer » l'objet d'intérêt, en le décomposant en des parties plus petites et plus prédictibles, et en comprenant comment celles-ci intéragissent. Commençons donc par rappeler brièvement l'architecture des LLM récents. Dans les grandes lignes, elle est assez similaire entre ces derniers, et basée sur la partie décodeur du transformer[Ashish Vaswani, 2023]

Attention Is All You Need

Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin (2023)

Source

.
  • un texte d'entrée est découpé en NN tokens, pour lesquels on donne un embedding dmodeld_{\textrm{model}} qui dépend du token et de la position — on a donc un tenseur x0x_0 de forme (N,dmodel)(N, d_{\textrm{model}}) ;
  • Pour chaque couche 0i<L0 \leq i < L du transformer, on lit xix_{i} et calcule xi+1=i(xi)x_{i+1} = \ell_i(x_i), également de forme (N,dmodel)(N, d_{\textrm{model}}) ;
  • en sortie de la dernière couche, on normalise et projette xLx_L pour obtenir les logits utilisés pour prédire le token suivant.
Architecture d'un LLM moderne (approximativement)

Valeurs issues de mistralai/Ministral-3-3B-Base-2512

Chaque couche ii est composée de deux blocs : un bloc d'attention et un bloc feed-forward (pour les modèles dense — c'est un peu différent pour les mixtures d'experts), chacun avec normalisation. Depuis GPT-2[Radford, 2019]

Language Models are Unsupervised Multitask Learners

Radford, Alec and Wu, Jeffrey and Child, Rewon and Luan, David and Amodei, Dario and Sutskever, Ilya (2019)

OpenAI.

Source

, la normalisation par couche se fait généralement en entrée de chaque bloc, plutôt qu'en sortie comme dans le transformer original[Ashish Vaswani, 2023]

Attention Is All You Need

Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin (2023)

Source

.
LayerNormMulti-HeadAttentionLayerNormFeedForwardNetwork
Couche ii du transformer (approximativement)
On remarquera que chaque bloc « met à jour » (xi)i(x_i)_i, qu'on appele d'ailleurs le flux résiduel du modèle. Il faut également noter qu'ils interagissent linéairement[Elhage, 2021]

A Mathematical Framework for Transformer Circuits

Elhage, Nelson and Nanda, Neel and Olsson, Catherine and Henighan, Tom and Joseph, Nicholas and Mann, Ben and Askell, Amanda and Bai, Yuntao and Chen, Anna and Conerly, Tom and DasSarma, Nova and Drain, Dawn and Ganguli, Deep and Hatfield-Dodds, Zac and Hernandez, Danny and Jones, Andy and Kernion, Jackson and Lovitt, Liane and Ndousse, Kamal and Amodei, Dario and Brown, Tom and Clark, Jack and Kaplan, Jared and McCandlish, Sam and Olah, Chris (2021)

Transformer Circuits Thread.

Source

avec ce flux résiduel, au sens où :
  • leur entrée est obtenue (à normalisation près) par des transformations linéaires du flux résiduel (matrices Q, K, V des têtes d'attention, première couche linéaire du FFN);
  • leur sortie, juste avant addition au flux résiduel, est précédée par une transformation linéaire (concaténation des têtes d'attention, dernière couche linéaire du FFN).
On peut donc avoir l'intuition que les représentations que se fait le modèle se constituent progressivement au fil des couches, et par ailleurs qu'elles respectent aussi une certaine linéarité : si xx représente une voiture, 2x2x ne devrait pas représenter un chat. Les représentations sont donc des directions dans l'espace latent. Par ailleurs, la linéarité doit aussi impliquer une certaine additivité des représentations, et on peut par exemple espérer que [la représentation de] reine) soit proche de roi - homme + femme, ce type de régularités linguistiques étant déjà observées depuis 2013[Mikolov, 2013]

Linguistic Regularities in Continuous Space Word Representations

Mikolov, Tomas and Yih, Scott Wen-tau and Zweig, Geoffrey (2013)

Source

.
Les représentations, qu'on appelle aussi caractéristiques (ou features) du modèle sont des directions dans l'espace latent Rdmodel\mathbb{R}^{d_{\textrm{model}}}. Or, il n'y a a priori pas de raisons de penser que les représentations les plus explicites aient comme directions celles de la base canonique, lesquelles correspondent aux sorties des neurones. Aussi, un même neurone peut très bien intervenir dans des contextes très différents (par exemple[Olah, 2017]

Feature Visualization

Olah, Chris and Mordvintsev, Alexander and Schubert, Ludwig (2017)

Distill.

DOI: 10.23915/distill.00007

un neurone d'un modèle de vision actif sur des visages d'animaux ou des voitures). Similairement, un même token xi,nx_{i,n} peut traiter de plusieurs notions en même temps. Dans ce cas, on parlera parfois de de superposition[Olah, 2020]

Zoom In: An Introduction to Circuits

Olah, Chris and Cammarata, Nick and Schubert, Ludwig and Goh, Gabriel and Petrov, Michael and Carter, Shan (2020)

Distill.

DOI: 10.23915/distill.00024.001

, et de caractéristiques polysémiques/polysémantiques.

Apprentisage de dictionnaire

Pour revenir à la question intiale, les neurones ne sont donc pas forcément facilement interprétables. On peut donc chercher à trouver des caractéristiques qui le soient, par exemple des caractéristiques monosémantiques. En quelque sorte, on cherche à « démêler » (disentanglement) les représentations. Pour cela, une approche consiste à faire de l'apprentisage de dictionnaire : pour un token tn=xi,nRdmodelt_n = x_{i,n} \in \mathbb{R}^{d_{\text{model}}} en sortie de la couche i1i-1, on va chercher un vecteur αnRh\alpha_n \in \mathbb{R}^h tel que
tntn^:=Dαnt_n \approx \widehat{t_n} := D\alpha_n
DRdmodel×hD \in \mathbb{R}^{d_{\text{model}}\times h} est le dictionnaire, hdmodelh \gg d_{\text{model}} et αn\alpha_n est très parcimonieux/sparse (autrement dit tt est une combinaison linéaire de quelques colonnes de DD). C'est une méthode utilisée en traitement du signal, puisqu'il permet de représenter un signal avec un nombre minimal de composantes (on pourrait faire une analogie avec la décomposition en série de Fourier, où αn\alpha_n correspondrait ici aux amplitudes / phases de chaque harmonique).
On a donc approximativement le problème d'optimisation suivant
arg minD,{αn}n(tnDαn22+λαn0)\argmin_{D,\{\alpha_n\}} \sum_{n} \left({||t_n - D\alpha_n ||}^2_2 + \lambda {||\alpha_n||}_0\right)
0{||\cdot||}_0 est la pseudo-norme 0 qui compte le nombre de valeurs non-nulles, et où λ>0\lambda > 0 arbitre donc entre la parcimonie/sparsité et la qualité de l'approximation, c.-à.-d. l'erreur entre tn^\widehat{t_n} et tnt_n.
Il existe plusieurs algorithmes permettant de résoudre ce problème. Cependant, celui-ci est non-convexe (à cause de la pseudo-norme 0) et NP-difficile[Tillmann, 2015]

On the Computational Intractability of Exact and Approximate Dictionary Learning

Tillmann, Andreas M. (2015)

IEEE Signal Processing Letters, vol. 22(1), pp. 45-49.

DOI: 10.1109/LSP.2014.2345761

. En conséquence, le résoudre est potentiellement « trop fort » (voir les arguments de Tom Henighan et Chris Olah[Henighan, 2023]

Circuits Updates, May 2023, Dictionary Learning Worries

Henighan, Tom and Olah, Chris (2023)

Source

) : on risque de trouver des caractéristiques que le modèle n'utilise pas, car trop coûteuses à calculer (et donc trop coûteuses d'accès pour le modèle). Il y a donc un risque que notre dictionnaire overfit sur les données et n'extraie pas correctement les représentations que le modèle manipule. Ces méthodes ne sont également pas les plus simples pour passer à l'échelle[Bricken, 2023]

Towards Monosemanticity: Decomposing Language Models With Dictionary Learning

Bricken, Trenton and Templeton, Adly and Batson, Joshua and Chen, Brian and Jermyn, Adam and Conerly, Tom and Turner, Nick and Anil, Cem and Denison, Carson and Askell, Amanda and Lasenby, Robert and Wu, Yifan and Kravec, Shauna and Schiefer, Nicholas and Maxwell, Tim and Joseph, Nicholas and Hatfield-Dodds, Zac and Tamkin, Alex and Nguyen, Karina and McLean, Brayden and Burke, Josiah E and Hume, Tristan and Carter, Shan and Henighan, Tom and Olah, Christopher (2023)

Transformer Circuits Thread.

Source

, or l'on souhaite idéalement traiter un corpus suffisament large (en milliards de tokens).

Sparse Autoencoders

Une autre approche consiste à entraîner un Sparse Autoencoder. Chaque token tnt_n, passe dans un encodeur αn=enc(tn)Rdlatent\alpha_n = \text{enc}(t_n) \in \mathbb{R}^{d_{\text{latent}}}, puis dans un decodeur qui reprojette dans la dimension d'origine tn^=dec(fn)Rdmodel\widehat{t_n} = \text{dec}(f_n) \in \mathbb{R}^{d_{\text{model}}}. Classiquement, il s'agit d'un réseau de neurones qu'on peut entraîner par rétro-propagation :
Schéma d'un Sparse Autoencoder
On cherche à minimiser l'erreur entre tn^\widehat{t_n} et tnt_n et à ce que αn\alpha_n soit le plus sparse possible. On se ramène ainsi à une forme faible d'apprentissage de dictionnaire. Au biais du decoder près, tn^\widehat{t_n} a d'ailleurs la même forme, et le dictionnaire DD correspond aux poids du decodeur. On peut donc espérer que les neurones f1,f2,f_1, f_2, \dots de la couche latente, autrement dit des dimensions / direction canonique de Rdlatent\mathbb{R}^{d_{\text{latent}}}, correspondent à des caractéristiques monosémantiques, et soient des représentations manipulées par le modèle. Notamment, l'architecture du bloc feed-forward des LLM est très similaire à celle du SAE que l'on vient de décrire.
Il y a plusieurs variations de SAE, et les premières itérations semblent dater de décembre 2022[Sharkey, 2022]

Interim Research Report: Taking Features Out of Superposition with Sparse Autoencoders

Sharkey, Lee and Braun, Dan and beren (2022)

Source

, notamment autour de la fonction d'activation en sortie de l'encodeur. On peut par exemple avec une ReLU\text{ReLU} et imposer la sparsité au niveau de la loss (comme dans le problème d'optimisation de l'aprentissage de dictionnaire). Il s'agit de l'approche des chercheurs d'Anthropic documentée dans Scaling Monosemanticy[Templeton, 2024]

Scaling Monosemanticity: Extracting Interpretable Features from Claude 3 Sonnet

Templeton, Adly and Conerly, Tom and Marcus, Jonathan and Lindsey, Jack and Bricken, Trenton and Chen, Brian and Pearce, Adam and Citro, Craig and Ameisen, Emmanuel and Jones, Andy and Cunningham, Hoagy and Turner, Nicholas L and McDougall, Callum and MacDiarmid, Monte and Freeman, C. Daniel and Sumers, Theodore R. and Rees, Edward and Batson, Joshua and Jermyn, Adam and Carter, Shan and Olah, Chris and Henighan, Tom (2024)

Transformer Circuits Thread.

Source

(et[Bricken, 2023]

Towards Monosemanticity: Decomposing Language Models With Dictionary Learning

Bricken, Trenton and Templeton, Adly and Batson, Joshua and Chen, Brian and Jermyn, Adam and Conerly, Tom and Turner, Nick and Anil, Cem and Denison, Carson and Askell, Amanda and Lasenby, Robert and Wu, Yifan and Kravec, Shauna and Schiefer, Nicholas and Maxwell, Tim and Joseph, Nicholas and Hatfield-Dodds, Zac and Tamkin, Alex and Nguyen, Karina and McLean, Brayden and Burke, Josiah E and Hume, Tristan and Carter, Shan and Henighan, Tom and Olah, Christopher (2023)

Transformer Circuits Thread.

Source

). Des chercheurs d'OpenAI ont introduit[Leo Gao, 2024]

Scaling and evaluating sparse autoencoders

Leo Gao and Tom Dupré la Tour and Henk Tillman and Gabriel Goh and Rajan Troll and Alec Radford and Ilya Sutskever and Jan Leike and Jeffrey Wu (2024)

Source

une variante où la sparsité est imposée avec une fonction d'activation TopK\text{Top}_K (on ne conserve que les KK valeurs les plus élevées (et positives) de αn\alpha_n). D'autres variations ont également été étudiées (notamment GatedSAE[Rajamanoharan, 2024]

Improving Dictionary Learning with Gated Sparse Autoencoders

Rajamanoharan, Senthooran and Arthur Conmy and Lewis Smith and Tom Lieberum and Vikrant Varma and János Kramár and Rohin Shah and Neel Nanda (2024)

Source

et JumpReLU SAE[Rajamanoharan, 2024]

Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders

Rajamanoharan, Senthooran and Tom Lieberum and Nicolas Sonnerat and Arthur Conmy and Vikrant Varma and János Kramár and Neel Nanda (2024)

Source

).

Entrainement

La première partie ci-dessous détaille les conditions générales d'entrainement. Les deux suivantes donnent des détails sur les optimisations (kernels, ...) que j'ai implémenté pour l'entrainement.

Présentation générale

En premier lieu il convient de choisir le LLM qu'on souhaite étudier. J'ai personnellement choisi de travailler avec Mistral-Small-3.2-24B-Instruct-2506, pour plusieurs raisons.
  • C'est un modèle (raisonnablement) petit : 24 milliards de paramètres, donc demande dans les 50 GO de VRAM en BF16 pour l'inférence (en pratique il faudra moins, voir sections optimisations plus bas).
  • C'est un modèle dense. Comme mentionné plus haut, les mixtures d'experts ont une architecture un peu différente, et mettent potentiellement à mal les hypothèses formulées. D'autre part, il y a actuellement bien plus de littérature sur les SAE entrainés sur des modèles dense que sur des MoE.
  • C'est un modèle « instruct ». Une partie du jeu de données que j'ai utilisé était sous forme d'échanges, il était donc pratique de pouvoir formatter ainsi les entrées du modèle.
Il faut également choisir la couche ii où récupérer le flux résiduel. Une intuition est que les concepts les plus simples se forment dans les premières couches, et que ceux plus complexes se situent plus loin dans le modèle Un certain nombre d'observations empire vont dans ce sens, et de nombreux travaux placent le SAE au milieu du modèle. Mistral-Small-3 ayant 40 couches, j'ai décidé de placer le SAE en sortie de la 30è couche (on regarde donc xi+1x_{i+1} avec i=29i = 29).
Dans ma configuration, j'entraîne le SAE de manière "online" (terminologie approximative) : Le LLM et le SAE en entrainement sont donc en même temps sur GPU, et j'infère directement le LLM jusqu'à avoir assez d'activations pour faire un batch avec lequel entraîner le SAE. J'ai utilisé une NVIDIA RTX PRO 6000 Blackwell avec 96 GO de VRAM. mistral-small-3 ayant a une dimension dmodel=5120d_{\text{model}} = 5120, j'ai choisi d'entraîner mon SAE avec
dlatent=56×dmodel=286720d_{\text{latent}} = 56 \times d_{\text{model}} = 286 720
ce qui fait environ 3B paramètres. J'entraine nativement en BF16 avec AdamW (les poids, gradients et moments représentent donc dans les 24 GO de VRAM), et avec une batchsize de 20482048. Chaque batch est mélangé et normalisé par un facteur Etn22dmodel\sqrt{\frac{\bdE{||t_n||}_2^2}{d_{\text{model}}}} (basé sur les travaux d'Anthropic). Le LLM est lui inféré sur des batch de 1616 séquences de 256256 tokens.
Ces choix sont principalement motivés par la contrainte en VRAM (s'agissant d'un projet personnel je n'ai pas pu me permettre de meilleure configuration). J'ai éxécuté plusieurs tests afin d'arbitrer et choisir au mieux ces valeurs (cf. benchmarks.md). J'ai cherché à me rapprocher le plus proche possible de Scaling Monosemanticity, aussi la fonction d'activation est une ReLU\text{ReLU} et la sparsité est obtenue au niveau de la loss : pour un token tnt_n, on calcule
αn=enc(tn)=ReLU(Wenctn+benc)tn^=dec(αn)=Wdecαn+bdec\begin{align*} \alpha_n = \text{enc}(t_n) &= \text{ReLU}\left(W_{\text{enc}}t_n + b_{\text{enc}}\right)\\ \widehat{t_n} = \text{dec}(\alpha_n) &= W_{\text{dec}}\alpha_n + b_{\text{dec}} \end{align*}
Wenc,Wdec,benc,bdecW_{\text{enc}}, W_{\text{dec}}, b_{\text{enc}}, b_{\text{dec}} sont les poids et biais de l'encodeur et du decodeur. La loss est
L=E[tn^tn22]+λE[jdlatentαn,jWdec,j22]\bcL = \bdE\left[{||\widehat{t_n} - t_n||}_2^2\right] + \lambda\bdE\left[\sum_{j}^{d_{\text{latent}}}\alpha_{n,j}{||W_{\text{dec},j}||}_2^2\right]
λ>0\lambda > 0 arbitre entre reconstruction et sparsité. La normalisation décrite plus haut m'a permis de reprendre λ=5\lambda = 5 des travaux d'Anthropic. Wdec,jW_{\text{dec},j} correspond à la colonne jj du decodeur (soit la représention de la caractéristique fjf_j). Il est important de l'intégrer à la pénalité sans quoi le modèle peut baisser artificiellement les valeurs de αn\alpha_n en augmentant les valeurs du decodeur. Une autre approche[Leo Gao, 2024]

Scaling and evaluating sparse autoencoders

Leo Gao and Tom Dupré la Tour and Henk Tillman and Gabriel Goh and Rajan Troll and Alec Radford and Ilya Sutskever and Jan Leike and Jeffrey Wu (2024)

Source

consiste à normaliser les Wdec,jW_{\text{dec},j} à chaque step.
WencW_{\text{enc}} est initialisé comme la transposée de WdecW_{\text{dec}}, lui-même initialisé selon He/Kaiming uniforme ; les biais bencb_{\text{enc}} et bdecb_{\text{dec}} sont initialisés à 0. AdamW est configuré avec β1=0.9\beta_1 = 0.9, β2=0.999\beta_2 = 0.999, pas de weight decay et ε=1015\varepsilon = 10^{-15}. J'utilise du gradient clipping, un « warmup » de λ\lambda de 00 à 55 sur les 5 premiers % de l'entrainement, et un learning rate de 5×1055 \times 10^{-5} constant puis qui décroit sur les derniers 20 % de l'entrainement.
J'ai effectué des entraînements sur 600 millions, 1.2 milliards et 1.6 milliards de tokens. Pour surveiller le bon déroulement, j'ai calculé et log périodiquement:
  • le nombre de valeurs non nulles (nnz) par token, avec des statistiques de répartition (moyenne, médiane, p90, p99) pour caractériser la sparsité effective ;
  • les fréquences d'activation des latents, moyenne, maximale, quantiles (p50, p90, p99), ainsi que la proportion de latents dans différentes plages de fréquence (0.1-1%, 1-10%, >10%) ;
  • les amplitudes des activations conditionnelles à l'activité, moyenne, écart-type et p99, pour caractériser la distribution des contributions ;
  • des métriques sur le résidu de reconstruction, notamment la variance expliquée, la corrélation avec l'entrée et la norme du résidu.

Optimisation de l'inférence (LLM)

Il existe de nombreux travaux pour accélérer l'inférence des modèles de langages. On peut notamment citer PagedAttention du projet vLLM[Kwon, 2023]

Efficient Memory Management for Large Language Model Serving with PagedAttention

Kwon, Woosuk and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica (2023)

Source

ou FlashAttention de Tri Dao[Dao, 2023]

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

Dao, Tri (2023)

Source

. Si les optimisations du premier concernent l'utilisation autorégressive et concurrente du modèle (ce qui n'est pas notre cas), celles de FlashAttention nous sont tout à fait utiles, augmentent le débit de tokens par seconde et réduisent la mémoire utilisée.
La manière classique de récupérer les activations/états intermédiaires d'un modèle est d'utiliser des hook ou des bibliothèques comme TransformerLens (qui elle aussi utilise des hook). Dans notre cas on peut aussi utiliser output_hidden_states=True dans le forward d'un modèle chargé avec transformers. Cette approche présente plusieurs défauts.
  • Elle infère toutes les couches du modèles, y compris les i>ii' > i situées après le SAE, dont nous n'avons pas besoin.
  • Elle charge en VRAM la partie vision du modèle (vision tower et projecteur multimodal) dont nous n'avons pas besoin.
  • Elle est peu ou pas compatible avec torch.compile().
On peut plutôt modifier l'implémentation du modèle dans transformers pour retirer la partie multimodale (« langage seul »), et les couches non utiles et normalisation finale (« tranché »).
Performance des implémentations

Sans FlashAttention ou compilation. mistralai/Mistral-Small-3.2-24B-Instruct-2506, BF16, i = 30.

Comme dit plus haut, un autre avantage et qu'on peut compiler le module obtenu. On teste la performance pour différentes méthodes de compilation possibles (avec et sans FlashAttention), puis pour différentes tailles de batch et de séquence. Il faut noter que la compilation demande idéalement d'avoir des tenseurs de forme fixe. On ajoute des tokens de padding et masque d'attention pour les séquences trop courtes, et on tronque les séquences trop longue. Il faut noter que les tokens de padding doivent être retirés après inférence avant d'être donnés au SAE, ce qui crée un surcoût (croissant en la longueur de séquence souhaitée). Au final, on parvient atteint plus du double en débit par rapport à l'implémentation de base.
Latence et mémoire maximale des implémentations

mistralai/Ministral-3-3B-Base-2512, BF16, i = 30.

Optimisation de l'entraînement (SAE)

Le SAE a globalement une structure très simple. Les optimisations à mettre en place tirent principalement profit de son caractère parcimonieux/sparse : pour un token, seulement quelques caractéristiques parmi f1,f2,,fdlatentf_1, f_2, \dots{}, f_{d_{\text{latent}}} sont actives. On peut donc trouver une meilleure manière de représenter la sortie αn\alpha_n de l’encodeur, et implémenter de manière moins naïve l’encodeur et le decodeur. Commençons par l'encodeur : on a
αn=enc(tn)=ReLU(Wenctn+benc)\alpha_n = \text{enc}(t_n) = \text{ReLU}\left(W_{\text{enc}}t_n + b_{\text{enc}}\right)
En pratique, on traite un batch de BB tokens. Dans la suite, on notera ABC\cdot_{ABC} un tenseur de forme (A,B,C)(A, B, C) ; notamment on notera WHD=WencW_{HD} = W_{\text{enc}} les poids et bH=bencb_H = b_{\text{enc}} biais de l'encodeur. Pour le forward, notre encodeur effectue
zBH=xBD(WHD)+1BbHaBH=ReLU(zBH)\begin{align*} z_{BH} &= x_{BD}(W_{HD})^\top + 1_{B}b_H^\top\\ a_{BH} &= \text{ReLU}(z_{BH})\\ \end{align*}
On va donc chercher à représenter de manière sparse aBHa_{BH}.
Dans le cas d'un SAE TopK\text{Top}_K, on obtient les matrices iBKi_{BK} et aBKa_{BK} de formes B×KB \times K suivantes :
0b<B,iBK[b,:],aBK[b,:]=TopK(aBH[b,:])\forall 0 \leq b < B,\qquad i_{BK}[b, :], a_{BK}[b, :] = \text{Top}_K(a_{BH}[b, :])
aBH[b,:]a_{BH}[b, :] est la bb-ième ligne de aBHa_{BH}. Ses KK plus grandes valeurs sont données par aBK[b,:]a_{BK}[b, :], situées aux colonnes iBK[b,:]i_{BK}[b, :] et et valeurs. Pour chaque token bb du batch, on stocke donc quelles sont les KK caractéristiques les plus actives et leur niveau d'activation.
Le cas d'un SAE où l'on a plutôt une pénalité et simplement une activation ReLU\text{ReLU} est un peu plus subtil, car on n'a pas de nombre fixé de caractéristiques actives par token ou même par batch. Notons NN le nombre total de caractéristiques activées sur le batch. On a les vecteurs iBNiB_N, iHNiH_N et aNa_N de taille NN suivants :
iBN,iHN=where>0(aBH)0n<N,aN[n]=aBH[BN[n],HN[n]]\begin{align*} &&iB_N, iH_N &= \text{where}_{> 0}(a_{BH})\\ \forall 0 \leq n < N,\qquad&& a_N[n] &= a_{BH}[B_N[n], H_N[n]]\end{align*}
where>0\text{where}_{> 0} donne les lignes iBNiB_N, colonnes iHNiH_N et valeurs correspondantes aNa_N strictement positives de aBHa_{BH}. Il s'agit du format COOrdinate (ou COO) pour les tenseurs sparse. Pour un 0n<N0 \leq n < N donné, aN[n]a_N[n] donne l'activation de la caractéristique iHN[n]iH_N[n] pour le token iBN[n]iB_N[n].
Pour le backward, on cherche à calculer LWHD\frac{\partial\bcL}{\partial W_{HD}} le gradient des poids, LbH\frac{\partial\bcL}{\partial b_{H}} celui des bias et potentiellement LxBD\frac{\partial\bcL}{\partial x_{BD}} celui des entrées. Pour rappel, on a aBH=ReLU(zBH)a_{BH} = \text{ReLU}(z_{BH}) donc
LzBH=JReLU(zBH)LaBH=1zBH>0LaBH\dfrac{\partial \bcL}{\partial z_{BH}} = J_{\text{ReLU}}(z_{BH})^\top \dfrac{\partial \bcL}{\partial a_{BH}} = \mathbf{1}_{z_{BH} > 0} \odot \dfrac{\partial \bcL}{\partial a_{BH}}
\odot est le produit élément par élément (Hadamard).
Ensuite, zBH=xBD(WHD)+1BbHz_{BH} = x_{BD}(W_{HD})^\top + 1_{B}b_H^\top nous livre
LWHD=(LzBH)xBDLxBD=(LzBH)WHDLbH=(LzBH)1B=BLzBH\begin{align*} \dfrac{\partial \bcL}{\partial W_{HD}} &= \left( \dfrac{\partial \bcL}{\partial z_{BH}}\right)^\top x_{BD} \\ \dfrac{\partial \bcL}{\partial x_{BD}} &= \left(\dfrac{\partial \bcL}{\partial z_{BH}}\right)W_{HD}\\ \dfrac{\partial \bcL}{\partial b_{H}} &= \left(\dfrac{\partial \bcL}{\partial z_{BH}} \right)^\top 1_B = \sum_{B} \dfrac{\partial \bcL}{\partial z_{BH}} \end{align*}
Puisqu'on a une représentation sparse de aBHa_{BH}, il en va de même de son gradient LaBH\frac{\partial \bcL}{\partial a_{BH}} et donc du gradient LzBH\frac{\partial \bcL}{\partial z_{BH}}. Dans notre cas, on n'a pas besoin des gradients sur les entrées xBDx_{BD}. On cherche donc à faire la somme d'une matrice sparse le long d'une dimension pour le gradient de bHb_H, et la multuplication de la transposée d'une matrice sparse par une matrice dense pour celui de WHDW_{HD}. Pour le premier, un scatter_add fait l'affaire. Pour le second, il va s'agir d'implémenter un kernel SpMM dédié.
Commençons par le cas TopK\text{TopK}. Le gradient Gz=LzBHG_z = \frac{\partial \bcL}{\partial z_{BH}} est une matrice sparse, représenté de la même manière que aBHa_{BH} : pour chaque ligne 0b<B0 \leq b < B (soit chaque token du batch), on a exactement KK valeurs non-nulles, situées aux colonnes d'indices iBK[b,:]i_{BK}[b, :]. On veut calculer le gradient GW=LWHD=GzxBDG_W = \frac{\partial \bcL}{\partial W_{HD}} = G_z^\top x_{BD}. On peut donc initialiser GWG_W à zéro, puis ajouter la contribution de chaque token bb du batch. On multiplie la ligne la ligne bb de xBDx_{BD} par les KK valeurs de la colonne bb de GG^\top, situées aux lignes iBK[b,:]i_{BK}[b, :]. On ajoute ces KK lignes à GWG_W aux positions iBK[b,:]i_{BK}[b, :].
Kernel SpMM scatter
On démarre BB programmes, chacun s'occupant d'un bb. On peut même subdiviser manuellement le travail selon DD, en démarrant une grille B×BSDB \times \bbB\bbS_D où chaque programme s'occupe d'un bb et de kBSDd<(k+1)BSDk\bbB\bbS_D \leq d < (k+1)\bbB\bbS_D. Un des avantages de cette approche est que les accès mémoire sont contigus : on lit xBDx_{BD} ligne par ligne et (Gz)(G_z)^\top colonne par colonne. Le défaut est cependant la présence d'additions atomiques à GWG_W : une même caractéristique hh peut être active dans plusieurs tokens donc plusieurs programmes peuvent ajouter aux mêmes positions. On a ici un kernel « scatter ».
Cette situation est très similaire à celle de vouloir aggréger les sorties de différents experts d'une MoE : un même token n'active que quelques experts (caractère sparse), et plusieurs experts peuvent contribuer au même token.

Bibliographie

© Rémy SIAHAAN–GENSOLLEN, 2026
remy-siahaan.com