Awesome-omni-skill JAX
Essential tools for using JAX in machine learning and mathematical analysis, covering core concepts, transformations, ML specifics, control flow, and parallelism.
install
source · Clone the upstream repo
git clone https://github.com/diegosouzapw/awesome-omni-skill
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/diegosouzapw/awesome-omni-skill "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/data-ai/jax" ~/.claude/skills/diegosouzapw-awesome-omni-skill-jax && rm -rf "$T"
manifest:
skills/data-ai/jax/SKILL.mdsource content
JAX Skill
JAX is Autograd and XLA, brought together for high-performance machine learning research.
Contents
- Concepts & Theory
- Immutability
- The 4 Transformations
- Pytrees
- Code Examples
,jit
,grad
,vmap
usagerandom- Control Flow (
,scan
,cond
)fori_loop - Parallelism (
)sharding
Common Workflows
1. Developing a new Model
- Define your parameters as a Pytree (dict/dataclass).
- Define your forward pass function (pure).
- Define your loss function.
- Use
to get gradients.jax.value_and_grad - Use
to speed up the update step.jax.jit - See examples.md for snippets.
2. Debugging Shapes/NaNs
- Disable JIT:
to debug with standard python tools.jax.config.update("jax_disable_jit", True) - Use
inside JITted functions.jax.debug.print