DiffSwing: Neural Cart-Pole Control with JAX & MuJoCo
A hybrid control system that combines neural network energy shaping with classical LQR stabilization for cart-pole swing-up. Trained end-to-end in differentiable JAX simulation with real-time MuJoCo deployment achieving 98% success rate and 1.9s swing-up time.
Overview
DiffSwing is a hybrid control system that solves the classic cart-pole swing-up problem by combining:
- Neural energy shaping for robust swing-up from arbitrary initial conditions
 - Classical LQR control for precise stabilization around the upright position
 - End-to-end differentiable training in JAX for sample-efficient policy learning
 
Key Achievements
- 98% success rate across wide initial pose/velocity distributions
 - 1.9s average swing-up time from pendant position
 - Real-time deployment at 1 kHz in MuJoCo 3.x
 - Training efficiency: < 5 minutes on CPU using analytical gradients
 
 Real-time MuJoCo simulation during neural network control phase
Technical Approach
System Dynamics
The cart-pole system is parameterized as:
| Parameter | Description | Value | 
|---|---|---|
| M | Cart mass | 1.0 kg | 
| m | Pole mass | 0.1 kg | 
| l | Pole half-length | 0.5 m | 
| g | Gravitational acceleration | 9.81 m/s² | 
State representation: [x, cos(θ), sin(θ), ẋ, θ̇] — trigonometric encoding ensures smooth gradients across angle wrapping.
Hybrid Control Architecture
| Controller | Purpose | Implementation | Parameters | 
|---|---|---|---|
| Neural MLP | Energy shaping & swing-up | 2×64 hidden layers, tanh | ~9k parameters | 
| LQR | Upright stabilization | Riccati solution | Q, R matrices | 
| Linear | Baseline comparison | Direct state feedback | 4 gains | 
Training Objective
The neural policy optimizes:
\[J = \sum_t \left[ w_E (E(t) - E_{\text{target}})^2 + w_x x(t)^2 + w_u u(t)^2 \right]\]where $E_{\text{target}} = 2mgl$ (energy difference between upright and hanging positions).
Implementation
Software Architecture
DiffSwing/
├── env/
│   └── cartpole.py           # Dynamics & energy functions
├── controller/
│   ├── linear_controller.py  # Differentiable linear control
│   ├── lqr_controller.py     # Riccati-based LQR
│   └── nn_controller.py      # Neural network policy
├── lib/
│   └── trainer.py            # Training loop & curriculum
└── scripts/
    ├── train_nn_controller.py    # Training script
    ├── nn_mujoco.py             # MuJoCo deployment
    └── run_simulation.py        # Controller comparison
Key Technologies
- JAX: Automatic differentiation through dynamics
 - Equinox: Neural network framework
 - Diffrax: ODE integration with gradients
 - MuJoCo: Physics simulation and real-time control
 - Optax: Gradient-based optimization (Adam)
 
Results
Performance Comparison
| Metric | Neural Network | Linear | LQR (stabilization) | 
|---|---|---|---|
| Success Rate (−π…π, ±2 rad/s) | 98% | 82% | N/A | 
| Swing-up Time (mean) | 1.9s | 3.1s | — | 
| Cart Deviation (RMS) | 0.14m | 0.22m | 0.05m | 
| Peak Control Force | 12N | 11N | 6N | 
Control Strategy
-  Phase 1: Neural network performs energy pumping until 
|θ| < 12° - Phase 2: Seamless handoff to LQR controller for fine stabilization
 - Result: Combines the neural network’s robust swing-up with LQR’s optimal stabilization
 
Quick Start
Installation
pip install jax equinox optax diffrax mujoco mujoco-python-viewer matplotlib numpy
Training
# Train neural network controller (~5 minutes)
python scripts/train_nn_controller.py
# Monitor training progress
tensorboard --logdir logs/
Deployment
# Deploy trained model in MuJoCo
python scripts/nn_mujoco.py --model trained_nn.eqx
# Compare different controllers
python scripts/run_simulation.py --controller neural
python scripts/run_simulation.py --controller lqr
python scripts/run_simulation.py --controller linear
Configuration
Key training parameters in config.py:
- Learning rate: 1e-3 (Adam)
 - Batch size: 256 initial states
 - Training steps: 5000
 - Weight schedule: Energy (1.0) → Position (0.1) → Control (0.01)
 
Future Directions
Technical Improvements
- Adaptive handoff: Smooth weighted blending between controllers
 - Model predictive control: Receding-horizon energy shaping layer
 - Domain randomization: Robust policies for sim-to-real transfer
 
Experimental Validation
- Hardware implementation: Low-cost setup with Teensy 4.1 microcontroller
 - Real-world testing: Validation on physical cart-pole system
 - Comparative studies: Benchmarking against other swing-up methods
 
License: MIT — contributions and extensions welcome!