Asi kscale-ksim

K-Scale ksim Skill

install
source · Clone the upstream repo
git clone https://github.com/plurigrid/asi
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/plurigrid/asi "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/kscale-ksim" ~/.claude/skills/plurigrid-asi-kscale-ksim && rm -rf "$T"
manifest: skills/kscale-ksim/SKILL.md
source content

K-Scale ksim Skill

"RL training library for humanoid locomotion and manipulation. Built on MuJoCo and JAX."

Trigger Conditions

  • User asks about robot simulation, humanoid locomotion, or RL policy training
  • Questions about MuJoCo/MJX, JAX-based physics simulation
  • Training walking/manipulation policies for humanoid robots
  • Sim2Real transfer, domain randomization, curriculum learning

Overview

ksim is K-Scale Labs' modular reinforcement learning framework for training robot control policies. It provides:

  1. Physics Engines: MuJoCo (CPU) and MJX (GPU/JAX-native)
  2. Observation System: Stateless and stateful observations with noise injection
  3. Reward Functions: Composable, curriculum-scaled reward components
  4. Action Processing: Latency modeling, actuator dynamics

Architecture

┌─────────────────────────────────────────────────────────────────────────┐
│  ksim Control Loop                                                       │
│                                                                          │
│  ┌──────────────┐  step   ┌──────────────┐  observe  ┌──────────────┐   │
│  │ PhysicsState │────────▶│   Trajectory │──────────▶│  Observation │   │
│  │   (MJX/MJ)   │         │   (batched)  │           │   (noisy)    │   │
│  └──────┬───────┘         └──────┬───────┘           └──────────────┘   │
│         │                        │                                       │
│         │ reset                  │ reward                                │
│         ▼                        ▼                                       │
│  ┌──────────────┐         ┌──────────────┐                              │
│  │    Engine    │         │  RewardState │                              │
│  │ (JIT-compiled)│        │  (component) │                              │
│  └──────────────┘         └──────────────┘                              │
│         │                                                                │
│         │ actuate                                                        │
│         ▼                                                                │
│  ┌──────────────┐                                                       │
│  │    Action    │◀─────── policy(observation)                           │
│  │  (latency)   │                                                       │
│  └──────────────┘                                                       │
└─────────────────────────────────────────────────────────────────────────┘

Module Structure

ModulePurposeKey Classes
engine.py
Physics stepping
MjxEngine
,
MujocoEngine
,
PhysicsEngine
observation.py
State extraction
Observation
,
StatefulObservation
rewards.py
Reward computation
Reward
,
StatefulReward
actions.py
Action processing
Action
, latency buffers
actuators.py
Motor dynamics
Actuators
,
PositionActuators
types.py
Core dataclasses
PhysicsState
,
Trajectory
,
RewardState
curriculum.py
Training progression
Scale
, curriculum schedules
terminations.py
Episode endingTermination conditions
resets.py
State initializationReset distributions
randomization.py
Domain randomizationParameter perturbations

Behavior Type Taxonomy

Tree-Sitter Lifted Type Hierarchy

