Gemma 4 Locally on Mac: Performance & Vision Benchmarks

LLM
Gemma-4
MLX
local-inference
multimodal
vision
Apple-Silicon
benchmarks
Author

Nipun Batra

Published

April 3, 2026

Google released Gemma 4 yesterday — open-weight multimodal models under Apache 2.0. In a companion post, we compared Gemma 4 and Gemini via the API. Here, we run Gemma 4 entirely locally on Apple Silicon using mlx-vlm and treat the notebook as a practical tour: a few qualitative demos, two lightweight benchmarks, and a throughput check.

  1. Qualitative examples — scene understanding, chart reading, prompted polygon segmentation
  2. Quantitative benchmarks — VQA exact-match accuracy on VQAv2 and class-aware localization on COCO
  3. Runtime performance — tokens/sec, prompt processing speed, peak memory

Hardware used for the results shown here: Mac Studio M2 Max with 64 GB unified memory.

Setup

uv pip install -U mlx-vlm==0.4.4 mlx-lm mlx datasets pycocotools supervision
Code
import time, json, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import supervision as sv
from IPython.display import display, Markdown
import mlx.core as mx

%config InlineBackend.figure_format = 'retina'

print(f"MLX version: {mx.__version__}")
print(f"Metal device: {mx.default_device()}")
MLX version: 0.31.1
Metal device: Device(gpu, 0)
TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

Loading the Model

We test Gemma 4 31B IT (4-bit) via mlx-vlm. At 4-bit quantization it needs ~19 GB — well within 64 GB unified memory.

Code
from mlx_vlm import load, generate
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_config

MODEL_LABEL = 'Gemma 4 31B IT (4-bit)'
HARDWARE_LABEL = 'Mac Studio M2 Max, 64 GB unified memory'
MODEL_PATH = 'mlx-community/gemma-4-31b-it-4bit'

t0 = time.time()
model, processor = load(MODEL_PATH)
config = load_config(MODEL_PATH)
load_time = time.time() - t0
print(f"Model loaded in {load_time:.1f}s")
print(f"Model type: {config.get('model_type')}")
objc[12040]: Class AVFFrameReceiver is implemented in both /Users/nipun/.uv/base/lib/python3.12/site-packages/cv2/.dylibs/libavdevice.61.3.100.dylib (0x11476c3a8) and /Users/nipun/.uv/base/lib/python3.12/site-packages/av/.dylibs/libavdevice.62.1.100.dylib (0x1497303a8). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.
objc[12040]: Class AVFAudioReceiver is implemented in both /Users/nipun/.uv/base/lib/python3.12/site-packages/cv2/.dylibs/libavdevice.61.3.100.dylib (0x11476c3f8) and /Users/nipun/.uv/base/lib/python3.12/site-packages/av/.dylibs/libavdevice.62.1.100.dylib (0x1497303f8). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.
Fetching 11 files: 100%|██████████| 11/11 [00:00<00:00, 62516.73it/s]
Fetching 11 files: 100%|██████████| 11/11 [00:00<00:00, 62857.42it/s]
Model loaded in 2.9s
Model type: gemma4

Inference Helper

The helper below keeps the rest of the notebook short: it runs local inference, returns timing/token stats, and exposes small parsing utilities reused later in the notebook.

Code
def strip_code_fences(text):
    text = text.strip()
    text = re.sub(r'```(?:json)?\s*', '', text)
    return re.sub(r'```', '', text)


def extract_json_fragment(text, kind='array'):
    cleaned = strip_code_fences(text)
    pattern = r'\[[\s\S]*\]' if kind == 'array' else r'\{[\s\S]*\}'
    match = re.search(pattern, cleaned)
    return match.group(0) if match else cleaned


def run_local(prompt_text, image_path=None, max_tokens=512, display_image=True, quiet=False):
    """Run Gemma 4 locally and return results with timing."""
    if image_path and display_image and not quiet:
        img = Image.open(image_path) if isinstance(image_path, str) else image_path
        fig, ax = plt.subplots(figsize=(4, 3))
        ax.imshow(img); ax.axis('off'); plt.tight_layout(); plt.show()

    if image_path and not isinstance(image_path, str):
        tmp = '/tmp/_mlx_tmp_img.png'
        image_path.save(tmp)
        image_path = tmp

    prompt = apply_chat_template(
        processor, config, prompt_text,
        num_images=1 if image_path else 0
    )

    t0 = time.time()
    result = generate(
        model, processor, prompt,
        image=image_path if image_path else None,
        max_tokens=max_tokens,
        verbose=False
    )
    elapsed = time.time() - t0

    out = {
        'response': result.text,
        'prompt_tokens': result.prompt_tokens,
        'gen_tokens': result.generation_tokens,
        'prompt_tps': result.prompt_tps,
        'gen_tps': result.generation_tps,
        'peak_mem_gb': result.peak_memory / (1024**3) if hasattr(result, 'peak_memory') and result.peak_memory else None,
        'time_s': elapsed,
    }

    if not quiet:
        display(Markdown(
            f"**{MODEL_LABEL} (local)** | {elapsed:.1f}s | {out['gen_tokens']} tok | {out['gen_tps']:.1f} tok/s"
        ))
        display(Markdown(out['response'][:2000]))
    return out

Part 1: Qualitative Examples

Scene Understanding

Code
results_scene = run_local(
    "Describe this image in 2-3 sentences. What is happening and what objects do you see?",
    image_path='classroom.jpg',
    max_tokens=256
)

Gemma 4 31B IT (4-bit) (local) | 7.4s | 54 tok | 17.6 tok/s

A group of young students are sitting at their desks in a classroom, focused on writing and reading in their notebooks. The room is filled with educational materials, including colorful posters on the walls, bookshelves, and various school supplies like pencils and erasers on the desks.

Chart Analysis

