from ultralytics import YOLO, checks, hub
import pandas as pd—author: Nipun Batrabadges: truecategories:- ML- computer-vision- object-detection- deep-learning- neural-networks- yolo- pytorchdate: ’2024-05-30’title: Object detectiontoc: true—
checks()Ultralytics YOLOv8.2.5 🚀 Python-3.11.7 torch-2.2.0+cu121 CUDA:0 (NVIDIA TITAN Xp, 12190MiB)
Setup complete ✅ (16 CPUs, 187.6 GB RAM, 38.5/184.8 GB disk)
from ultralytics import YOLO
# Load a model
#model = YOLO('yolov8n.yaml') # build a new model from scratch
model = YOLO('yolov8n.pt') import requests
from PIL import Image
import matplotlib.pyplot as plt
from ultralytics import YOLO
# Download the image locally
image_url = 'https://ultralytics.com/images/bus.jpg'
response = requests.get(image_url)
with open('bus.jpg', 'wb') as f:
f.write(response.content)# Load a model
model = YOLO('yolov8n.pt')
def show_image(img, confidence_threshold=0.5):
results = model(img)[0]
# Get the image and detections
image = Image.open('bus.jpg')
confidences = []
classes = []
for box in results.boxes:
box_coords = box.xyxy[0].cpu().numpy()
confidence = box.conf[0].cpu().numpy()
confidences.append(confidence.item())
class_name = model.names[int(box.cls)]
classes.append(class_name)
if confidence < confidence_threshold:
continue
plt.gca().add_patch(plt.Rectangle((box_coords[0], box_coords[1]),
box_coords[2]-box_coords[0],
box_coords[3]-box_coords[1],
fill=False, edgecolor='r', lw=2))
# Put label
plt.text(box_coords[0], box_coords[1], class_name, color='r')
plt.imshow(image, alpha=0.5)
return pd.Series(confidences, classes, )show_image('bus.jpg')
image 1/1 /home/nipun.batra/git/blog/posts/bus.jpg: 640x480 4 persons, 1 bus, 1 stop sign, 55.8ms
Speed: 3.2ms preprocess, 55.8ms inference, 1384.5ms postprocess per image at shape (1, 3, 640, 480)
bus 0.870545
person 0.868980
person 0.853604
person 0.819305
stop sign 0.346069
person 0.301294
dtype: float64