RL_BEHAVIOR_INTERFACE (Root)
├─ SIMULATION_BEHAVIOR
│  ├─ PHYSICS_STEP: dt × Action → State
│  ├─ RESET: InitConfig → State
│  ├─ RENDER: State → Visualization
│  └─ BATCH_VECTORIZATION: 1 Scene → N Scenes (parallel)
│
├─ POLICY_BEHAVIOR  
│  ├─ ACTOR: (Obs, LSTM_Carry) → (Distribution, LSTM_Carry)
│  ├─ CRITIC: (Obs, LSTM_Carry) → (Value, LSTM_Carry)
│  └─ ACTION_SAMPLING: Distribution → Action (JAX random)
│
├─ REWARD_BEHAVIOR
│  ├─ PENALTY_COMPOSITION: [Penalty] → Scalar
│  ├─ JOINT_DEVIATION: Physics → Scalar
│  └─ POSTURE_CONSTRAINT: Physics → Scalar
│
├─ TRAINING_BEHAVIOR
│  ├─ TRAJECTORY_COLLECTION: Action × Physics → Experience
│  ├─ GRADIENT_COMPUTATION: Trajectory → Gradients (PPO)
│  ├─ MODEL_UPDATE: Gradients → Model′
│  └─ CHECKPOINT: Model → Disk
│
├─ CONFIGURATION_BEHAVIOR
│  ├─ HYPERPARAMETER_SPEC: Type-safe declarative
│  └─ ENV_FACTORY: Config → (Model, Physics, Task)
│
└─ RECURRENCE_BEHAVIOR
   ├─ LSTM_STATE_CARRY: Array → LSTM_Carry
   ├─ STATE_RESET: () → LSTM_Carry
   └─ STATE_EVOLUTION: (Carry, Obs) → Carry′

Type Signature Contracts

BehaviorInput TypeOutput TypeJAX/Equinox Traits
Actor.forward
(Array[B,O], Array[H])
(Distribution, Array[H])
PyTree, JIT-compiled
Critic.forward
(Array[B,O], Array[H])
(Array[B,1], Array[H])
PyTree, Differentiable
step(action)
Action: Array[B,A]
State: PhysicsModel
Vectorized, MJX batch
get_rewards()
PhysicsModel
Array[B,1]
JAX-pure function
sample_action()
(Model, PhysicsModel)
Action: Array[B,A]
Random keyed, PRNGKey
Config.__init__()
Keyword args
Config: dataclass
Immutable, type-checked

Stateless Behaviors (Pure Functions)

# Observation: PhysicsState → Array
class BasePositionObservation(Observation):
    def observe(self, state: PhysicsState) -> Array:
        return state.data.qpos[0:3]

# Reward: Trajectory → Array  
class BaseHeightReward(Reward):
    def get_reward(self, trajectory: Trajectory) -> Array:
        height = trajectory.qpos[:, 2]
        return jnp.exp(-((height - self.target) ** 2) / (2 * self.scale ** 2))

Stateful Behaviors (With Carry)

# StatefulObservation: (PhysicsState, Carry) → (Array, Carry)
class DelayedJointPositionObservation(StatefulObservation):
    def observe_stateful(self, state, carry):
        # Ring buffer for action latency simulation
        new_carry = jnp.roll(carry, 1, axis=0)
        new_carry = new_carry.at[0].set(state.data.qpos[7:])
        return carry[-1], new_carry

# StatefulReward: (Trajectory, Carry) → (Array, Carry)
class FeetAirTimeReward(StatefulReward):
    def get_reward_stateful(self, trajectory, carry):
        # Track contact state over time
        ...

Neural Network Behavioral Contracts (Equinox)

class Model(eqx.Module):
    actor: Actor   # Stochastic Policy Behavior
    critic: Critic # Value Estimation Behavior

class Actor(eqx.Module):
    """Behavioral Contract: (Obs, LSTM_Carry) → (Distribution, LSTM_Carry)"""
    def forward(self, obs_n: Array, carry: Array) -> tuple[Distribution, Array]:
        ...

class Critic(eqx.Module):
    """Behavioral Contract: (Obs, LSTM_Carry) → (Value, LSTM_Carry)"""
    def forward(self, obs_n: Array, carry: Array) -> tuple[Array, Array]:
        ...

Key Patterns

1. JIT Compilation with Equinox

@eqx.filter_jit
def step(self, action, physics_model, physics_state, curriculum_level, rng):
    # Efficient GPU execution via JAX tracing
    ...

2. Exponential Kernel Rewards

def exp_kernel(x, scale):
    return jnp.exp(-(x ** 2) / (2 * scale ** 2))

3. Curriculum Scaling