Code
np.random.seed(42)
epochs = np.arange(1, 21)
train_loss = 2.5 * np.exp(-0.15 * epochs) + 0.1 + np.random.normal(0, 0.05, 20)
val_loss = 2.5 * np.exp(-0.12 * epochs) + 0.3 + np.random.normal(0, 0.08, 20)
val_loss[14:] += np.linspace(0, 0.4, 6)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.plot(epochs, train_loss, 'b-o', label='Train', ms=4)
ax1.plot(epochs, val_loss, 'r-s', label='Val', ms=4)
ax1.set(xlabel='Epoch', ylabel='Loss', title='Training vs Validation Loss')
ax1.legend(); ax1.grid(True, alpha=0.3)

bars = ['CNN', 'ResNet', 'ViT', 'Ours']
acc = [78.2, 85.6, 89.1, 92.3]
ax2.bar(bars, acc, color=['#aaa','#aaa','#aaa','#e74c3c'])
ax2.set(ylabel='Accuracy (%)', title='Model Comparison', ylim=(70,100))
for i, v in enumerate(acc): ax2.text(i, v+0.5, f'{v}%', ha='center', fontsize=9)
plt.tight_layout()
fig.savefig('/tmp/research_plot.png', dpi=150, bbox_inches='tight')
plt.show()

Code
results_chart = run_local(
    """Analyze these two plots concisely:
1. Left: what does it show and at which epoch does overfitting begin?
2. Right: which model is best and by how much over the runner-up?""",
    image_path='/tmp/research_plot.png',
    max_tokens=300,
    display_image=False
)

Gemma 4 31B IT (4-bit) (local) | 10.5s | 108 tok | 17.0 tok/s

Based on the provided plots:

  1. Left Plot: It shows the training and validation loss over 20 epochs. Overfitting begins around epoch 12, where the validation loss stops decreasing and starts to trend upward while the training loss continues to drop.
  2. Right Plot: The “Ours” model is the best, with an accuracy of 92.3%. It outperforms the runner-up (ViT, 89.1%) by 3.2%.

Segmentation

This is a qualitative segmentation sanity check, not a segmentation benchmark. We ask Gemma to emit a polygon for the main animal and visualize the returned shape on top of the image.

Code
results_seg = run_local(
    """Segment the main animal in this image.
Return JSON: {"polygon": [[x1,y1], [x2,y2], ...], "label": "..."}
Coordinates in [0, 1000] range. Use 20+ points.
Return ONLY valid JSON.""",
    image_path='happy-doggy.jpg',
    max_tokens=2048
)

Gemma 4 31B IT (4-bit) (local) | 22.6s | 303 tok | 17.0 tok/s

{"polygon": [[100, 100], [160, 140], [230, 210], [300, 230], [400, 200], [500, 180], [600, 200], [700, 250], [800, 200], [900, 100], [950, 200], [950, 350], [850, 450], [800, 600], [900, 800], [950, 950], [800, 1000], [500, 1000], [200, 1000], [0, 950], [0, 800], [100, 700], [50, 600], [30, 550], [100, 450], [200, 400], [150, 300], [100, 200], [100, 100]], "label": "main animal"}
Code
from matplotlib.patches import Polygon as MPLPoly

seg_img = Image.open('happy-doggy.jpg')
img_np = np.array(seg_img)
h, w = img_np.shape[:2]

fig, ax = plt.subplots(figsize=(6, 5))
ax.set_title('Gemma 4 31B (local) — Prompted polygon segmentation')
ax.imshow(img_np)

try:
    seg = json.loads(extract_json_fragment(results_seg['response'], kind='object'))
    pts_raw = seg.get('polygon') or seg.get('segmentation') or seg.get('points', [])
    pts = [(x * w / 1000, y * h / 1000) for x, y in pts_raw]
    ax.add_patch(MPLPoly(pts, closed=True, fill=True, fc='red', alpha=0.35, ec='red', lw=2))
    ax.set_xlabel(f"{len(pts)} points — {seg.get('label', '?')}")
except Exception as e:
    ax.set_xlabel(f"Parse error: {e}")

ax.axis('off')
plt.tight_layout(); plt.show()


Part 2: Quantitative Benchmarks

These are small, reproducible sanity checks rather than leaderboard-grade evaluations. The point is to see what a local VLM can already do, and where it still struggles.

VQA: VQAv2 Subset

We sample 50 questions from vqav2-small validation split and measure strict exact-match accuracy against the ground-truth answer.

That strictness matters: 3 and 3 people count as different answers here, so the metric is intentionally unforgiving.

Code
from datasets import load_dataset

vqa_ds = load_dataset('merve/vqav2-small', split='validation', streaming=True)

# Sample 50 examples deterministically
np.random.seed(42)
N_VQA = 50

vqa_samples = []
for i, ex in enumerate(vqa_ds):
    if len(vqa_samples) >= N_VQA:
        break
    if np.random.random() < 0.05:  # ~5% sampling rate
        vqa_samples.append(ex)

print(f"Loaded {len(vqa_samples)} VQA samples")
print(f"Example: Q='{vqa_samples[0]['question']}', A='{vqa_samples[0]['multiple_choice_answer']}'")
Loaded 50 VQA samples
Example: Q='Where was this photo taken?', A='la brea ave'
Code
def vqa_exact_match(prediction, ground_truth):
    """Exact match accuracy (case-insensitive, strip punctuation)."""
    pred = prediction.strip().lower().rstrip('.').strip()
    gt = ground_truth.strip().lower().rstrip('.').strip()
    return 1.0 if pred == gt else 0.0

vqa_results = []
for i, sample in enumerate(vqa_samples):
    img = sample['image']
    question = sample['question']
    gt_answer = sample['multiple_choice_answer']

    tmp_path = '/tmp/_vqa_tmp.png'
    img.save(tmp_path)

    prompt = f"""Answer this question about the image with a single short answer (1-3 words max).
Question: {question}
Answer:"""

    r = run_local(prompt, image_path=tmp_path, max_tokens=20, quiet=True)
    pred = r['response'].strip().split('\n')[0]

    acc = vqa_exact_match(pred, gt_answer)
    vqa_results.append({
        'question': question,
        'prediction': pred,
        'gt_answer': gt_answer,
        'accuracy': acc,
        'time_s': r['time_s'],
        'gen_tps': r['gen_tps'],
    })

    if (i+1) % 10 == 0:
        running_acc = np.mean([r['accuracy'] for r in vqa_results])
        print(f"  [{i+1}/{N_VQA}] Running accuracy: {running_acc:.1%}")