img = Image.open('bus.jpg')
res = model(img)[0]
0: 640x480 4 persons, 1 bus, 1 stop sign, 9.9ms
Speed: 20.6ms preprocess, 9.9ms inference, 2.1ms postprocess per image at shape (1, 3, 640, 480)
res.boxes[0].prob--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In[20], line 1 ----> 1 res.boxes[0].prob File ~/miniconda3/envs/ml/lib/python3.11/site-packages/ultralytics/utils/__init__.py:160, in SimpleClass.__getattr__(self, attr) 158 """Custom attribute access error message with helpful information.""" 159 name = self.__class__.__name__ --> 160 raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") AttributeError: 'Boxes' object has no attribute 'prob'. See valid attributes below. Manages detection boxes, providing easy access and manipulation of box coordinates, confidence scores, class identifiers, and optional tracking IDs. Supports multiple formats for box coordinates, including both absolute and normalized forms. Attributes: data (torch.Tensor): The raw tensor containing detection boxes and their associated data. orig_shape (tuple): The original image size as a tuple (height, width), used for normalization. is_track (bool): Indicates whether tracking IDs are included in the box data. Properties: xyxy (torch.Tensor | numpy.ndarray): Boxes in [x1, y1, x2, y2] format. conf (torch.Tensor | numpy.ndarray): Confidence scores for each box. cls (torch.Tensor | numpy.ndarray): Class labels for each box. id (torch.Tensor | numpy.ndarray, optional): Tracking IDs for each box, if available. xywh (torch.Tensor | numpy.ndarray): Boxes in [x, y, width, height] format, calculated on demand. xyxyn (torch.Tensor | numpy.ndarray): Normalized [x1, y1, x2, y2] boxes, relative to `orig_shape`. xywhn (torch.Tensor | numpy.ndarray): Normalized [x, y, width, height] boxes, relative to `orig_shape`. Methods: cpu(): Moves the boxes to CPU memory. numpy(): Converts the boxes to a numpy array format. cuda(): Moves the boxes to CUDA (GPU) memory. to(device, dtype=None): Moves the boxes to the specified device.
# Access the detection results
results = model(img)
boxes = results[0].boxes
# Print class probabilities for each detected object
for box in boxes:
class_probs = box.probs
print(f"Bounding Box: {box.xyxy}")
for class_id, prob in enumerate(class_probs):
print(f"Class ID: {class_id}, Probability: {prob:.4f}")
0: 640x480 4 persons, 1 bus, 1 stop sign, 8.5ms
Speed: 23.1ms preprocess, 8.5ms inference, 1.8ms postprocess per image at shape (1, 3, 640, 480)
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In[22], line 8 6 # Print class probabilities for each detected object 7 for box in boxes: ----> 8 class_probs = box.probs 9 print(f"Bounding Box: {box.xyxy}") 10 for class_id, prob in enumerate(class_probs): File ~/miniconda3/envs/ml/lib/python3.11/site-packages/ultralytics/utils/__init__.py:160, in SimpleClass.__getattr__(self, attr) 158 """Custom attribute access error message with helpful information.""" 159 name = self.__class__.__name__ --> 160 raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") AttributeError: 'Boxes' object has no attribute 'probs'. See valid attributes below. Manages detection boxes, providing easy access and manipulation of box coordinates, confidence scores, class identifiers, and optional tracking IDs. Supports multiple formats for box coordinates, including both absolute and normalized forms. Attributes: data (torch.Tensor): The raw tensor containing detection boxes and their associated data. orig_shape (tuple): The original image size as a tuple (height, width), used for normalization. is_track (bool): Indicates whether tracking IDs are included in the box data. Properties: xyxy (torch.Tensor | numpy.ndarray): Boxes in [x1, y1, x2, y2] format. conf (torch.Tensor | numpy.ndarray): Confidence scores for each box. cls (torch.Tensor | numpy.ndarray): Class labels for each box. id (torch.Tensor | numpy.ndarray, optional): Tracking IDs for each box, if available. xywh (torch.Tensor | numpy.ndarray): Boxes in [x, y, width, height] format, calculated on demand. xyxyn (torch.Tensor | numpy.ndarray): Normalized [x1, y1, x2, y2] boxes, relative to `orig_shape`. xywhn (torch.Tensor | numpy.ndarray): Normalized [x, y, width, height] boxes, relative to `orig_shape`. Methods: cpu(): Moves the boxes to CPU memory. numpy(): Converts the boxes to a numpy array format. cuda(): Moves the boxes to CUDA (GPU) memory. to(device, dtype=None): Moves the boxes to the specified device.
results[0].boxes[0].datatensor([[ 17.2858, 230.5922, 801.5182, 768.4058, 0.8705, 5.0000]], device='cuda:0')
modelYOLO(
(model): DetectionModel(
(model): Sequential(
(0): Conv(
(conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): C2f(
(cv1): Conv(
(conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0): Bottleneck(
(cv1): Conv(
(conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(3): Conv(
(conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(4): C2f(
(cv1): Conv(
(conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0-1): 2 x Bottleneck(
(cv1): Conv(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(5): Conv(
(conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(6): C2f(
(cv1): Conv(
(conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0-1): 2 x Bottleneck(
(cv1): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(7): Conv(
(conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(8): C2f(
(cv1): Conv(
(conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(384, 256, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0): Bottleneck(
(cv1): Conv(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(9): SPPF(
(cv1): Conv(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False)
)
(10): Upsample(scale_factor=2.0, mode='nearest')
(11): Concat()
(12): C2f(
(cv1): Conv(
(conv): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0): Bottleneck(
(cv1): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(13): Upsample(scale_factor=2.0, mode='nearest')
(14): Concat()
(15): C2f(
(cv1): Conv(
(conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0): Bottleneck(
(cv1): Conv(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(16): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(17): Concat()
(18): C2f(
(cv1): Conv(
(conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0): Bottleneck(
(cv1): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(19): Conv(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(20): Concat()
(21): C2f(
(cv1): Conv(
(conv): Conv2d(384, 256, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(384, 256, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0): Bottleneck(
(cv1): Conv(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(22): Detect(
(cv2): ModuleList(
(0): Sequential(
(0): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): Sequential(
(0): Conv(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
)
(2): Sequential(
(0): Conv(
(conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(cv3): ModuleList(
(0): Sequential(
(0): Conv(
(conv): Conv2d(64, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1))
)
(1): Sequential(
(0): Conv(
(conv): Conv2d(128, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1))
)
(2): Sequential(
(0): Conv(
(conv): Conv2d(256, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1))
)
)
(dfl): DFL(
(conv): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
)
)
)
import numpy as np
import os
import torch
from ultralytics import YOLO
from ultralytics.nn.modules.head import Detect
from ultralytics.utils import ops
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
from PIL import Image
import json
class SaveIO:
"""Simple PyTorch hook to save the output of a nn.module."""
def __init__(self):
self.input = None
self.output = None
def __call__(self, module, module_in, module_out):
self.input = module_in
self.output = module_out
def load_and_prepare_model(model_path):
# we are going to register a PyTorch hook on the important parts of the YOLO model,
# then reverse engineer the outputs to get boxes and logits
# first, we have to register the hooks to the model *before running inference*
# then, when inference is run, the hooks will save the inputs/outputs of their respective modules
model = YOLO(model_path)
detect = None
cv2_hooks = None
cv3_hooks = None
detect_hook = SaveIO()
for i, module in enumerate(model.model.modules()):
if type(module) is Detect:
module.register_forward_hook(detect_hook)
detect = module
cv2_hooks = [SaveIO() for _ in range(module.nl)]
cv3_hooks = [SaveIO() for _ in range(module.nl)]
for i in range(module.nl):
module.cv2[i].register_forward_hook(cv2_hooks[i])
module.cv3[i].register_forward_hook(cv3_hooks[i])
break
input_hook = SaveIO()
model.model.register_forward_hook(input_hook)
# save and return these for later
hooks = [input_hook, detect, detect_hook, cv2_hooks, cv3_hooks]
return model, hooks
def is_text_file(file_path):
# Check if the file extension indicates a text file
text_extensions = ['.txt'] #, '.csv', '.json', '.xml'] # Add more extensions if needed
return any(file_path.lower().endswith(ext) for ext in text_extensions)
def is_image_file(file_path):
# Check if the file extension indicates an image file
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp'] # Add more extensions if needed
return any(file_path.lower().endswith(ext) for ext in image_extensions)
def plot_image(img_path, results, category_mapping=None, suffix='test', show_labels=True, include_legend=True):
"""
Display the image with bounding boxes and their corresponding class scores.
Args:
img_path (str): Path to the image file.
results (list): List of dictionaries containing bounding box information.
category_mapping:
suffix: what to append to the original image name when saving
Returns:
None
"""
img = Image.open(img_path)
fig, ax = plt.subplots()
ax.imshow(img)
for box in results:
x0, y0, x1, y1 = map(int, box['bbox'])
box_color = "r" # red
tag_color = "k" # black
max_score = max(box['activations'])
max_category_id = box['activations'].index(max_score)
category_name = max_category_id
if category_mapping:
max_category_name = category_mapping.get(max_category_id, "Unknown")
category_name = max_category_name
rect = patches.Rectangle(
(x0, y0),
x1 - x0,
y1 - y0,
edgecolor=box_color,
label=f"{max_category_id}: {category_name} ({max_score:.2f})",
facecolor='none'
)
ax.add_patch(rect)
if show_labels:
plt.text(
x0,
y0 - 50,
f"{max_category_id} ({max_score:.2f})",
fontsize="5",
color=tag_color,
backgroundcolor=box_color,
)
if include_legend:
ax.legend(fontsize="5")
plt.axis("off")
plt.savefig(f'{os.path.basename(img_path).rsplit(".", 1)[0]}_{suffix}.jpg', bbox_inches="tight", dpi=300)
def write_json(results):
# Create a list to store the predictions data
predictions = []
for result in results:
image_id = os.path.basename(result['image_id'])#.split('.')[0]
# image_id = result["image_id"]
#image_id = os.path.basename(img_path).split('.')[0]
max_category_id = result['activations'].index(max(result['activations']))
category_id = max_category_id
bbox = result['bbox']
score = max(result['activations'])
activations = result['activations']
prediction = {
'image_id': image_id,
'category_id': category_id,
'bbox': bbox,
'score': score,
'activations': activations
}
predictions.append(prediction)
# Write the predictions list to a JSON file
with open('predictions.json', 'w') as f:
json.dump(predictions, f)
def calculate_iou(box1, box2):
"""
Calculates the Intersection over Union (IoU) between two bounding boxes.
Args:
box1 (list): Bounding box coordinates [x1, y1, w1, h1].
box2 (list): Bounding box coordinates [x2, y2, w2, h2].
Returns:
float: Intersection over Union (IoU) value.
"""
x1, y1, w1, h1 = box1
x2, y2, w2, h2 = box2
intersect_x1 = max(x1, x2)
intersect_y1 = max(y1, y2)
intersect_x2 = min(x1 + w1, x2 + w2)
intersect_y2 = min(y1 + h1, y2 + h2)
intersect_area = max(0, intersect_x2 - intersect_x1 + 1) * max(0, intersect_y2 - intersect_y1 + 1)
box1_area = w1 * h1
box2_area = w2 * h2
iou = intersect_area / float(box1_area + box2_area - intersect_area)
return iou
# Apply Non-Maximum Suppression
def nms(boxes, iou_threshold=0.7):
"""
Applies Non-Maximum Suppression (NMS) to a list of bounding box dictionaries.
Args:
boxes (list): List of dictionaries, each containing 'bbox', 'logits', and 'activations'.
iou_threshold (float, optional): Intersection over Union (IoU) threshold for NMS. Default is 0.7.
Returns:
list: List of selected bounding box dictionaries after NMS.
"""
# Sort boxes by confidence score in descending order
sorted_boxes = sorted(boxes, key=lambda x: max(x['activations']), reverse=True)
selected_boxes = []
# Keep the box with highest confidence and remove overlapping boxes
delete_idxs = []
for i, box0 in enumerate(sorted_boxes):
for j, box1 in enumerate(sorted_boxes):
if i < j and calculate_iou(box0['bbox'], box1['bbox']) > iou_threshold:
delete_idxs.append(j)
# Reverse the order of delete_idxs
delete_idxs.reverse()
# now delete by popping them in reverse order
filtered_boxes = [box for idx, box in enumerate(sorted_boxes) if idx not in delete_idxs]
return filtered_boxes
def results_predict(img_path, model, hooks, threshold=0.5, iou=0.7, save_image = False, category_mapping = None):
"""
Run prediction with a YOLO model and apply Non-Maximum Suppression (NMS) to the results.
Args:
img_path (str): Path to an image file.
model (YOLO): YOLO model object.
hooks (list): List of hooks for the model.
threshold (float, optional): Confidence threshold for detection. Default is 0.5.
iou (float, optional): Intersection over Union (IoU) threshold for NMS. Default is 0.7.
save_image (bool, optional): Whether to save the image with boxes plotted. Default is False.
Returns:
list: List of selected bounding box dictionaries after NMS.
"""
# unpack hooks from load_and_prepare_model()
input_hook, detect, detect_hook, cv2_hooks, cv3_hooks = hooks
# run inference; we don't actually need to store the results because
# the hooks store everything we need
model(img_path)
# now reverse engineer the outputs to find the logits
# see Detect.forward(): https://github.com/ultralytics/ultralytics/blob/b638c4ed9a24270a6875cdd47d9eeda99204ef5a/ultralytics/nn/modules/head.py#L22
shape = detect_hook.input[0][0].shape # BCHW
x = []
for i in range(detect.nl):
x.append(torch.cat((cv2_hooks[i].output, cv3_hooks[i].output), 1))
x_cat = torch.cat([xi.view(shape[0], detect.no, -1) for xi in x], 2)
box, cls = x_cat.split((detect.reg_max * 4, detect.nc), 1)
# assumes batch size = 1 (i.e. you are just running with one image)
# if you want to run with many images, throw this in a loop
batch_idx = 0
xywh_sigmoid = detect_hook.output[0][batch_idx]
all_logits = cls[batch_idx]
# figure out the original img shape and model img shape so we can transform the boxes
img_shape = input_hook.input[0].shape[2:]
orig_img_shape = model.predictor.batch[1][batch_idx].shape[:2]
# compute predictions
boxes = []
for i in range(xywh_sigmoid.shape[-1]): # for each predicted box...
x0, y0, x1, y1, *class_probs_after_sigmoid = xywh_sigmoid[:,i]
x0, y0, x1, y1 = ops.scale_boxes(img_shape, np.array([x0.cpu(), y0.cpu(), x1.cpu(), y1.cpu()]), orig_img_shape)
logits = all_logits[:,i]
boxes.append({
'image_id': img_path,
'bbox': [x0.item(), y0.item(), x1.item(), y1.item()], # xyxy
'bbox_xywh': [(x0.item() + x1.item())/2, (y0.item() + y1.item())/2, x1.item() - x0.item(), y1.item() - y0.item()],
'logits': logits.cpu().tolist(),
'activations': [p.item() for p in class_probs_after_sigmoid]
})
# for debugging
# top10 = sorted(boxes, key=lambda x: max(x['activations']), reverse=True)[:10]
# plot_image(img_path, top10, suffix="before_nms")
# NMS
# we can keep the activations and logits around via the YOLOv8 NMS method, but only if we
# append them as an additional time to the prediction vector. It's a weird hacky way to do it, but
# it works. We also have to pass in the num classes (nc) parameter to make it work.
boxes_for_nms = torch.stack([
torch.tensor([*b['bbox_xywh'], *b['activations'], *b['activations'], *b['logits']]) for b in boxes
], dim=1).unsqueeze(0)
# do the NMS
nms_results = ops.non_max_suppression(boxes_for_nms, conf_thres=threshold, iou_thres=iou, nc=detect.nc)[0]
# unpack it and return it
boxes = []
for b in range(nms_results.shape[0]):
box = nms_results[b, :]
x0, y0, x1, y1, conf, cls, *acts_and_logits = box
activations = acts_and_logits[:detect.nc]
logits = acts_and_logits[detect.nc:]
box_dict = {
'bbox': [x0.item(), y0.item(), x1.item(), y1.item()], # xyxy
'bbox_xywh': [(x0.item() + x1.item())/2, (y0.item() + y1.item())/2, x1.item() - x0.item(), y1.item() - y0.item()],
'best_conf': conf.item(),
'best_cls': cls.item(),
'image_id': img_path,
'activations': [p.item() for p in activations],
'logits': [p.item() for p in logits]
}
boxes.append(box_dict)
return boxes
def run_predict(input_path, model, hooks, score_threshold=0.5, iou_threshold=0.7, save_image = False, save_json = False, category_mapping = None):
"""
Run prediction with a YOLO model.
Args:
input_path (str): Path to an image file or txt file containing paths to image files.
model (YOLO): YOLO model object.
hooks (list): List of hooks for the model.
threshold (float, optional): Confidence threshold for detection. Default is 0.5.
iou_threshold (float, optional): Intersection over Union (IoU) threshold for NMS. Default is 0.7.
save_image (bool, optional): Whether to save the image with boxes plotted. Default is False.
save_json (bool, optional): Whether to save the results in a json file. Default is False.
Returns:
list: List of selected bounding box dictionaries for all the images given as input.
"""
use_txt_input = False
if is_text_file(input_path):
use_txt_input = True
if use_txt_input:
with open(input_path, 'r') as f:
img_paths = f.read().splitlines()
else:
img_paths = [input_path]
all_results = []
for img_path in img_paths:
results = results_predict(img_path, model, hooks, score_threshold, iou=iou_threshold, save_image=save_image, category_mapping=category_mapping)
all_results.extend(results)
if save_json:
write_json(all_results)
return all_results
### Start example script here ###
### (This shows how to use the methods in this file) ###
# change these, of course :)
SAVE_TEST_IMG = True
model_path = 'yolov8n.pt'
img_path = 'bus.jpg'
threshold = 0.5
nms_threshold = 0.7
# load the model
model, hooks = load_and_prepare_model(model_path)
# run inference
results = run_predict(img_path, model, hooks, score_threshold=threshold, iou_threshold=nms_threshold)
print("Processed", len(results), "boxes")
print("The first one is", results[0])
if SAVE_TEST_IMG:
plot_image(img_path, results)
image 1/1 /home/nipun.batra/git/blog/posts/bus.jpg: 640x480 4 persons, 1 bus, 1 stop sign, 6.9ms
Speed: 2.5ms preprocess, 6.9ms inference, 1.4ms postprocess per image at shape (1, 3, 640, 480)
Processed 4 boxes
The first one is {'bbox': [17.28582763671875, 230.59222412109375, 801.5181884765625, 768.4058227539062], 'bbox_xywh': [409.4020080566406, 499.4990234375, 784.2323608398438, 537.8135986328125], 'best_conf': 0.8705450892448425, 'best_cls': 5.0, 'image_id': 'bus.jpg', 'activations': [1.0311242476745974e-05, 9.112093721341807e-06, 0.0003611394204199314, 6.673172902083024e-05, 0.00022987156989984214, 0.8705450892448425, 0.0009488395880907774, 0.007086973637342453, 0.0001582797704031691, 9.052685527422e-06, 4.503397406097065e-07, 1.0215371730737388e-05, 1.730335316096898e-05, 4.165742666373262e-06, 5.312303414939379e-07, 1.9164923514836119e-07, 9.471239650338248e-07, 4.973830982635263e-06, 2.5008853654640006e-08, 4.033859681840113e-07, 0.0005047211889177561, 6.792038220737595e-07, 4.306784262553265e-07, 9.298752411268651e-05, 4.2397550714667886e-07, 6.426890467992052e-05, 3.322290240248549e-06, 1.4765126934435102e-06, 2.1423567886813544e-05, 1.0903042948484654e-06, 2.7694509299180936e-06, 3.31359643723772e-07, 1.1906087138413568e-06, 5.6864570069592446e-05, 5.980197329336079e-06, 9.176978892355692e-07, 6.090281658543972e-06, 8.041779437917285e-06, 4.321082087699324e-05, 2.43568738369504e-05, 4.93736079079099e-05, 1.2213927220727783e-05, 3.2120142350322567e-06, 5.898335757592577e-07, 6.751507157787273e-07, 1.514082214271184e-06, 7.679868474497198e-08, 5.769239237451984e-07, 2.8340421067696298e-06, 4.807910954696126e-06, 2.058711316976769e-07, 2.196402647314244e-06, 2.783847776299808e-06, 7.796339195920154e-05, 1.7758512171894836e-07, 2.9252541366986407e-07, 4.3183670641155913e-05, 8.449018423561938e-06, 1.572124688209442e-06, 4.166020971752005e-06, 1.1690817700582556e-05, 1.4325351571642386e-07, 3.190975985489786e-05, 1.2439878673831117e-06, 9.514394605503185e-07, 9.426859719496861e-07, 5.968039204162778e-06, 8.346189133590087e-05, 2.026550646405667e-05, 5.310462256602477e-06, 9.205886271956842e-06, 4.1557203900310924e-08, 4.2926520109176636e-05, 2.642763683979865e-05, 2.686474908841774e-05, 2.5954559532692656e-05, 6.725371122229262e-07, 5.926154926783056e-07, 1.19427056688437e-06, 2.642658500917605e-07], 'logits': [-11.48226547241211, -11.6058988571167, -7.925885200500488, -9.614763259887695, -8.37775993347168, 1.9057865142822266, -6.95932149887085, -4.942384719848633, -8.750988006591797, -11.61244010925293, -14.613263130187988, -11.491606712341309, -10.964592933654785, -12.388611793518066, -14.44806957244873, -15.467598915100098, -13.869834899902344, -12.211315155029297, -17.50403594970703, -14.723371505737305, -7.590999603271484, -14.202343940734863, -14.657903671264648, -9.282952308654785, -14.673589706420898, -9.65237045288086, -12.614852905273438, -13.425826072692871, -10.750997543334961, -13.729052543640137, -12.796858787536621, -14.920061111450195, -13.641044616699219, -9.774781227111816, -12.027050971984863, -13.901396751403809, -12.008810043334961, -11.730852127075195, -10.049376487731934, -10.622672080993652, -9.916045188903809, -11.312921524047852, -12.648609161376953, -14.343424797058105, -14.208329200744629, -13.400699615478516, -16.382078170776367, -14.365554809570312, -12.7738037109375, -12.245243072509766, -15.396015167236328, -13.028687477111816, -12.79167366027832, -9.459193229675293, -15.543815612792969, -15.044713973999023, -10.050004959106445, -11.681451797485352, -13.363080978393555, -12.388545036315918, -11.356695175170898, -15.758649826049805, -10.352566719055176, -13.597187042236328, -13.865288734436035, -13.874531745910645, -12.029086112976074, -9.391036987304688, -10.806570053100586, -12.14582633972168, -11.595658302307129, -16.99619483947754, -10.055977821350098, -10.5410737991333, -10.52466869354248, -10.559137344360352, -14.212207794189453, -14.338719367980957, -13.63797378540039, -15.146309852600098]}

results[{'bbox': [17.28582763671875,
230.59222412109375,
801.5181884765625,
768.4058227539062],
'bbox_xywh': [409.4020080566406,
499.4990234375,
784.2323608398438,
537.8135986328125],
'best_conf': 0.8705450892448425,
'best_cls': 5.0,
'image_id': 'bus.jpg',
'activations': [1.0311242476745974e-05,
9.112093721341807e-06,
0.0003611394204199314,
6.673172902083024e-05,
0.00022987156989984214,
0.8705450892448425,
0.0009488395880907774,
0.007086973637342453,
0.0001582797704031691,
9.052685527422e-06,
4.503397406097065e-07,
1.0215371730737388e-05,
1.730335316096898e-05,
4.165742666373262e-06,
5.312303414939379e-07,
1.9164923514836119e-07,
9.471239650338248e-07,
4.973830982635263e-06,
2.5008853654640006e-08,
4.033859681840113e-07,
0.0005047211889177561,
6.792038220737595e-07,
4.306784262553265e-07,
9.298752411268651e-05,
4.2397550714667886e-07,
6.426890467992052e-05,
3.322290240248549e-06,
1.4765126934435102e-06,
2.1423567886813544e-05,
1.0903042948484654e-06,
2.7694509299180936e-06,
3.31359643723772e-07,
1.1906087138413568e-06,
5.6864570069592446e-05,
5.980197329336079e-06,
9.176978892355692e-07,
6.090281658543972e-06,
8.041779437917285e-06,
4.321082087699324e-05,
2.43568738369504e-05,
4.93736079079099e-05,
1.2213927220727783e-05,
3.2120142350322567e-06,
5.898335757592577e-07,
6.751507157787273e-07,
1.514082214271184e-06,
7.679868474497198e-08,
5.769239237451984e-07,
2.8340421067696298e-06,
4.807910954696126e-06,
2.058711316976769e-07,
2.196402647314244e-06,
2.783847776299808e-06,
7.796339195920154e-05,
1.7758512171894836e-07,
2.9252541366986407e-07,
4.3183670641155913e-05,
8.449018423561938e-06,
1.572124688209442e-06,
4.166020971752005e-06,
1.1690817700582556e-05,
1.4325351571642386e-07,
3.190975985489786e-05,
1.2439878673831117e-06,
9.514394605503185e-07,
9.426859719496861e-07,
5.968039204162778e-06,
8.346189133590087e-05,
2.026550646405667e-05,
5.310462256602477e-06,
9.205886271956842e-06,
4.1557203900310924e-08,
4.2926520109176636e-05,
2.642763683979865e-05,
2.686474908841774e-05,
2.5954559532692656e-05,
6.725371122229262e-07,
5.926154926783056e-07,
1.19427056688437e-06,
2.642658500917605e-07],
'logits': [-11.48226547241211,
-11.6058988571167,
-7.925885200500488,
-9.614763259887695,
-8.37775993347168,
1.9057865142822266,
-6.95932149887085,
-4.942384719848633,
-8.750988006591797,
-11.61244010925293,
-14.613263130187988,
-11.491606712341309,
-10.964592933654785,
-12.388611793518066,
-14.44806957244873,
-15.467598915100098,
-13.869834899902344,
-12.211315155029297,
-17.50403594970703,
-14.723371505737305,
-7.590999603271484,
-14.202343940734863,
-14.657903671264648,
-9.282952308654785,
-14.673589706420898,
-9.65237045288086,
-12.614852905273438,
-13.425826072692871,
-10.750997543334961,
-13.729052543640137,
-12.796858787536621,
-14.920061111450195,
-13.641044616699219,
-9.774781227111816,
-12.027050971984863,
-13.901396751403809,
-12.008810043334961,
-11.730852127075195,
-10.049376487731934,
-10.622672080993652,
-9.916045188903809,
-11.312921524047852,
-12.648609161376953,
-14.343424797058105,
-14.208329200744629,
-13.400699615478516,
-16.382078170776367,
-14.365554809570312,
-12.7738037109375,
-12.245243072509766,
-15.396015167236328,
-13.028687477111816,
-12.79167366027832,
-9.459193229675293,
-15.543815612792969,
-15.044713973999023,
-10.050004959106445,
-11.681451797485352,
-13.363080978393555,
-12.388545036315918,
-11.356695175170898,
-15.758649826049805,
-10.352566719055176,
-13.597187042236328,
-13.865288734436035,
-13.874531745910645,
-12.029086112976074,
-9.391036987304688,
-10.806570053100586,
-12.14582633972168,
-11.595658302307129,
-16.99619483947754,
-10.055977821350098,
-10.5410737991333,
-10.52466869354248,
-10.559137344360352,
-14.212207794189453,
-14.338719367980957,
-13.63797378540039,
-15.146309852600098]},
{'bbox': [48.739479064941406,
399.263916015625,
244.50173950195312,
902.501220703125],
'bbox_xywh': [146.62060928344727,
650.882568359375,
195.76226043701172,
503.2373046875],
'best_conf': 0.8689801096916199,
'best_cls': 0.0,
'image_id': 'bus.jpg',
'activations': [0.8689801096916199,
2.159083720698618e-07,
7.267921660059073e-07,
4.954289352099295e-07,
3.9690485209575854e-06,
3.6866176742478274e-06,
1.924761590998969e-06,
5.504815590029466e-07,
8.782559461906203e-07,
1.8101994783137343e-06,
1.012005441225483e-06,
1.725299512145284e-07,
1.9031394913326949e-06,
1.7214931347098172e-07,
2.6558041099633556e-06,
1.7690776985546108e-06,
4.047087259095861e-06,
1.6409081808888004e-06,
7.941423518786905e-07,
1.1039174978577648e-06,
1.4459410522249527e-06,
1.3210571978561347e-06,
4.1252138771596947e-07,
1.2257790331204887e-05,
9.701152521301992e-07,
5.426437610367429e-07,
4.6731869929317327e-07,
1.5046108501337585e-06,
6.04137326831733e-08,
3.9817462038627127e-07,
1.8838558162315167e-06,
1.1097992000941304e-06,
1.1555634955584537e-06,
7.810695024090819e-06,
2.828919605235569e-06,
2.9941506340946944e-07,
5.739046855524066e-07,
1.1345385928507312e-06,
4.880649271399307e-07,
2.7935711841564626e-06,
1.2807398661607294e-06,
3.5370536011214426e-07,
3.775756454160728e-07,
2.3527961729996605e-06,
9.562232889948064e-07,
4.1244192061640206e-07,
6.125056870587287e-07,
2.33453351938806e-06,
5.989752480672905e-07,
1.124500272453588e-06,
9.676870149633032e-07,
1.5129870007513091e-05,
3.98353768105153e-06,
6.566142474184744e-06,
6.885069296913571e-07,
5.487340786203276e-07,
7.767624765619985e-07,
5.031957925893948e-07,
2.001876282520243e-06,
9.961429441318614e-07,
1.2013529158139136e-06,
4.2538695765870216e-07,
2.504941676306771e-06,
2.0721746807339514e-07,
2.624780393034598e-07,
7.277583904397034e-07,
4.963558453141559e-08,
1.2240196838320117e-06,
3.9281539443436486e-07,
1.3602493709186092e-06,
2.1281984174947866e-07,
9.780134746506519e-08,
1.0565305501586408e-06,
4.288276045372186e-07,
3.775760148982954e-07,
2.0054291383075906e-07,
1.2226444141560933e-06,
1.0907050636888016e-05,
1.6243636764556868e-06,
5.100530984236684e-07],
'logits': [1.8919711112976074,
-15.348411560058594,
-14.134624481201172,
-14.517841339111328,
-12.436980247497559,
-12.510797500610352,
-13.160706520080566,
-14.412471771240234,
-13.945326805114746,
-13.222071647644043,
-13.80357551574707,
-15.572694778442383,
-13.172003746032715,
-15.57490348815918,
-12.838760375976562,
-13.245050430297852,
-12.417509078979492,
-13.320259094238281,
-14.046002388000488,
-13.716644287109375,
-13.446748733520508,
-13.537076950073242,
-14.700977325439453,
-11.30933666229248,
-13.845849990844727,
-14.426812171936035,
-14.576253890991211,
-13.406974792480469,
-16.62204933166504,
-14.736374855041504,
-13.182188034057617,
-13.71133041381836,
-13.670921325683594,
-11.760008811950684,
-12.775612831115723,
-15.021434783935547,
-14.37080192565918,
-13.68928337097168,
-14.532816886901855,
-12.788187026977539,
-13.568071365356445,
-14.854801177978516,
-14.789494514465332,
-12.959903717041016,
-13.860273361206055,
-14.701169967651367,
-14.305706977844238,
-12.967696189880371,
-14.328044891357422,
-13.69817066192627,
-13.848356246948242,
-11.098824501037598,
-12.43333625793457,
-11.933577537536621,
-14.188739776611328,
-14.415651321411133,
-14.068130493164062,
-14.502285957336426,
-13.121423721313477,
-13.819374084472656,
-13.632061004638672,
-14.670266151428223,
-12.897242546081543,
-15.389496803283691,
-15.153098106384277,
-14.133296012878418,
-16.818557739257812,
-13.61336898803711,
-14.74992561340332,
-13.507841110229492,
-15.36281967163086,
-16.14032745361328,
-13.760519027709961,
-14.662210464477539,
-14.789493560791016,
-15.422237396240234,
-13.614493370056152,
-11.426090240478516,
-13.330392837524414,
-14.488750457763672]},
{'bbox': [670.269287109375,
380.2840270996094,
809.858154296875,
875.6907958984375],
'bbox_xywh': [740.063720703125,
627.9874114990234,
139.5888671875,
495.4067687988281],
'best_conf': 0.8536035418510437,
'best_cls': 0.0,
'image_id': 'bus.jpg',
'activations': [0.8536035418510437,
8.206617394534987e-07,
1.2038118484269944e-06,
1.1462942666184972e-06,
1.5739009313620045e-06,
3.3411379263270646e-06,
3.73882085114019e-06,
4.4993788606007e-07,
1.2424654869391816e-06,
3.6132853438175516e-06,
2.934745680249762e-06,
1.105569964465758e-07,
1.0381992296970566e-06,
5.935368676546204e-07,
1.4730724160472164e-06,
3.304354777355911e-06,
4.168759005551692e-06,
3.27019665746775e-06,
5.449832656267972e-07,
5.046072146797087e-07,
1.3525707345252158e-06,
9.852926723397104e-07,
1.1722339650077629e-06,
7.1923359428183176e-06,
2.579861757112667e-06,
7.465223461622372e-07,
1.0459139048180077e-06,
1.3547731896323967e-06,
3.88375610782532e-07,
3.233400320823421e-07,
1.6130819631143822e-06,
8.577525818509457e-07,
1.0866494903893908e-06,
3.3366293337167008e-06,
2.0756035610247636e-06,
3.148159635202319e-07,
6.232777991499461e-07,
7.360301310654904e-07,
9.060752290679375e-07,
2.585813035693718e-06,
1.3577079016613425e-06,
4.043033925427153e-07,
3.105890868937422e-07,
1.355843437522708e-06,
3.1056660532158276e-07,
1.5147954002259212e-07,
3.9362166148748656e-07,
1.426429321327305e-06,
1.4964024330765824e-07,
4.153795600814192e-07,
3.13703480969707e-07,
2.565983777458314e-06,
1.2322219617999508e-06,
9.984036068999558e-07,
1.9590237343436456e-07,
4.869207828051003e-07,
2.924327418440953e-06,
1.931917267938843e-06,
7.152341822802555e-06,
5.519689239008585e-06,
1.5743992207717383e-06,
2.096727712341817e-06,
2.2192562028067186e-06,
3.0115475624370447e-07,
3.523186933307443e-07,
3.960224432830728e-07,
3.5972377077087e-08,
8.289633797176066e-07,
3.8832413906675356e-07,
8.207001087612298e-07,
1.5001531039615656e-07,
8.185835298490929e-08,
1.6670093145876308e-06,
2.5175890527862066e-07,
3.8390692225220846e-07,
7.335270879593736e-07,
1.5532028783127316e-06,
1.81348677870119e-05,
1.4697145616082707e-06,
5.530769158212934e-07],
'logits': [1.763148307800293,
-14.013154029846191,
-13.630016326904297,
-13.678975105285645,
-13.36195182800293,
-12.609195709228516,
-12.496736526489258,
-14.614155769348145,
-13.598411560058594,
-12.530889511108398,
-12.738886833190918,
-16.01773452758789,
-13.778021812438965,
-14.337165832519531,
-13.4281587600708,
-12.62026596069336,
-12.387887954711914,
-12.630657196044922,
-14.422510147094727,
-14.49948501586914,
-13.51350212097168,
-13.830326080322266,
-13.656598091125488,
-11.842487335205078,
-12.867772102355957,
-14.107839584350586,
-13.770618438720703,
-13.51187515258789,
-14.761292457580566,
-14.944561004638672,
-13.337362289428711,
-13.968949317932129,
-13.732410430908203,
-12.610546112060547,
-13.085256576538086,
-14.971277236938477,
-14.288272857666016,
-14.121994018554688,
-13.914142608642578,
-12.86546802520752,
-13.509711265563965,
-14.721099853515625,
-14.984794616699219,
-13.511085510253906,
-14.984867095947266,
-15.702815055847168,
-14.747875213623047,
-13.460334777832031,
-15.715031623840332,
-14.694072723388672,
-14.974817276000977,
-12.87316608428955,
-13.606690406799316,
-13.817107200622559,
-15.445649147033691,
-14.535163879394531,
-12.742443084716797,
-13.15699577331543,
-11.848063468933105,
-12.107183456420898,
-13.361635208129883,
-13.075130462646484,
-13.018336296081543,
-15.015641212463379,
-14.858729362487793,
-14.74179458618164,
-17.140514373779297,
-14.00308895111084,
-14.761425018310547,
-14.013107299804688,
-15.712528228759766,
-16.318275451660156,
-13.30447769165039,
-15.194793701171875,
-14.772865295410156,
-14.12540054321289,
-13.375189781188965,
-10.917655944824219,
-13.430440902709961,
-14.407768249511719]},
{'bbox': [221.39376831054688,
405.79168701171875,
344.7171936035156,
857.3920288085938],
'bbox_xywh': [283.05548095703125,
631.5918579101562,
123.32342529296875,
451.600341796875],
'best_conf': 0.8193051218986511,
'best_cls': 0.0,
'image_id': 'bus.jpg',
'activations': [0.8193051218986511,
4.0964692971101613e-07,
2.936613100246177e-06,
1.382118284709577e-06,
1.6149664588738233e-05,
8.412195711571258e-06,
9.894216645989218e-07,
1.2228890682308702e-06,
1.2646942195715383e-06,
1.9066765162278898e-06,
6.240153425096651e-07,
7.205974128510206e-08,
2.2608073777519166e-06,
1.6783783962637244e-07,
4.693410573963774e-06,
3.116812195003149e-06,
7.556559467047919e-06,
3.5700793432624778e-06,
8.197254715014424e-07,
1.5916730262688361e-06,
2.326829417143017e-06,
1.0168097333007609e-06,
2.1919508697010315e-07,
4.784199973073555e-06,
1.3443328725770698e-06,
4.94625453484332e-07,
4.1110206439043395e-07,
6.275794248722377e-07,
1.2701826790362247e-07,
3.41750705956656e-07,
1.7371836520396755e-06,
1.0280481319568935e-06,
1.6011792922654422e-06,
5.354379027266987e-06,
4.389916284708306e-06,
3.4768410728247545e-07,
1.3610668929686653e-06,
1.0928538358712103e-06,
7.004660460552259e-07,
3.702946060002432e-06,
2.8141175789642148e-06,
6.319215231087583e-07,
4.0783669419397484e-07,
1.7733005961417803e-06,
1.152611616817012e-06,
2.805371650538291e-07,
8.198293812711199e-07,
3.066387762373779e-06,
8.176439223461784e-07,
6.666024887636013e-07,
1.3133955008015619e-06,
9.064304322237149e-06,
3.986513547715731e-06,
2.927407876995858e-06,
2.5427269179090217e-07,
4.65552773221134e-07,
2.3654131382500054e-06,
7.987733283698617e-07,
1.5470428706976236e-06,
5.12271185471036e-07,
4.3607190036709653e-07,
5.852278945894795e-07,
1.5100877135409974e-06,
3.0730282674085174e-07,
5.032850936004252e-07,
7.947044764478051e-07,
6.477876013377681e-08,
2.3970601432665717e-06,
1.9725838740214385e-07,
3.763364588849072e-07,
2.017427789269277e-07,
1.4345364718337805e-07,
8.217072036131867e-07,
3.7514539030780725e-07,
2.626693742513453e-07,
2.2584059422570135e-07,
9.443388080399018e-07,
1.2512692592281383e-05,
2.4881660465325695e-06,
7.62891602335003e-07],
'logits': [1.5116467475891113,
-14.707969665527344,
-12.738250732421875,
-13.491891860961914,
-11.033595085144043,
-11.685819625854492,
-13.826144218444824,
-13.614293098449707,
-13.580678939819336,
-13.170146942138672,
-14.287090301513672,
-16.445770263671875,
-12.999786376953125,
-15.60026741027832,
-12.269346237182617,
-12.678696632385254,
-11.793087005615234,
-12.542919158935547,
-14.01429557800293,
-13.350723266601562,
-12.971001625061035,
-13.798839569091797,
-15.333303451538086,
-12.250186920166016,
-13.519611358642578,
-14.519464492797852,
-14.704423904418945,
-14.281394958496094,
-15.878934860229492,
-14.88918399810791,
-13.263243675231934,
-13.787847518920898,
-13.344768524169922,
-12.137590408325195,
-12.336195945739746,
-14.871971130371094,
-13.507240295410156,
-13.726716995239258,
-14.17151927947998,
-12.506378173828125,
-12.780858993530273,
-14.274499893188477,
-14.712398529052734,
-13.242666244506836,
-13.673479080200195,
-15.086559295654297,
-14.014168739318848,
-12.69500732421875,
-14.016838073730469,
-14.221071243286133,
-13.542893409729004,
-11.611157417297363,
-12.432589530944824,
-12.741390228271484,
-15.184858322143555,
-14.580039978027344,
-12.95455551147461,
-14.04018783569336,
-13.37916374206543,
-14.484411239624023,
-14.645458221435547,
-14.351263999938965,
-13.403341293334961,
-14.995431900024414,
-14.502108573913574,
-14.045294761657715,
-16.552288055419922,
-12.941265106201172,
-15.438751220703125,
-14.792781829833984,
-15.416272163391113,
-15.757253646850586,
-14.011880874633789,
-14.795951843261719,
-15.152369499206543,
-15.303436279296875,
-13.872779846191406,
-11.2887544631958,
-12.903962135314941,
-14.086149215698242]}]