SMolLM: Small language models learn small molecular grammar
1 SMolLM trains a 53K-parameter weight-shared transformer (one Modern GPT block reused for multiple passes) to generate SMILES with 95.3% RDKit validity on ZINC-250K, beating an identically trained unshared GPT with 10× more parameters (GPT-527K: 87.6% validity).
2 The key idea is trading parameters for iterative computation: for each next-token prediction, the same block is applied repeatedly (virtual depth). Validity climbs sharply with more passes and plateaus around 8 passes (e.g., WS-206K: 33% at D=1 → 90.5% at D=2 → 97.1% at D=4 → 98.6% at D=8).
3 Weight sharing dominates the sub-megaparameter parameter–validity Pareto frontier: across <1M parameters, WS models consistently achieve higher validity than unshared GPT baselines at similar or larger sizes (e.g., WS-53K reaches 95.3%; much larger unshared GPTs are needed to approach that level).
4 Mechanistic result: the same shared block resolves SMILES constraints in a consistent, staged order across passes—brackets first, ring closures second, valence last. This is quantified using exact symbolic checks for each constraint, enabling per-pass “grammar progress” measurement.
5 Output-level error classification shows the stage ordering directly: bracket errors collapse by pass 2 (e.g., 23.0% → 1.3% from D=1 to D=2), ring errors drop next (notably from pass 2 to pass 4), and valence violations are the last major residual class (improving mainly from pass 4 to pass 8).
6 Internal representations mirror the same hierarchy. Linear probes decode bracket depth earlier than ring state (WS-206K: bracket-depth probe peaks at pass 2; ring-state probe peaks at pass 5). Sparse autoencoders (SAEs) recover early token-level detectors (bracket / ring-digit), followed by later compositional state features (bracket depth, ring state), and then more atom-level identity features.
7 Causal localization: a single attention head is responsible for the earliest bracket-matching step. Ablating that head at pass 1 selectively increases bracket errors (WS-206K: 22 percentage points bracket errors; validity −21 pp) while leaving ring/valence errors near baseline—supporting a specific circuit rather than general sensitivity.
8 Capacity mainly changes how many passes are needed, not the ordering. The smaller WS-53K reaches similar milestones ~1–2 passes later and relies on the bracket head across more early passes, suggesting iterative refinement compensates for limited per-pass capacity.
9 Robustness checks: benefits persist across decoding settings (temperature/top-k/top-p) and are not explained by extra training of unshared baselines. Interestingly, standard post-training methods (offline distillation, DPO) reduce validity, implying that preserving per-pass iterative computation may require different objectives than typical output-only alignment.
💻Code:
github.com/akhljndl/smollm
📜Paper:
arxiv.org/abs/2605.06322
#ComputationalBiology #Cheminformatics #MolecularGeneration #SMILES #Transformers #MechanisticInterpretability #WeightSharing #SparseAutoencoders #DrugDiscovery #MachineLearning