overall_vqa_acc = np.mean([r['accuracy'] for r in vqa_results])
avg_time = np.mean([r['time_s'] for r in vqa_results])
avg_tps = np.mean([r['gen_tps'] for r in vqa_results])
vqa_summary = {
    'accuracy': overall_vqa_acc,
    'n_samples': len(vqa_results),
    'avg_time_s': avg_time,
    'avg_gen_tps': avg_tps,
}
print(f"\nVQA Accuracy: {overall_vqa_acc:.1%} ({len(vqa_results)} samples)")
print(f"Avg time per question: {avg_time:.1f}s")
print(f"Avg generation speed: {avg_tps:.1f} tok/s")
  [10/50] Running accuracy: 50.0%
  [20/50] Running accuracy: 35.0%
  [30/50] Running accuracy: 46.7%
  [40/50] Running accuracy: 52.5%
  [50/50] Running accuracy: 54.0%

VQA Accuracy: 54.0% (50 samples)
Avg time per question: 4.2s
Avg generation speed: 24.3 tok/s
Code
# Show some correct and incorrect examples
correct = [r for r in vqa_results if r['accuracy'] > 0]
wrong = [r for r in vqa_results if r['accuracy'] == 0]

display(Markdown(f"### VQA Results: {overall_vqa_acc:.1%} accuracy on {len(vqa_results)} samples"))
display(Markdown(f"{len(correct)} correct, {len(wrong)} incorrect"))

display(Markdown("#### Sample correct predictions"))
for r in correct[:5]:
    print(f"  Q: {r['question']}")
    print(f"  Pred: {r['prediction']} | GT: {r['gt_answer']}")
    print()

display(Markdown("#### Sample incorrect predictions"))
for r in wrong[:5]:
    print(f"  Q: {r['question']}")
    print(f"  Pred: {r['prediction']} | GT: {r['gt_answer']}")
    print()

VQA Results: 54.0% accuracy on 50 samples

27 correct, 23 incorrect

Sample correct predictions

  Q: Where was this photo taken?
  Pred: La Brea Ave. | GT: la brea ave

  Q: What color is the plate?
  Pred: White | GT: white

  Q: Is the child dressed up?
  Pred: Yes | GT: yes

  Q: Are there various jelly choices?
  Pred: No | GT: no

  Q: What mode of transit is this?
  Pred: Bus | GT: bus

Sample incorrect predictions

  Q: How many elephants?
  Pred: Two elephants | GT: 2

  Q: What is in the mug?
  Pred: sauce | GT: butter

  Q: Why is there more than one tab open on the computer?
  Pred: Unknown | GT: working

  Q: How many people in the picture?
  Pred: 3 people | GT: 3

  Q: What type of boards are on the vehicle?
  Pred: Surfboards | GT: canoe

Class-Aware Localization on COCO

We take 20 COCO validation images and prompt Gemma 4 to return class-labeled boxes, then match predictions against ground truth after converting COCO’s native [x, y, w, h] boxes into supervision detections.

This is a simple localization sanity check, not COCO mAP:

  • only medium-to-large GT boxes are scored (area >= 32^2)
  • we report both a coarse class-aware view (IoU >= 0.3) and a strict one (IoU >= 0.5)
  • all parsing, rendering, and IoU matching go through supervision

That distinction matters for VLMs: many boxes are directionally right but still too loose for strict detector-style scoring.

Code
coco_ds = load_dataset('detection-datasets/coco', split='val', streaming=True)

# Hugging Face exposes this nested feature differently across datasets versions:
# sometimes as a dict-like object, sometimes wrapped in one or more List/Sequence layers.
objects_feature = coco_ds.features['objects']
category_feature = objects_feature['category'] if isinstance(objects_feature, dict) else objects_feature.feature['category']
while hasattr(category_feature, 'feature'):
    category_feature = category_feature.feature

# The HF dataset uses contiguous 0-based category ids, not the sparse original COCO ids.
COCO_CATS = {
    idx: name for idx, name in enumerate(category_feature.names)
}

# Aliases: common VLM outputs → canonical COCO names
COCO_NAME_ALIASES = {
    'tv monitor': 'tv', 'tvmonitor': 'tv', 'television': 'tv', 'monitor': 'tv',
    'mobile phone': 'cell phone', 'cellphone': 'cell phone', 'phone': 'cell phone',
    'smartphone': 'cell phone', 'iphone': 'cell phone',
    'sofa': 'couch', 'loveseat': 'couch',
    'diningtable': 'dining table', 'table': 'dining table', 'desk': 'dining table',
    'pottedplant': 'potted plant', 'houseplant': 'potted plant', 'plant': 'potted plant',
    'aeroplane': 'airplane', 'plane': 'airplane', 'jet': 'airplane',
    'motorbike': 'motorcycle', 'scooter': 'motorcycle',
    'couch': 'couch', 'wine': 'wine glass', 'glass': 'wine glass',
    'laptop computer': 'laptop', 'notebook': 'laptop',
    'bicycle': 'bicycle', 'bike': 'bicycle',
    'teddy': 'teddy bear', 'stuffed bear': 'teddy bear', 'plush bear': 'teddy bear',
    'traffic signal': 'traffic light', 'signal': 'traffic light',
    'automobile': 'car', 'vehicle': 'car',
    'human': 'person', 'man': 'person', 'woman': 'person', 'child': 'person',
    'boy': 'person', 'girl': 'person', 'people': 'person', 'pedestrian': 'person',
}