class Scale:
    def __call__(self, curriculum_level: Array) -> Array:
        # Modulate reward/observation based on training progress
        ...

GF(3) Trit Assignment

Trit: 0 (ERGODIC)
Role: Infrastructure/Coordination
Color: #25BC3D
URI: skill://kscale-ksim#25BC3D

Balanced Triads

kscale-ksim (0) ⊗ kscale-kos (-1) ⊗ gym (+1) = 0 ✓
kscale-ksim (0) ⊗ jax-rl (-1) ⊗ mujoco-playground (+1) = 0 ✓

Related Skills

  • kscale-kos
    : K-Scale Operating System (firmware layer)
  • kscale-kinfer
    : Model inference engine
  • kscale-urdf
    : Robot description conversion
  • gym
    : OpenAI Gym environments
  • jax
    : JAX numerical computing

Key Contributors (Cognitive Superposition)

ContributorFocus AreasCommits
codekansas (Ben Bolte)Architecture, rewards, training1475+
b-vmRandomization, disturbances500+
WT-MM (Wesley Maa)Tooling, visualization300+
alik-git (Ali Kuwajerwala)Integration, testing200+

Commands

# Install ksim
pip install ksim

# Train a walking policy (RTX 4090: ~30 min for 80 steps)
python -m ksim.train --config configs/kbot_walk.yaml

# Visualize trained policy
python -m ksim.vis --checkpoint path/to/model.ckpt

References

Narya Compatibility (Structure-Aware Diffing)

FieldDefinition
before
PhysicsState
at timestep t (qpos, qvel, control)
after
PhysicsState
at timestep t+1 after action execution
delta
Trajectory
segment: the action taken + reward received
birth
Initial
PhysicsState
from
reset()
with domain randomization
impact
1 if episode terminated (fall, out-of-bounds), 0 otherwise

Behavior Type Diffing

@dataclass
class KsimNaryaEvent:
    """Structure-aware diff for ksim state transitions."""
    event_id: str
    before: PhysicsState      # State before action
    after: PhysicsState       # State after action
    delta: TrajectorySegment  # Action + reward + info
    trit: int                 # GF(3): -1=penalty, 0=neutral, +1=reward
    
    @property
    def impact(self) -> int:
        """1 if state change is significant (termination/reset)."""
        return 1 if self.delta.done else 0
    
    def to_jsonl(self) -> str:
        return json.dumps({
            "event_id": self.event_id,
            "before_hash": hash_state(self.before),
            "after_hash": hash_state(self.after),
            "delta": {"action": self.delta.action.tolist(),
                      "reward": float(self.delta.reward)},
            "trit": self.trit,
            "impact": self.impact
        })

Replay Determinism

# Same seed → same trajectory (critical for sim2real debugging)
def replay_episode(seed: int, policy: Model) -> list[KsimNaryaEvent]:
    rng = jax.random.PRNGKey(seed)
    state = env.reset(rng)  # birth
    events = []
    
    for t in range(max_steps):
        rng, action_rng = jax.random.split(rng)
        action = policy.sample(state.obs, action_rng)
        
        before = state
        state, reward, done, info = env.step(action)
        
        events.append(KsimNaryaEvent(
            event_id=f"step_{t}",
            before=before,
            after=state,
            delta=TrajectorySegment(action, reward, done, info),
            trit=sign(reward)  # -1, 0, +1
        ))
        
        if done:
            break
    
    return events

ACSet Schema

@present SchKsim(FreeSchema) begin
    PhysicsState::Ob
    Trajectory::Ob
    Observation::Ob
    Reward::Ob
    Action::Ob
    
    step::Hom(Action, PhysicsState)
    observe::Hom(PhysicsState, Observation)
    reward::Hom(Trajectory, Reward)
    
    StateData::AttrType
    qpos::Attr(PhysicsState, StateData)
    qvel::Attr(PhysicsState, StateData)
end