// Apple's array framework for Apple Silicon — tensors, autograd, training, and on-device LLMs (with Gemma 3 4B walkthrough)
MLX is Apple's array framework for Apple Silicon, written in C++ with idiomatic Python and Swift bindings. It's PyTorch-shaped on the surface — arrays, autograd, neural-network modules, optimizers — but designed top-to-bottom for unified memory: arrays live in one address space accessed by both CPU and GPU with no .to(device) dance. Lazy evaluation lets the compiler fuse ops; Metal kernels run on the GPU; CPU fallback is automatic.
For LLMs, the practical entry point is mlx-lm — a separate package on top of MLX that handles tokenizers, model loaders, generation loops, quantization, and LoRA fine-tuning. Most users start here, drop down to raw MLX only when they need to.
# core framework
pip install mlx
# LLM stack on top — what you actually use day-to-day
pip install mlx-lm
# vision-language models (Gemma 3 multimodal, Llava, etc.)
pip install mlx-vlm
# quick verify
python -c "import mlx.core as mx; print(mx.array([1,2,3]) * 2)"
# → array([2, 4, 6], dtype=int32)
Apple Silicon only. MLX requires M1 or later. Intel Macs get nothing — there's no fallback target. macOS 13.5+, Python 3.9+.
| concept | what it means |
|---|---|
| Unified memory | One address space. CPU and GPU read the same arrays. No x.cuda() / x.cpu() shuffle. |
| Lazy evaluation | Operations build a graph; computation defers until you call mx.eval(x) or read a value. Lets the compiler fuse ops and elide intermediates. |
| Composable transforms | mx.grad, mx.value_and_grad, mx.vmap, mx.compile compose like JAX — wrap a Python fn, get the transformed fn back. |
| Multi-device | CPU and GPU streams (mx.cpu / mx.gpu). Most ops auto-route; explicit placement via mx.set_default_device(). |
| NumPy-shaped API | Familiar names — mx.array, mx.matmul, mx.softmax — but lazy and Metal-backed. |
numpy torch jax mlx
───── ───── ─── ───
eager eager lazy lazy
cpu cpu/cuda cpu/gpu unified
no grad autograd transform transform
− tensors arrays arrays
.cuda() jit() compile()
metal kernels
import mlx.core as mx
x = mx.array([[1, 2], [3, 4]], dtype=mx.float32)
y = mx.ones((2, 2))
z = x @ y # matmul, lazy
mx.eval(z) # force compute
print(z)
# array([[3, 3], [7, 7]], dtype=float32)
# dtype control matters — bf16 / float16 halve memory + speed up matmul
w = mx.random.normal((1024, 1024), dtype=mx.bfloat16)
# transforms — autograd
loss_fn = lambda w, x, y: ((x @ w - y) ** 2).mean()
grad_fn = mx.grad(loss_fn)
g = grad_fn(w, x, y) # gradient w.r.t. first arg
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self, in_dim: int, hidden: int, out_dim: int):
super().__init__()
self.l1 = nn.Linear(in_dim, hidden)
self.l2 = nn.Linear(hidden, out_dim)
def __call__(self, x):
return self.l2(nn.relu(self.l1(x)))
model = MLP(784, 256, 10)
mx.eval(model.parameters()) # materialize params on first use
out = model(mx.random.normal((32, 784))) # forward pass
import mlx.optimizers as optim
opt = optim.AdamW(learning_rate=1e-4, weight_decay=0.01)
loss_and_grad = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad(model, x_batch, y_batch)
opt.update(model, grads)
mx.eval(model.parameters(), opt.state) # materialize after step
for epoch in range(EPOCHS):
for x_batch, y_batch in loader:
loss, grads = loss_and_grad(model, x_batch, y_batch)
opt.update(model, grads)
mx.eval(model.parameters(), opt.state) # break the lazy graph each step
print(f"epoch {epoch}: loss={loss.item():.4f}")
Always mx.eval(...) after each step. Lazy evaluation will happily build a graph that spans your whole epoch. Without periodic eval, memory blows up and gradients accumulate compute graph debt. Put it after every optimizer step.
mlx-lm wraps tokenizer + model + sampler + cache for dozens of open-weight architectures: Llama, Mistral, Qwen, Phi, DeepSeek, Gemma, Phi-mini, OLMo, etc. One load() + one generate() covers most use cases.
from mlx_lm import load, generate, stream_generate
model, tok = load("mlx-community/gemma-3-4b-it-bf16")
# one-shot
text = generate(model, tok,
prompt="Explain unified memory in one sentence.",
max_tokens=128, temp=0.7)
print(text)
# streaming
for chunk in stream_generate(model, tok, prompt="Hello", max_tokens=64):
print(chunk.text, end="", flush=True)
Google's Gemma 3 family (270M / 1B / 4B / 12B / 27B) hits a sweet spot for local inference on Macs with the 4B variant — fits comfortably in 8 GB unified memory at bf16, fast on M2/M3/M4. Multimodal variants (gemma-3-4b-it includes vision) work via mlx-vlm.
# the mlx-community on HF hosts pre-converted MLX weights for most popular models
# bf16 — full quality, ~8 GB
mlx_lm.generate \
--model mlx-community/gemma-3-4b-it-bf16 \
--prompt "List 3 surprising uses of unified memory." \
--max-tokens 256
# 4-bit quantized — ~2.5 GB, ~3-4× faster, slight quality drop
mlx_lm.generate \
--model mlx-community/gemma-3-4b-it-4bit \
--prompt "..."
from mlx_lm import load, generate
model, tok = load("mlx-community/gemma-3-4b-it-bf16")
messages = [
{"role": "system", "content": "You are a precise assistant."},
{"role": "user", "content": "Why is bf16 a good default for inference?"},
]
prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
out = generate(model, tok, prompt=prompt, max_tokens=512, temp=0.7)
print(out)
| knob | effect |
|---|---|
temp · top_p · top_k | Sampling shape. temp=0 = greedy. |
repetition_penalty | Reduces loops. 1.1 is a safe baseline. |
max_kv_size | KV cache cap (tokens). Sets a memory ceiling for long contexts. |
kv_bits=8 | Quantize the KV cache for longer contexts. ~2× context for ~5% quality cost. |
prompt_cache | Reuse KV across calls when prompts share a prefix (system prompts, RAG context). |
mlx-lm ships mlx_lm.convert for converting + quantizing any HF model into MLX format. 4-bit Q is the standard for local inference; 8-bit when 4-bit hurts your task too much.
# take a HF model, convert + 4-bit quantize, save locally
mlx_lm.convert \
--hf-path google/gemma-3-4b-it \
--mlx-path ./gemma-3-4b-it-4bit \
--quantize \
--q-bits 4 \
--q-group-size 64
# or push the converted weights back to HF (requires write token)
mlx_lm.convert \
--hf-path google/gemma-3-4b-it \
--mlx-path gemma-3-4b-it-4bit \
--upload-repo my-username/gemma-3-4b-it-mlx-4bit \
--quantize
Quantization-aware caveat. Group quantization (q-group-size 64) is the standard. Smaller groups = better fidelity, larger files. For Gemma 3 4B, --q-bits 4 --q-group-size 64 is the sweet spot — quality stays high enough for production.
mlx-lm supports LoRA + DoRA fine-tuning out of the box. On a 16 GB M-series, you can fine-tune Gemma 3 4B in bf16 in an afternoon on a few thousand examples.
// data/train.jsonl — one example per line
{"messages": [{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
{"messages": [...]}
# or text-completion style
{"text": "prompt → completion all in one string"}
mlx_lm.lora \
--model mlx-community/gemma-3-4b-it-bf16 \
--train \
--data ./data \
--iters 1000 \
--batch-size 4 \
--lora-layers 16 \
--learning-rate 1e-4 \
--adapter-path ./adapters
mlx_lm.generate \
--model mlx-community/gemma-3-4b-it-bf16 \
--adapter-path ./adapters \
--prompt "..."
mlx_lm.fuse \
--model mlx-community/gemma-3-4b-it-bf16 \
--adapter-path ./adapters \
--save-path ./gemma-3-4b-tuned
mlx_lm.server \
--model mlx-community/gemma-3-4b-it-4bit \
--host 0.0.0.0 --port 8080
# now usable from anything that speaks OpenAI
curl http://localhost:8080/v1/chat/completions \
-H "content-type: application/json" \
-d '{
"model": "default",
"messages": [{"role":"user","content":"hi"}],
"max_tokens": 64
}'
Same shape, but with images. Gemma 3 4B's -it instruction-tuned variant is multimodal; mlx-vlm handles the image preprocessing.
pip install mlx-vlm
from mlx_vlm import load, generate
from mlx_vlm.prompt_utils import apply_chat_template
model, processor = load("mlx-community/gemma-3-4b-it-4bit")
prompt = apply_chat_template(processor, model.config,
["<describe what you see>"], num_images=1)
out = generate(model, processor, prompt, image=["./photo.jpg"], max_tokens=200)
print(out)
Concrete plays for the Organized AI / clip-pipeline setup:
mlx_lm.server on claw, point the autoresearch loops' OPENAI_BASE_URL at it. Free token throughput; offline-capable; latency falls from hundreds of ms to tens.$ai_generation events still flow.mx.eval() in training loops = exploding memory. The lazy graph spans every op until you force evaluation. mx.eval(model.parameters(), opt.state) after every step.x.astype(mx.bfloat16) at the model boundary.memory_pressure in Activity Monitor. For a 4B model + KV cache at 8K context, plan ~10 GB.huggingface-cli login. Gemma weights are gated. Run huggingface-cli login once with a write-scoped token, then load() works.--kv-bits 8 halves that for ~5% quality.curl first; add SDK after.