def normalize_label(label):
    if label is None:
        return None
    label = re.sub(r'[_-]+', ' ', str(label).strip().lower())
    label = re.sub(r'\s+', ' ', label)
    return COCO_NAME_ALIASES.get(label, label)


COCO_NAME_TO_ID = {normalize_label(name): cid for cid, name in COCO_CATS.items()}

# Minimum GT box area (pixels²) — we align the benchmark with medium/large objects.
MIN_GT_AREA = 32 * 32

N_OD = 20

od_samples = []
for ex in coco_ds:
    if len(od_samples) >= N_OD:
        break
    cats = [COCO_CATS[c] for c in ex['objects']['category'] if c in COCO_CATS]
    if len(cats) >= 2:
        od_samples.append(ex)

print(f'Loaded {len(od_samples)} COCO images')
Loaded 20 COCO images
Code
def make_detections(xyxy, class_ids, class_names):
    if len(xyxy) == 0:
        return sv.Detections.empty()
    return sv.Detections(
        xyxy=np.asarray(xyxy, dtype=np.float32),
        class_id=np.asarray(class_ids, dtype=np.int32),
        data={'class_name': np.asarray(class_names, dtype=object)},
    )


def detection_labels(detections):
    if len(detections) == 0:
        return []
    names = detections.data.get('class_name')
    if names is None:
        return [str(class_id) for class_id in detections.class_id]
    return [str(name) for name in names]


def coco_gt_detections(sample):
    """Convert COCO GT boxes from xywh to supervision detections."""
    xywh = np.asarray(sample['objects']['bbox'], dtype=np.float32)
    class_ids = np.asarray(sample['objects']['category'], dtype=np.int32)
    if len(xywh) == 0:
        return sv.Detections.empty()

    xyxy = sv.xywh_to_xyxy(xywh)
    keep, class_names = [], []
    for idx, (box, class_id) in enumerate(zip(xyxy, class_ids)):
        label = COCO_CATS.get(int(class_id))
        if label is None:
            continue
        width = max(0.0, float(box[2] - box[0]))
        height = max(0.0, float(box[3] - box[1]))
        if width * height < MIN_GT_AREA:
            continue
        keep.append(idx)
        class_names.append(label)

    if not keep:
        return sv.Detections.empty()
    keep = np.asarray(keep, dtype=int)
    return make_detections(xyxy[keep], class_ids[keep], class_names)


def parse_gemma_detections(text, width, height):
    """Parse Gemma's JSON response into supervision detections."""
    if not text:
        return sv.Detections.empty()

    try:
        dets = json.loads(extract_json_fragment(text, kind='array'))
    except Exception:
        return sv.Detections.empty()

    if not isinstance(dets, list):
        return sv.Detections.empty()

    xyxy, class_ids, class_names = [], [], []
    for det in dets:
        if not isinstance(det, dict):
            continue
        label = normalize_label(
            det.get('label') or det.get('class') or det.get('name') or det.get('object')
        )
        box = det.get('box_2d') or det.get('box') or det.get('bbox') or det.get('bounding_box')
        if label not in COCO_NAME_TO_ID or not isinstance(box, (list, tuple)) or len(box) != 4:
            continue
        try:
            y1_n, x1_n, y2_n, x2_n = [float(v) for v in box]
        except (ValueError, TypeError):
            continue

        x1 = float(np.clip(x1_n / 1000 * width, 0, width))
        x2 = float(np.clip(x2_n / 1000 * width, 0, width))
        y1 = float(np.clip(y1_n / 1000 * height, 0, height))
        y2 = float(np.clip(y2_n / 1000 * height, 0, height))
        if x2 <= x1 + 1 or y2 <= y1 + 1:
            continue

        xyxy.append([x1, y1, x2, y2])
        class_ids.append(COCO_NAME_TO_ID[label])
        class_names.append(label)

    return make_detections(xyxy, class_ids, class_names)


def greedy_match(iou_matrix, threshold):
    matches = []
    if iou_matrix.size == 0:
        return matches

    gt_idx, pred_idx = np.where(iou_matrix >= threshold)
    if len(gt_idx) == 0:
        return matches

    scores = iou_matrix[gt_idx, pred_idx]
    order = np.argsort(scores)[::-1]
    used_gt, used_pred = set(), set()
    for idx in order:
        gt = int(gt_idx[idx])
        pred = int(pred_idx[idx])
        if gt in used_gt or pred in used_pred:
            continue
        used_gt.add(gt)
        used_pred.add(pred)
        matches.append((gt, pred, float(iou_matrix[gt, pred])))
    return matches


def summarize_matches(matches, n_pred, n_gt):
    tp = len(matches)
    precision = tp / n_pred if n_pred else 0.0
    recall = tp / n_gt if n_gt else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
    pred_to_iou = {pred: iou for gt, pred, iou in matches}
    pred_to_gt = {pred: gt for gt, pred, iou in matches}
    gt_indices = sorted(gt for gt, pred, iou in matches)
    mean_iou = float(np.mean([iou for _, _, iou in matches])) if matches else 0.0
    return {
        'tp': tp,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'pred_to_iou': pred_to_iou,
        'pred_to_gt': pred_to_gt,
        'gt_indices': gt_indices,
        'mean_iou': mean_iou,
    }


