YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

Self-Healing Training System (SHTS)

Fully autonomous debugging and error recovery for Hugging Face TRL trainers. Add one callback, wrap with SelfHealingTrainer, and cut debugging costs to near zero.

License: MIT HF Hub


The Problem

ML training fails constantly:

  • CUDA OOM kills jobs at step 847/1000 — restart from scratch
  • NaN loss silently corrupts models — discovered hours later
  • Loss spikes cascade into divergence — manual intervention required
  • DPO plateau at 0.693 loss (= random chance) — wasted GPU hours
  • No postmortem — "what step did it die on?"

Each failure costs developer time + GPU credits + schedule delay. At scale, this is millions in wasted compute.

The Solution

SHTS wraps any Hugging Face TRL trainer with four autonomous layers:

┌─────────────────────────────────────────┐
│  LAYER 4: ORCHESTRATION                 │
│  SelfHealingTrainer retry loop          │
│  while not converged: try → recover     │
├─────────────────────────────────────────┤
│  LAYER 3: RECOVERY                      │
│  HealingActions: rollback, halve LR,    │
│  halve batch, reclip, clear cache       │
├─────────────────────────────────────────┤
│  LAYER 2: DIAGNOSIS                     │
│  Root-cause classifier: NaN/divergence/ │
│  OOM/data/API — with literature refs    │
├─────────────────────────────────────────┤
│  LAYER 1: DETECTION                     │
│  SelfHealingCallback: loss, gradients,  │
│  memory, ZClip adaptive clipping        │
└─────────────────────────────────────────┘

Quick Start

pip install git+https://huggingface.co/ScottzillaSystems/self-healing-training
from self_healing import SelfHealingTrainer, HealingConfig
from trl import SFTTrainer, SFTConfig

# Your normal training setup
trainer = SFTTrainer(
    model=model,
    args=SFTConfig(
        output_dir="./output",
        learning_rate=2e-5,
        per_device_train_batch_size=4,
    ),
    train_dataset=dataset,
    tokenizer=tokenizer,
)

# Wrap with self-healing — that's it!
sh = SelfHealingTrainer(
    trainer,
    HealingConfig(
        max_recovery_attempts=5,
        zclip_enabled=True,
    ),
)

# Optional: dry-run to catch config errors before full training
sh.dry_run(num_steps=2)

# Train with full autonomy
result = sh.train()

What Handles What

Failure Detection Recovery Paper
NaN loss math.isnan(loss) after each step Rollback → halve LR → enable grad clip ZClip arxiv:2504.02507
CUDA OOM on_exception catches OutOfMemoryError Halve batch (preserve effective via GA) → gradient checkpointing → clear cache Unicron arxiv:2401.00134
Loss spike Loss > 5× running mean over window ZClip adaptive gradient clipping → emergency checkpoint ZClip arxiv:2504.02507
Divergence Loss increasing for N consecutive steps Rollback → halve LR Pioneer Agent arxiv:2604.09791
Gradient explosion grad_norm > 100 ZClip → enable max_grad_norm=1.0 AdaGC arxiv:2502.11034
DPO plateau loss ≈ 0.693 (random chance) Increase LR 2-5× → check data quality Rafailov et al. (2023)
Overfitting eval_loss - train_loss > 2.0 Alert with actionable recommendation Standard practice
API errors Exception with "api/network/timeout" Exponential backoff (30s → 60s → 120s → ...) Standard pattern
Data errors Exception with "shape/dimension/index" Skip batch → log bad sample Deep Researcher arxiv:2604.05854
Crash postmortem Always postmortem.json with exit reason, last step, metrics, recovery history PTT pattern

Crash Postmortem

Every training interruption produces a postmortem.json:

{
  "exit_reason": "exception",
  "exception_type": "OutOfMemoryError",
  "last_step": 847,
  "timestamp": "2026-04-30T15:26:04Z",
  "final_metrics": {"loss": 2.15, "grad_norm": 42.3},
  "recovery_actions": [
    {
      "failure": "oom",
      "diagnosis": "CUDA Out of Memory. Batch size exceeds GPU capacity.",
      "actions": ["halve_batch_size", "enable_gradient_checkpointing", "clear_cache"]
    }
  ],
  "running_time_seconds": 1847.3
}

Trackio Integration

Set report_to="trackio" in your training args. SHTS emits:

  • Alerts at every decision point (INFO/WARN/ERROR)
  • Metrics: healing/recovery_attempts, healing/nan_count, healing/loss_spike_ratio, healing/eval_gap
  • ZClip metrics: zclip/raw_grad_norm, zclip/clipped_grad_norm, zclip/z_score, zclip/total_clips

Dashboard URL: https://huggingface.co/spaces/<username>/<trackio-space>

HealingConfig Presets

# Aggressive — for unstable training, low tolerance
config = HealingConfig.aggressive()
# nan_patience=1, zclip_z_threshold=2.0, max_recovery_attempts=10

# Conservative — only intervene on clear failures
config = HealingConfig.conservative()
# nan_patience=10, loss_spike_factor=10.0, zclip_z_threshold=4.0, max_recovery_attempts=2

# Custom
config = HealingConfig(
    nan_patience=5,
    loss_spike_factor=8.0,
    divergence_patience=100,
    max_recovery_attempts=3,
    zclip_enabled=True,
    zclip_z_threshold=3.0,
)

Compatibility

Trainer Status Notes
SFTTrainer (TRL) ✅ Full All metrics captured
DPOTrainer (TRL) ✅ Full DPO plateau detection (loss≈0.693)
GRPOTrainer (TRL) ✅ Full Group reward monitoring
PPOTrainer (TRL) ✅ Full KL divergence tracking
ORPOTrainer (TRL) ✅ Full Odds ratio monitoring
KTOTrainer (TRL) ✅ Full Desirable/undesirable logps
CPOTrainer (TRL) ✅ Full Contrastive preference
Trainer (Transformers) ✅ Full Standard ML training

Architecture

SelfHealingTrainer.train()
  │
  ├── dry_run()              ← Validate setup first
  │
  └── while not converged:
      │
      ├── trainer.train()    ← Run training
      │     │
      │     ├── on_step_end  ← Detect NaN, spikes, divergence
      │     ├── on_log       ← Monitor gradients (ZClip)
      │     ├── on_evaluate  ← Check overfitting
      │     └── on_exception ← Catch OOM, API, data errors
      │
      ├── [recovery needed?]
      │     ├── diagnose     ← Classify failure type
      │     ├── heal         ← Apply recovery actions
      │     └── retry        ← resume_from_checkpoint=True
      │
      └── [converged]        ← Done!

References

Paper ID Contribution
Unicron arxiv:2401.00134 Cost-aware self-healing at cluster scale, error taxonomy (4 types), elastic scaling
ZClip arxiv:2504.02507 Z-score adaptive gradient clipping, eliminates catastrophic loss spikes
AdaGC arxiv:2502.11034 Per-tensor adaptive gradient clipping, optimizer-agnostic
Pioneer Agent arxiv:2604.09791 Structured decision tree by score buckets for autonomous iteration
Deep Researcher arxiv:2604.05854 Dry-run validation, zero-cost monitoring, constant-size memory
CheckFree arxiv:2506.15461 Pipeline-parallel recovery via neighbor averaging
DPO Rafailov et al. (2023) DPO plateau at 0.693 = random chance (Section 4.2)
PTT post-training-toolkit DiagnosticsCallback + postmortem pattern

License

MIT — use freely, attribution appreciated.


Built autonomously by ML Intern. Questions? Open an issue on the Hub.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Papers for ScottzillaSystems/self-healing-training