MLX

// Apple's array framework for Apple Silicon — tensors, autograd, training, and on-device LLMs (with Gemma 3 4B walkthrough)

What MLX is

MLX is Apple's array framework for Apple Silicon, written in C++ with idiomatic Python and Swift bindings. It's PyTorch-shaped on the surface — arrays, autograd, neural-network modules, optimizers — but designed top-to-bottom for unified memory: arrays live in one address space accessed by both CPU and GPU with no .to(device) dance. Lazy evaluation lets the compiler fuse ops; Metal kernels run on the GPU; CPU fallback is automatic.

For LLMs, the practical entry point is mlx-lm — a separate package on top of MLX that handles tokenizers, model loaders, generation loops, quantization, and LoRA fine-tuning. Most users start here, drop down to raw MLX only when they need to.

Install

# core framework
pip install mlx

# LLM stack on top — what you actually use day-to-day
pip install mlx-lm

# vision-language models (Gemma 3 multimodal, Llava, etc.)
pip install mlx-vlm

# quick verify
python -c "import mlx.core as mx; print(mx.array([1,2,3]) * 2)"
# → array([2, 4, 6], dtype=int32)

Apple Silicon only. MLX requires M1 or later. Intel Macs get nothing — there's no fallback target. macOS 13.5+, Python 3.9+.

Core concepts

conceptwhat it means
Unified memoryOne address space. CPU and GPU read the same arrays. No x.cuda() / x.cpu() shuffle.
Lazy evaluationOperations build a graph; computation defers until you call mx.eval(x) or read a value. Lets the compiler fuse ops and elide intermediates.
Composable transformsmx.grad, mx.value_and_grad, mx.vmap, mx.compile compose like JAX — wrap a Python fn, get the transformed fn back.
Multi-deviceCPU and GPU streams (mx.cpu / mx.gpu). Most ops auto-route; explicit placement via mx.set_default_device().
NumPy-shaped APIFamiliar names — mx.array, mx.matmul, mx.softmax — but lazy and Metal-backed.
   numpy        torch        jax         mlx
   ─────        ─────        ───         ───
   eager        eager        lazy        lazy
   cpu          cpu/cuda     cpu/gpu     unified
   no grad      autograd     transform   transform
   −            tensors      arrays      arrays
                .cuda()      jit()       compile()
                                         metal kernels

Arrays & ops

import mlx.core as mx

x = mx.array([[1, 2], [3, 4]], dtype=mx.float32)
y = mx.ones((2, 2))
z = x @ y                         # matmul, lazy

mx.eval(z)                        # force compute
print(z)
# array([[3, 3], [7, 7]], dtype=float32)

# dtype control matters — bf16 / float16 halve memory + speed up matmul
w = mx.random.normal((1024, 1024), dtype=mx.bfloat16)

# transforms — autograd
loss_fn = lambda w, x, y: ((x @ w - y) ** 2).mean()
grad_fn = mx.grad(loss_fn)
g = grad_fn(w, x, y)              # gradient w.r.t. first arg

Neural-network modules

import mlx.core as mx
import mlx.nn   as nn

class MLP(nn.Module):
    def __init__(self, in_dim: int, hidden: int, out_dim: int):
        super().__init__()
        self.l1 = nn.Linear(in_dim, hidden)
        self.l2 = nn.Linear(hidden, out_dim)

    def __call__(self, x):
        return self.l2(nn.relu(self.l1(x)))

model = MLP(784, 256, 10)
mx.eval(model.parameters())                # materialize params on first use
out = model(mx.random.normal((32, 784)))   # forward pass

Optimizers

import mlx.optimizers as optim

opt = optim.AdamW(learning_rate=1e-4, weight_decay=0.01)

loss_and_grad = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad(model, x_batch, y_batch)
opt.update(model, grads)
mx.eval(model.parameters(), opt.state)     # materialize after step

Training loop

for epoch in range(EPOCHS):
    for x_batch, y_batch in loader:
        loss, grads = loss_and_grad(model, x_batch, y_batch)
        opt.update(model, grads)
        mx.eval(model.parameters(), opt.state)   # break the lazy graph each step
    print(f"epoch {epoch}: loss={loss.item():.4f}")

Always mx.eval(...) after each step. Lazy evaluation will happily build a graph that spans your whole epoch. Without periodic eval, memory blows up and gradients accumulate compute graph debt. Put it after every optimizer step.

mlx-lm — the LLM stack

mlx-lm wraps tokenizer + model + sampler + cache for dozens of open-weight architectures: Llama, Mistral, Qwen, Phi, DeepSeek, Gemma, Phi-mini, OLMo, etc. One load() + one generate() covers most use cases.

from mlx_lm import load, generate, stream_generate

model, tok = load("mlx-community/gemma-3-4b-it-bf16")

# one-shot
text = generate(model, tok,
    prompt="Explain unified memory in one sentence.",
    max_tokens=128, temp=0.7)
print(text)

# streaming
for chunk in stream_generate(model, tok, prompt="Hello", max_tokens=64):
    print(chunk.text, end="", flush=True)

Gemma 3 4B — end-to-end on Apple Silicon

Google's Gemma 3 family (270M / 1B / 4B / 12B / 27B) hits a sweet spot for local inference on Macs with the 4B variant — fits comfortably in 8 GB unified memory at bf16, fast on M2/M3/M4. Multimodal variants (gemma-3-4b-it includes vision) work via mlx-vlm.

1. Pull the MLX-converted weights