def evaluate_detections(predictions, targets, coarse_iou=0.3, strict_iou=0.5):
    """Class-aware localization analysis with coarse and strict views."""
    n_pred, n_gt = len(predictions), len(targets)
    pred_labels = detection_labels(predictions)
    gt_labels = detection_labels(targets)

    if n_pred and n_gt:
        pair_ious = sv.box_iou_batch(targets.xyxy, predictions.xyxy)
        class_match = targets.class_id[:, None] == predictions.class_id[None, :]
        same_class_ious = np.where(class_match, pair_ious, -1.0)
    else:
        pair_ious = np.zeros((n_gt, n_pred), dtype=np.float32)
        class_match = np.zeros((n_gt, n_pred), dtype=bool)
        same_class_ious = np.zeros((n_gt, n_pred), dtype=np.float32)

    strict_matches = greedy_match(same_class_ious, strict_iou)
    coarse_matches = greedy_match(same_class_ious, coarse_iou)
    strict_stats = summarize_matches(strict_matches, n_pred, n_gt)
    coarse_stats = summarize_matches(coarse_matches, n_pred, n_gt)

    strict_tp_mask = np.zeros(n_pred, dtype=bool)
    strict_tp_ious = np.zeros(n_pred, dtype=np.float32)
    for gt_idx, pred_idx, iou in strict_matches:
        strict_tp_mask[pred_idx] = True
        strict_tp_ious[pred_idx] = iou

    best_same_iou = np.zeros(n_pred, dtype=np.float32)
    best_any_iou = np.zeros(n_pred, dtype=np.float32)
    best_any_gt_idx = np.full(n_pred, -1, dtype=int)

    if n_pred and n_gt:
        best_any_gt_idx = pair_ious.argmax(axis=0)
        best_any_iou = pair_ious[best_any_gt_idx, np.arange(n_pred)]
        has_same = class_match.any(axis=0)
        if has_same.any():
            same_pred_cols = np.where(has_same)[0]
            same_gt_idx = same_class_ious[:, same_pred_cols].argmax(axis=0)
            best_same_iou[same_pred_cols] = same_class_ious[same_gt_idx, same_pred_cols]

    pred_status = []
    pred_display_labels = []
    for pred_idx, label in enumerate(pred_labels):
        if pred_idx in strict_stats['pred_to_iou']:
            iou = strict_stats['pred_to_iou'][pred_idx]
            pred_status.append('tp')
            pred_display_labels.append(f'{label} (TP {iou:.2f})')
        elif best_same_iou[pred_idx] >= coarse_iou:
            pred_status.append('near')
            pred_display_labels.append(f'{label} (near {best_same_iou[pred_idx]:.2f})')
        elif best_any_iou[pred_idx] >= strict_iou and best_any_gt_idx[pred_idx] >= 0:
            gt_label = gt_labels[best_any_gt_idx[pred_idx]]
            pred_status.append('class')
            pred_display_labels.append(f'{label} (vs {gt_label} {best_any_iou[pred_idx]:.2f})')
        else:
            pred_status.append('fp')
            pred_display_labels.append(f'{label} (FP)')

    status_counts = {
        status: int(sum(s == status for s in pred_status))
        for status in ['tp', 'near', 'class', 'fp']
    }

    return {
        'tp': coarse_stats['tp'],
        'n_pred': n_pred,
        'n_gt': n_gt,
        'precision': coarse_stats['precision'],
        'recall': coarse_stats['recall'],
        'f1': coarse_stats['f1'],
        'strict_precision': strict_stats['precision'],
        'strict_recall': strict_stats['recall'],
        'strict_f1': strict_stats['f1'],
        'mean_tp_iou': strict_stats['mean_iou'],
        'strict_tp_mask': strict_tp_mask,
        'tp_ious': strict_tp_ious,
        'pred_status': np.asarray(pred_status, dtype=object),
        'pred_display_labels': pred_display_labels,
        'matched_gt_indices': coarse_stats['gt_indices'],
        'strict_matched_gt_indices': strict_stats['gt_indices'],
        'status_counts': status_counts,
    }
Code
STATUS_COLORS = {
    'tp': sv.Color(22, 163, 74),
    'near': sv.Color(245, 158, 11),
    'class': sv.Color(168, 85, 247),
    'fp': sv.Color(220, 38, 38),
}


def annotate_detections(scene, detections, labels=None, color_lookup=sv.ColorLookup.CLASS, color=None):
    if len(detections) == 0:
        return scene

    if color is None:
        box_annotator = sv.BoxAnnotator(thickness=2, color_lookup=color_lookup)
        label_annotator = sv.LabelAnnotator(
            text_scale=0.4,
            text_padding=3,
            color_lookup=color_lookup,
        )
    else:
        box_annotator = sv.BoxAnnotator(thickness=2, color=color)
        label_annotator = sv.LabelAnnotator(
            text_scale=0.4,
            text_padding=3,
            color=color,
        )

    scene = box_annotator.annotate(scene=scene, detections=detections)
    scene = label_annotator.annotate(
        scene=scene,
        detections=detections,
        labels=labels or detection_labels(detections),
    )
    return scene


def render_detections(image, detections, labels=None, color_lookup=sv.ColorLookup.CLASS, color=None):
    scene = np.array(image.convert('RGB')).copy()
    return annotate_detections(
        scene,
        detections,
        labels=labels,
        color_lookup=color_lookup,
        color=color,
    )


def annotate_prediction_status(image, predictions, pred_status, pred_display_labels):
    scene = np.array(image.convert('RGB')).copy()
    if len(predictions) == 0:
        return scene

    for status in ['fp', 'class', 'near', 'tp']:
        mask = pred_status == status
        if not mask.any():
            continue
        dets = predictions[mask]
        labels = [label for label, keep in zip(pred_display_labels, mask) if keep]
        scene = annotate_detections(
            scene,
            dets,
            labels=labels,
            color=STATUS_COLORS[status],
        )

    return scene
Code
od_prompt = """Detect the clearly visible medium-to-large objects in this image and return bounding boxes.
Return ONLY a JSON array — no explanation, no markdown.
Each element: {"label": "<coco_class>", "box_2d": [y_min, x_min, y_max, x_max]}
Coordinates are integers in [0, 1000] (normalized to image size).
Use only standard COCO class names.
Skip tiny distant instances; focus on objects with clear boundaries."""

