def make_detections(xyxy, class_ids, class_names):
if len(xyxy) == 0:
return sv.Detections.empty()
return sv.Detections(
xyxy=np.asarray(xyxy, dtype=np.float32),
class_id=np.asarray(class_ids, dtype=np.int32),
data={'class_name': np.asarray(class_names, dtype=object)},
)
def detection_labels(detections):
if len(detections) == 0:
return []
names = detections.data.get('class_name')
if names is None:
return [str(class_id) for class_id in detections.class_id]
return [str(name) for name in names]
def coco_gt_detections(sample):
"""Convert COCO GT boxes from xywh to supervision detections."""
xywh = np.asarray(sample['objects']['bbox'], dtype=np.float32)
class_ids = np.asarray(sample['objects']['category'], dtype=np.int32)
if len(xywh) == 0:
return sv.Detections.empty()
xyxy = sv.xywh_to_xyxy(xywh)
keep, class_names = [], []
for idx, (box, class_id) in enumerate(zip(xyxy, class_ids)):
label = COCO_CATS.get(int(class_id))
if label is None:
continue
width = max(0.0, float(box[2] - box[0]))
height = max(0.0, float(box[3] - box[1]))
if width * height < MIN_GT_AREA:
continue
keep.append(idx)
class_names.append(label)
if not keep:
return sv.Detections.empty()
keep = np.asarray(keep, dtype=int)
return make_detections(xyxy[keep], class_ids[keep], class_names)
def parse_gemma_detections(text, width, height):
"""Parse Gemma's JSON response into supervision detections."""
if not text:
return sv.Detections.empty()
try:
dets = json.loads(extract_json_fragment(text, kind='array'))
except Exception:
return sv.Detections.empty()
if not isinstance(dets, list):
return sv.Detections.empty()
xyxy, class_ids, class_names = [], [], []
for det in dets:
if not isinstance(det, dict):
continue
label = normalize_label(
det.get('label') or det.get('class') or det.get('name') or det.get('object')
)
box = det.get('box_2d') or det.get('box') or det.get('bbox') or det.get('bounding_box')
if label not in COCO_NAME_TO_ID or not isinstance(box, (list, tuple)) or len(box) != 4:
continue
try:
y1_n, x1_n, y2_n, x2_n = [float(v) for v in box]
except (ValueError, TypeError):
continue
x1 = float(np.clip(x1_n / 1000 * width, 0, width))
x2 = float(np.clip(x2_n / 1000 * width, 0, width))
y1 = float(np.clip(y1_n / 1000 * height, 0, height))
y2 = float(np.clip(y2_n / 1000 * height, 0, height))
if x2 <= x1 + 1 or y2 <= y1 + 1:
continue
xyxy.append([x1, y1, x2, y2])
class_ids.append(COCO_NAME_TO_ID[label])
class_names.append(label)
return make_detections(xyxy, class_ids, class_names)
def greedy_match(iou_matrix, threshold):
matches = []
if iou_matrix.size == 0:
return matches
gt_idx, pred_idx = np.where(iou_matrix >= threshold)
if len(gt_idx) == 0:
return matches
scores = iou_matrix[gt_idx, pred_idx]
order = np.argsort(scores)[::-1]
used_gt, used_pred = set(), set()
for idx in order:
gt = int(gt_idx[idx])
pred = int(pred_idx[idx])
if gt in used_gt or pred in used_pred:
continue
used_gt.add(gt)
used_pred.add(pred)
matches.append((gt, pred, float(iou_matrix[gt, pred])))
return matches
def summarize_matches(matches, n_pred, n_gt):
tp = len(matches)
precision = tp / n_pred if n_pred else 0.0
recall = tp / n_gt if n_gt else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
pred_to_iou = {pred: iou for gt, pred, iou in matches}
pred_to_gt = {pred: gt for gt, pred, iou in matches}
gt_indices = sorted(gt for gt, pred, iou in matches)
mean_iou = float(np.mean([iou for _, _, iou in matches])) if matches else 0.0
return {
'tp': tp,
'precision': precision,
'recall': recall,
'f1': f1,
'pred_to_iou': pred_to_iou,
'pred_to_gt': pred_to_gt,
'gt_indices': gt_indices,
'mean_iou': mean_iou,
}
def evaluate_detections(predictions, targets, coarse_iou=0.3, strict_iou=0.5):
"""Class-aware localization analysis with coarse and strict views."""
n_pred, n_gt = len(predictions), len(targets)
pred_labels = detection_labels(predictions)
gt_labels = detection_labels(targets)
if n_pred and n_gt:
pair_ious = sv.box_iou_batch(targets.xyxy, predictions.xyxy)
class_match = targets.class_id[:, None] == predictions.class_id[None, :]
same_class_ious = np.where(class_match, pair_ious, -1.0)
else:
pair_ious = np.zeros((n_gt, n_pred), dtype=np.float32)
class_match = np.zeros((n_gt, n_pred), dtype=bool)
same_class_ious = np.zeros((n_gt, n_pred), dtype=np.float32)
strict_matches = greedy_match(same_class_ious, strict_iou)
coarse_matches = greedy_match(same_class_ious, coarse_iou)
strict_stats = summarize_matches(strict_matches, n_pred, n_gt)
coarse_stats = summarize_matches(coarse_matches, n_pred, n_gt)
strict_tp_mask = np.zeros(n_pred, dtype=bool)
strict_tp_ious = np.zeros(n_pred, dtype=np.float32)
for gt_idx, pred_idx, iou in strict_matches:
strict_tp_mask[pred_idx] = True
strict_tp_ious[pred_idx] = iou
best_same_iou = np.zeros(n_pred, dtype=np.float32)
best_any_iou = np.zeros(n_pred, dtype=np.float32)
best_any_gt_idx = np.full(n_pred, -1, dtype=int)
if n_pred and n_gt:
best_any_gt_idx = pair_ious.argmax(axis=0)
best_any_iou = pair_ious[best_any_gt_idx, np.arange(n_pred)]
has_same = class_match.any(axis=0)
if has_same.any():
same_pred_cols = np.where(has_same)[0]
same_gt_idx = same_class_ious[:, same_pred_cols].argmax(axis=0)
best_same_iou[same_pred_cols] = same_class_ious[same_gt_idx, same_pred_cols]
pred_status = []
pred_display_labels = []
for pred_idx, label in enumerate(pred_labels):
if pred_idx in strict_stats['pred_to_iou']:
iou = strict_stats['pred_to_iou'][pred_idx]
pred_status.append('tp')
pred_display_labels.append(f'{label} (TP {iou:.2f})')
elif best_same_iou[pred_idx] >= coarse_iou:
pred_status.append('near')
pred_display_labels.append(f'{label} (near {best_same_iou[pred_idx]:.2f})')
elif best_any_iou[pred_idx] >= strict_iou and best_any_gt_idx[pred_idx] >= 0:
gt_label = gt_labels[best_any_gt_idx[pred_idx]]
pred_status.append('class')
pred_display_labels.append(f'{label} (vs {gt_label} {best_any_iou[pred_idx]:.2f})')
else:
pred_status.append('fp')
pred_display_labels.append(f'{label} (FP)')
status_counts = {
status: int(sum(s == status for s in pred_status))
for status in ['tp', 'near', 'class', 'fp']
}
return {
'tp': coarse_stats['tp'],
'n_pred': n_pred,
'n_gt': n_gt,
'precision': coarse_stats['precision'],
'recall': coarse_stats['recall'],
'f1': coarse_stats['f1'],
'strict_precision': strict_stats['precision'],
'strict_recall': strict_stats['recall'],
'strict_f1': strict_stats['f1'],
'mean_tp_iou': strict_stats['mean_iou'],
'strict_tp_mask': strict_tp_mask,
'tp_ious': strict_tp_ious,
'pred_status': np.asarray(pred_status, dtype=object),
'pred_display_labels': pred_display_labels,
'matched_gt_indices': coarse_stats['gt_indices'],
'strict_matched_gt_indices': strict_stats['gt_indices'],
'status_counts': status_counts,
}