# the mlx-community on HF hosts pre-converted MLX weights for most popular models
# bf16 — full quality, ~8 GB
mlx_lm.generate \
  --model mlx-community/gemma-3-4b-it-bf16 \
  --prompt "List 3 surprising uses of unified memory." \
  --max-tokens 256

# 4-bit quantized — ~2.5 GB, ~3-4× faster, slight quality drop
mlx_lm.generate \
  --model mlx-community/gemma-3-4b-it-4bit \
  --prompt "..."

2. Programmatic generation with chat template

from mlx_lm import load, generate

model, tok = load("mlx-community/gemma-3-4b-it-bf16")

messages = [
    {"role": "system", "content": "You are a precise assistant."},
    {"role": "user",   "content": "Why is bf16 a good default for inference?"},
]
prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
out = generate(model, tok, prompt=prompt, max_tokens=512, temp=0.7)
print(out)

3. Throughput knobs

knobeffect
temp · top_p · top_kSampling shape. temp=0 = greedy.
repetition_penaltyReduces loops. 1.1 is a safe baseline.
max_kv_sizeKV cache cap (tokens). Sets a memory ceiling for long contexts.
kv_bits=8Quantize the KV cache for longer contexts. ~2× context for ~5% quality cost.
prompt_cacheReuse KV across calls when prompts share a prefix (system prompts, RAG context).

Quantization

mlx-lm ships mlx_lm.convert for converting + quantizing any HF model into MLX format. 4-bit Q is the standard for local inference; 8-bit when 4-bit hurts your task too much.

# take a HF model, convert + 4-bit quantize, save locally
mlx_lm.convert \
  --hf-path google/gemma-3-4b-it \
  --mlx-path ./gemma-3-4b-it-4bit \
  --quantize \
  --q-bits 4 \
  --q-group-size 64

# or push the converted weights back to HF (requires write token)
mlx_lm.convert \
  --hf-path google/gemma-3-4b-it \
  --mlx-path gemma-3-4b-it-4bit \
  --upload-repo my-username/gemma-3-4b-it-mlx-4bit \
  --quantize

Quantization-aware caveat. Group quantization (q-group-size 64) is the standard. Smaller groups = better fidelity, larger files. For Gemma 3 4B, --q-bits 4 --q-group-size 64 is the sweet spot — quality stays high enough for production.

LoRA fine-tune

mlx-lm supports LoRA + DoRA fine-tuning out of the box. On a 16 GB M-series, you can fine-tune Gemma 3 4B in bf16 in an afternoon on a few thousand examples.

Data format — JSONL

// data/train.jsonl — one example per line
{"messages": [{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
{"messages": [...]}

# or text-completion style
{"text": "prompt → completion all in one string"}

Train

mlx_lm.lora \
  --model mlx-community/gemma-3-4b-it-bf16 \
  --train \
  --data ./data \
  --iters 1000 \
  --batch-size 4 \
  --lora-layers 16 \
  --learning-rate 1e-4 \
  --adapter-path ./adapters

Generate with adapter

mlx_lm.generate \
  --model mlx-community/gemma-3-4b-it-bf16 \
  --adapter-path ./adapters \
  --prompt "..."

Fuse adapter into the base — ship as one model

mlx_lm.fuse \
  --model mlx-community/gemma-3-4b-it-bf16 \
  --adapter-path ./adapters \
  --save-path ./gemma-3-4b-tuned

Serve as an OpenAI-compatible HTTP server

mlx_lm.server \
  --model mlx-community/gemma-3-4b-it-4bit \
  --host 0.0.0.0 --port 8080

# now usable from anything that speaks OpenAI
curl http://localhost:8080/v1/chat/completions \
  -H "content-type: application/json" \
  -d '{
    "model": "default",
    "messages": [{"role":"user","content":"hi"}],
    "max_tokens": 64
  }'

mlx-vlm — vision-language

Same shape, but with images. Gemma 3 4B's -it instruction-tuned variant is multimodal; mlx-vlm handles the image preprocessing.

pip install mlx-vlm

from mlx_vlm import load, generate
from mlx_vlm.prompt_utils import apply_chat_template

model, processor = load("mlx-community/gemma-3-4b-it-4bit")
prompt = apply_chat_template(processor, model.config,
    ["<describe what you see>"], num_images=1)

out = generate(model, processor, prompt, image=["./photo.jpg"], max_tokens=200)
print(out)

Integrate with this project

Concrete plays for the Organized AI / clip-pipeline setup:

  1. Replace cloud LLM with local Gemma 3 4B. Stand up mlx_lm.server on claw, point the autoresearch loops' OPENAI_BASE_URL at it. Free token throughput; offline-capable; latency falls from hundreds of ms to tens.
  2. Caption polish on-device. The caption-quality-boost skill currently calls Claude for proper-noun cleanup. A LoRA-fine-tuned Gemma on your past transcripts handles ~95% of cases locally.
  3. Vision moment scoring. mlx-vlm + Gemma 3 4B-multimodal on still frames from a long-form recording. Replace TwelveLabs for the basic "is this a high-energy moment" classification on cost-sensitive sessions.
  4. Speaker diarization assist. Pyannote does the audio diarization; Gemma reads the transcript and proposes speaker→identity mappings as a sanity layer (the LLM-bootstrap-speakers skill is already this shape — just swap the model).
  5. Token Machine routing. Token Machine's "local can clear the bar" path is exactly an mlx-lm endpoint. Wire claw's local Gemma as one of the routed backends; PostHog's $ai_generation events still flow.

Pitfalls