od_results = []
for i, sample in enumerate(od_samples):
    img = sample['image']
    w, h = img.size
    gt_detections = coco_gt_detections(sample)
    tmp = '/tmp/_od_tmp.jpg'
    img.save(tmp)
    out = run_local(od_prompt, image_path=tmp, max_tokens=1024, quiet=True)
    pred_detections = parse_gemma_detections(out['response'], w, h)
    metrics = evaluate_detections(pred_detections, gt_detections)
    metrics['image'] = img
    metrics['gt_detections'] = gt_detections
    metrics['pred_detections'] = pred_detections
    metrics['time_s'] = out['time_s']
    metrics['raw_response'] = out['response']
    od_results.append(metrics)
    if (i + 1) % 5 == 0:
        coarse_p = np.mean([r['precision'] for r in od_results])
        coarse_r = np.mean([r['recall'] for r in od_results])
        strict_r = np.mean([r['strict_recall'] for r in od_results])
        print(f'  [{i+1}/{N_OD}] coarse P/R={coarse_p:.2f}/{coarse_r:.2f} | strict R={strict_r:.2f}')

od_summary = {
    'precision': np.mean([r['precision'] for r in od_results]),
    'recall': np.mean([r['recall'] for r in od_results]),
    'f1': np.mean([r['f1'] for r in od_results]),
    'strict_precision': np.mean([r['strict_precision'] for r in od_results]),
    'strict_recall': np.mean([r['strict_recall'] for r in od_results]),
    'strict_f1': np.mean([r['strict_f1'] for r in od_results]),
    'mean_tp_iou': np.mean([r['mean_tp_iou'] for r in od_results]),
    'total_tp': sum(r['tp'] for r in od_results),
    'total_pred': sum(r['n_pred'] for r in od_results),
    'total_gt': sum(r['n_gt'] for r in od_results),
    'avg_time_s': np.mean([r['time_s'] for r in od_results]),
}

status_totals = {
    status: sum(r['status_counts'][status] for r in od_results)
    for status in ['tp', 'near', 'class', 'fp']
}

print(f'\n--- COCO Localization ({N_OD} images, GT boxes ≥ {MIN_GT_AREA}px²) ---')
print(f'Coarse P/R/F1 @0.3: {od_summary["precision"]:.2f} / {od_summary["recall"]:.2f} / {od_summary["f1"]:.2f}')
print(f'Strict P/R/F1 @0.5: {od_summary["strict_precision"]:.2f} / {od_summary["strict_recall"]:.2f} / {od_summary["strict_f1"]:.2f}')
print(f'Mean strict TP IoU: {od_summary["mean_tp_iou"]:.2f}')
print(f'Predictions by status: TP={status_totals["tp"]}, near={status_totals["near"]}, class mismatch={status_totals["class"]}, FP={status_totals["fp"]}')
print(f'Avg time/image: {od_summary["avg_time_s"]:.1f}s')
  [5/20] coarse P/R=0.51/0.27 | strict R=0.10
  [10/20] coarse P/R=0.35/0.21 | strict R=0.07
  [15/20] coarse P/R=0.40/0.24 | strict R=0.08
  [20/20] coarse P/R=0.38/0.23 | strict R=0.11

--- COCO Localization (20 images, GT boxes ≥ 1024px²) ---
Coarse P/R/F1 @0.3: 0.38 / 0.23 / 0.28
Strict P/R/F1 @0.5: 0.18 / 0.11 / 0.14
Mean strict TP IoU: 0.25
Predictions by status: TP=8, near=9, class mismatch=0, FP=57
Avg time/image: 13.5s

GT vs. Predictions: Strict Hits vs Near Misses

Code
viz_idx = np.linspace(0, len(od_results) - 1, 6, dtype=int)

fig, axes = plt.subplots(len(viz_idx), 2, figsize=(14, 4 * len(viz_idx)))

for row, idx in enumerate(viz_idx):
    r = od_results[idx]
    gt_scene = render_detections(r['image'], r['gt_detections'])
    pred_scene = annotate_prediction_status(
        r['image'],
        r['pred_detections'],
        r['pred_status'],
        r['pred_display_labels'],
    )

    axes[row, 0].imshow(gt_scene)
    axes[row, 0].set_title(f'Ground truth ({r["n_gt"]} objects)', fontsize=10)
    axes[row, 0].axis('off')

    axes[row, 1].imshow(pred_scene)
    axes[row, 1].set_title(
        f'Coarse P/R@0.3={r["precision"]:.2f}/{r["recall"]:.2f} | '
        f'Strict P/R@0.5={r["strict_precision"]:.2f}/{r["strict_recall"]:.2f}',
        fontsize=10,
    )
    axes[row, 1].axis('off')

fig.suptitle(
    'COCO localization overlays: green=TP@0.5, amber=near miss, purple=class mismatch, red=no usable match',
    fontsize=13,
    fontweight='bold',
)
plt.tight_layout()
plt.show()

Detection Quality Breakdown

Code
from collections import Counter

fig, axes = plt.subplots(1, 3, figsize=(16, 4.5))

# 1 — Per-image coarse P vs R scatter
precs = [r['precision'] for r in od_results]
recs = [r['recall'] for r in od_results]
axes[0].scatter(recs, precs, s=60, alpha=0.7, edgecolors='black', linewidth=0.5, c='#06b6d4')
axes[0].set_xlabel('Recall @0.3'); axes[0].set_ylabel('Precision @0.3')
axes[0].set_xlim(-0.05, 1.05); axes[0].set_ylim(-0.05, 1.05)
axes[0].set_title('Per-image coarse precision vs recall')
axes[0].axhline(od_summary['precision'], color='gray', ls='--', alpha=0.5)
axes[0].axvline(od_summary['recall'], color='gray', ls='--', alpha=0.5)
axes[0].text(
    od_summary['recall'] + 0.02,
    od_summary['precision'] + 0.04,
    f"mean ({od_summary['precision']:.2f}, {od_summary['recall']:.2f})",
    fontsize=8,
    color='gray',
)

# 2 — IoU distribution of strict TPs
tp_ious = [float(iou) for r in od_results for iou, is_tp in zip(r['tp_ious'], r['strict_tp_mask']) if is_tp]
if tp_ious:
    axes[1].hist(tp_ious, bins=15, color='#22c55e', edgecolor='white', alpha=0.85)
    axes[1].axvline(0.5, color='red', ls='--', lw=2, label='IoU=0.5 threshold')
    axes[1].set_xlabel('IoU'); axes[1].set_ylabel('Count')
    axes[1].legend(fontsize=8)
