Awesome-omni-skill JAX

Essential tools for using JAX in machine learning and mathematical analysis, covering core concepts, transformations, ML specifics, control flow, and parallelism.

install
source · Clone the upstream repo
git clone https://github.com/diegosouzapw/awesome-omni-skill
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/diegosouzapw/awesome-omni-skill "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/data-ai/jax" ~/.claude/skills/diegosouzapw-awesome-omni-skill-jax && rm -rf "$T"
manifest: skills/data-ai/jax/SKILL.md
source content

JAX Skill

JAX is Autograd and XLA, brought together for high-performance machine learning research.

Contents

  • Concepts & Theory
    • Immutability
    • The 4 Transformations
    • Pytrees
  • Code Examples
    • jit
      ,
      grad
      ,
      vmap
      ,
      random
      usage
    • Control Flow (
      scan
      ,
      cond
      ,
      fori_loop
      )
    • Parallelism (
      sharding
      )

Common Workflows

1. Developing a new Model

  1. Define your parameters as a Pytree (dict/dataclass).
  2. Define your forward pass function (pure).
  3. Define your loss function.
  4. Use
    jax.value_and_grad
    to get gradients.
  5. Use
    jax.jit
    to speed up the update step.
  6. See examples.md for snippets.

2. Debugging Shapes/NaNs

  1. Disable JIT:
    jax.config.update("jax_disable_jit", True)
    to debug with standard python tools.
  2. Use
    jax.debug.print
    inside JITted functions.