axes[1].set_title(f'IoU of strict matches (n={len(tp_ious)})')

# 3 — Per-class coarse recall (top 10 GT classes)
gt_counts = Counter()
tp_counts = Counter()
for r in od_results:
    gt_labels = detection_labels(r['gt_detections'])
    gt_counts.update(gt_labels)
    matched_gt = r['gt_detections'][np.asarray(r['matched_gt_indices'], dtype=int)] if r['matched_gt_indices'] else sv.Detections.empty()
    tp_counts.update(detection_labels(matched_gt))

top = gt_counts.most_common(10)
cls_names = [name for name, _ in top]
cls_recall = [tp_counts[name] / gt_counts[name] for name in cls_names]
bar_colors = ['#22c55e' if value >= 0.5 else '#f97316' if value >= 0.25 else '#ef4444' for value in cls_recall]

axes[2].barh(cls_names[::-1], cls_recall[::-1], color=bar_colors[::-1])
axes[2].set_xlabel('Recall @0.3')
axes[2].set_xlim(0, 1.05)
axes[2].set_title('Per-class coarse recall (top 10 by GT count)')
for row, (name, recall_value) in enumerate(zip(cls_names[::-1], cls_recall[::-1])):
    total = gt_counts[name]
    axes[2].text(recall_value + 0.02, row, f'{tp_counts[name]}/{total}', va='center', fontsize=8)

plt.suptitle('Localization quality analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

Best and Worst Detections

Code
from collections import Counter

f1s = [r['f1'] for r in od_results]
best_i, worst_i = int(np.argmax(f1s)), int(np.argmin(f1s))

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for row, (idx, label) in enumerate([(best_i, 'Best'), (worst_i, 'Worst')]):
    r = od_results[idx]
    axes[row, 0].imshow(render_detections(r['image'], r['gt_detections']))
    axes[row, 0].set_title(f'{label} (coarse F1={r["f1"]:.2f}) — GT: {r["n_gt"]} objects', fontsize=11)
    axes[row, 0].axis('off')

    axes[row, 1].imshow(
        annotate_prediction_status(
            r['image'],
            r['pred_detections'],
            r['pred_status'],
            r['pred_display_labels'],
        )
    )
    axes[row, 1].set_title(
        f'{label} — coarse P/R={r["precision"]:.2f}/{r["recall"]:.2f} | strict R={r["strict_recall"]:.2f}',
        fontsize=11,
    )
    axes[row, 1].axis('off')

plt.suptitle('Best vs worst localization (by coarse F1)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

rw = od_results[worst_i]
gt_labels = Counter(detection_labels(rw['gt_detections']))
pred_labels = Counter(detection_labels(rw['pred_detections']))
print(f'Worst image — GT classes:   {dict(gt_labels)}')
print(f'Worst image — Pred classes: {dict(pred_labels)}')
print(f'Worst image — status counts: {rw["status_counts"]}')
missed = set(gt_labels) - set(pred_labels)
if missed:
    print(f'Missed entirely: {missed}')

Worst image — GT classes:   {'person': 7, 'tennis racket': 1}
Worst image — Pred classes: {'person': 3}
Worst image — status counts: {'tp': 0, 'near': 0, 'class': 0, 'fp': 3}
Missed entirely: {'tennis racket'}

Part 3: Runtime Performance

Prompt Length vs Throughput

How does generation speed change as prompt complexity increases? We test text-only and image prompts of varying length.

The longest image prompt is intentionally a stress test rather than a recommended everyday prompt.

Code
perf_tests = [
    ('Short text', 'What is 2+2?', None, 50),
    ('Medium text', 'Explain the transformer architecture in detail, covering attention, positional encoding, and training.', None, 300),
    ('Long text', 'Write a detailed tutorial on building a CNN for image classification in PyTorch. Cover data loading, model architecture with conv/pool/fc layers, training loop, evaluation, and common pitfalls. Include code snippets.' , None, 500),
    ('Image + short', 'What is this?', 'classroom.jpg', 50),
    ('Image + medium', 'Describe everything you see in this image in detail.', 'classroom.jpg', 300),
    ('Image + long', 'Analyze this image. Describe the setting, count all people, describe their activities, identify all objects, estimate the time of day, and suggest what event is taking place. Be thorough.', 'classroom.jpg', 500),
]

perf_rows = []
for name, prompt, img, max_tok in perf_tests:
    print(f"Running: {name}...", end=' ')
    r = run_local(prompt, image_path=img, max_tokens=max_tok, quiet=True)
    perf_rows.append({
        'Test': name,
        'Prompt tokens': r['prompt_tokens'],
        'Gen tokens': r['gen_tokens'],
        'Prompt tok/s': round(r['prompt_tps'], 1),
        'Gen tok/s': round(r['gen_tps'], 1),
        'Time (s)': round(r['time_s'], 1),
        'Peak mem (GB)': round(r['peak_mem_gb'], 1) if r['peak_mem_gb'] else 'N/A',
    })
    print(f"{r['time_s']:.1f}s, {r['gen_tps']:.1f} tok/s")

df_perf = pd.DataFrame(perf_rows)
display(df_perf.style.hide(axis='index'))
Running: Short text... 1.3s, 18.8 tok/s
Running: Medium text... 18.1s, 17.1 tok/s
Running: Long text... 29.9s, 17.2 tok/s
Running: Image + short... 6.4s, 17.3 tok/s
Running: Image + medium... 21.4s, 17.1 tok/s
Running: Image + long... 33.9s, 16.8 tok/s
Test Prompt tokens Gen tokens Prompt tok/s Gen tok/s Time (s) Peak mem (GB)
Short text 20 8 38.200000 18.800000 1.300000 0.000000
Medium text 29 300 57.000000 17.100000 18.100000 0.000000
Long text 55 500 63.700000 17.200000 29.900000 0.000000
Image + short 280 44 74.400000 17.300000 6.400000 0.000000
Image + medium 286 300 76.000000 17.100000 21.400000 0.000000
Image + long 313 500 76.200000 16.800000 33.900000 0.000000
Code
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

labels = df_perf['Test']
colors = ['#3498db' if 'Image' not in t else '#e74c3c' for t in labels]

axes[0].barh(labels, df_perf['Gen tok/s'], color=colors, alpha=0.8)
axes[0].set_xlabel('Tokens/sec'); axes[0].set_title('Generation Speed')
for i, v in enumerate(df_perf['Gen tok/s']):
    axes[0].text(v + 0.3, i, f'{v}', va='center', fontsize=9)

axes[1].barh(labels, df_perf['Prompt tok/s'], color=colors, alpha=0.8)
axes[1].set_xlabel('Tokens/sec'); axes[1].set_title('Prompt Processing Speed')
for i, v in enumerate(df_perf['Prompt tok/s']):
    axes[1].text(v + 0.3, i, f'{v}', va='center', fontsize=9)

axes[2].barh(labels, df_perf['Time (s)'], color=colors, alpha=0.8)
axes[2].set_xlabel('Seconds'); axes[2].set_title('Total Time')
for i, v in enumerate(df_perf['Time (s)']):
    axes[2].text(v + 0.3, i, f'{v}s', va='center', fontsize=9)

fig.suptitle('Gemma 4 31B (4-bit) on M2 Max — Blue: text only, Red: with image', fontsize=11)
plt.tight_layout(); plt.show()


Summary

Code
summary = pd.DataFrame([
    {"Metric": "Model", "Value": MODEL_LABEL},
    {"Metric": "Framework", "Value": "mlx-vlm 0.4.4 (PyPI)"},
    {"Metric": "Hardware", "Value": HARDWARE_LABEL},
    {"Metric": "Peak Memory", "Value": "~19 GB"},
    {"Metric": "Gen Speed (text)", "Value": f"{df_perf.iloc[0]['Gen tok/s']} tok/s"},
    {"Metric": "Gen Speed (vision)", "Value": f"{df_perf.iloc[3]['Gen tok/s']} tok/s"},
    {"Metric": "VQA Accuracy (VQAv2, n=50)", "Value": f"{vqa_summary['accuracy']:.1%}"},
    {"Metric": "Localization Precision @0.3 (COCO, n=20)", "Value": f"{od_summary['precision']:.2f}"},
    {"Metric": "Localization Recall @0.3 (COCO, n=20)", "Value": f"{od_summary['recall']:.2f}"},
    {"Metric": "Strict Recall @0.5 (COCO, n=20)", "Value": f"{od_summary['strict_recall']:.2f}"},
    {"Metric": "Mean strict TP IoU (COCO, n=20)", "Value": f"{od_summary['mean_tp_iou']:.2f}"},
])
display(summary.style.hide(axis="index"))
Metric Value
Model Gemma 4 31B IT (4-bit)
Framework mlx-vlm 0.4.4 (PyPI)
Hardware Mac Studio M2 Max, 64 GB unified memory
Peak Memory ~19 GB
Gen Speed (text) 18.8 tok/s
Gen Speed (vision) 17.3 tok/s
VQA Accuracy (VQAv2, n=50) 54.0%
Localization Precision @0.3 (COCO, n=20) 0.38
Localization Recall @0.3 (COCO, n=20) 0.23
Strict Recall @0.5 (COCO, n=20) 0.11
Mean strict TP IoU (COCO, n=20) 0.25
Code
fig, axes = plt.subplots(1, 2, figsize=(11, 4))

axes[0].bar(["VQA Accuracy"], [vqa_summary['accuracy']], color="#2ecc71", alpha=0.8)
axes[0].set_ylim(0, 1)
axes[0].set_ylabel("Accuracy")
axes[0].set_title(f"VQAv2 ({vqa_summary['n_samples']} samples)")
axes[0].text(0, vqa_summary['accuracy'] + 0.02, f"{vqa_summary['accuracy']:.1%}", ha="center", fontsize=14)

od_bar = {
    "Precision @0.3": od_summary["precision"],
    "Recall @0.3": od_summary["recall"],
    "Strict Recall @0.5": od_summary["strict_recall"],
    "Mean TP IoU": od_summary["mean_tp_iou"],
}
axes[1].bar(od_bar.keys(), od_bar.values(), color=["#3498db", "#e74c3c", "#9b59b6", "#f39c12"], alpha=0.8)
axes[1].set_ylim(0, 1)
axes[1].set_title(f"COCO Class-Aware Localization ({N_OD} images)")
for j, (label, value) in enumerate(od_bar.items()):
    axes[1].text(j, value + 0.02, f"{value:.2f}", ha="center", fontsize=11)

plt.suptitle(f"{MODEL_LABEL} (local) — Quantitative Results", fontsize=12)
plt.tight_layout()
plt.show()

Takeaways

  • Local multimodal inference is already practical: Gemma 4 31B (4-bit) runs with workable latency on Apple Silicon for interactive image prompting
  • The qualitative demos are the model at its best: scene understanding, chart reading, and simple prompted polygons all work surprisingly well locally
  • The VQA number is strict: exact-match grading penalizes formatting differences, so the benchmark is intentionally harsher than a human evaluator would be
  • COCO looks better under a coarse grounding lens than under strict detector-style scoring: many boxes are useful at IoU >= 0.3, but far fewer survive a strict IoU >= 0.5 match
  • Not every red-looking prediction is equally bad: the notebook now separates strict hits, near misses, class mismatches, and clear failures in the overlays
  • supervision helps a lot: cleaner overlays and explicit status colors make the localization section easier to interpret
  • Privacy is still the biggest advantage: the whole pipeline stays offline on Apple Silicon

For API-based comparison with Gemini models, see the companion post.