from ultralytics import YOLO, checks, hub
import pandas as pd
checks()
Ultralytics YOLOv8.2.5 🚀 Python-3.11.7 torch-2.2.0+cu121 CUDA:0 (NVIDIA TITAN Xp, 12190MiB)
Setup complete ✅ (16 CPUs, 187.6 GB RAM, 38.5/184.8 GB disk)
from ultralytics import YOLO
# Load a model
#model = YOLO('yolov8n.yaml') # build a new model from scratch
= YOLO('yolov8n.pt') model
import requests
from PIL import Image
import matplotlib.pyplot as plt
from ultralytics import YOLO
# Download the image locally
= 'https://ultralytics.com/images/bus.jpg'
image_url = requests.get(image_url)
response with open('bus.jpg', 'wb') as f:
f.write(response.content)
# Load a model
= YOLO('yolov8n.pt') model
def show_image(img, confidence_threshold=0.5):
= model(img)[0]
results
# Get the image and detections
= Image.open('bus.jpg')
image = []
confidences = []
classes for box in results.boxes:
= box.xyxy[0].cpu().numpy()
box_coords = box.conf[0].cpu().numpy()
confidence
confidences.append(confidence.item())= model.names[int(box.cls)]
class_name
classes.append(class_name)if confidence < confidence_threshold:
continue
0], box_coords[1]),
plt.gca().add_patch(plt.Rectangle((box_coords[2]-box_coords[0],
box_coords[3]-box_coords[1],
box_coords[=False, edgecolor='r', lw=2))
fill# Put label
0], box_coords[1], class_name, color='r')
plt.text(box_coords[
=0.5)
plt.imshow(image, alphareturn pd.Series(confidences, classes, )
'bus.jpg') show_image(
image 1/1 /home/nipun.batra/git/blog/posts/bus.jpg: 640x480 4 persons, 1 bus, 1 stop sign, 55.8ms
Speed: 3.2ms preprocess, 55.8ms inference, 1384.5ms postprocess per image at shape (1, 3, 640, 480)
bus 0.870545
person 0.868980
person 0.853604
person 0.819305
stop sign 0.346069
person 0.301294
dtype: float64
= Image.open('bus.jpg')
img = model(img)[0] res
0: 640x480 4 persons, 1 bus, 1 stop sign, 9.9ms
Speed: 20.6ms preprocess, 9.9ms inference, 2.1ms postprocess per image at shape (1, 3, 640, 480)
0].prob res.boxes[
AttributeError: 'Boxes' object has no attribute 'prob'. See valid attributes below.
Manages detection boxes, providing easy access and manipulation of box coordinates, confidence scores, class
identifiers, and optional tracking IDs. Supports multiple formats for box coordinates, including both absolute and
normalized forms.
Attributes:
data (torch.Tensor): The raw tensor containing detection boxes and their associated data.
orig_shape (tuple): The original image size as a tuple (height, width), used for normalization.
is_track (bool): Indicates whether tracking IDs are included in the box data.
Properties:
xyxy (torch.Tensor | numpy.ndarray): Boxes in [x1, y1, x2, y2] format.
conf (torch.Tensor | numpy.ndarray): Confidence scores for each box.
cls (torch.Tensor | numpy.ndarray): Class labels for each box.
id (torch.Tensor | numpy.ndarray, optional): Tracking IDs for each box, if available.
xywh (torch.Tensor | numpy.ndarray): Boxes in [x, y, width, height] format, calculated on demand.
xyxyn (torch.Tensor | numpy.ndarray): Normalized [x1, y1, x2, y2] boxes, relative to `orig_shape`.
xywhn (torch.Tensor | numpy.ndarray): Normalized [x, y, width, height] boxes, relative to `orig_shape`.
Methods:
cpu(): Moves the boxes to CPU memory.
numpy(): Converts the boxes to a numpy array format.
cuda(): Moves the boxes to CUDA (GPU) memory.
to(device, dtype=None): Moves the boxes to the specified device.
# Access the detection results
= model(img)
results
= results[0].boxes
boxes
# Print class probabilities for each detected object
for box in boxes:
= box.probs
class_probs print(f"Bounding Box: {box.xyxy}")
for class_id, prob in enumerate(class_probs):
print(f"Class ID: {class_id}, Probability: {prob:.4f}")
0: 640x480 4 persons, 1 bus, 1 stop sign, 8.5ms
Speed: 23.1ms preprocess, 8.5ms inference, 1.8ms postprocess per image at shape (1, 3, 640, 480)
AttributeError: 'Boxes' object has no attribute 'probs'. See valid attributes below.
Manages detection boxes, providing easy access and manipulation of box coordinates, confidence scores, class
identifiers, and optional tracking IDs. Supports multiple formats for box coordinates, including both absolute and
normalized forms.
Attributes:
data (torch.Tensor): The raw tensor containing detection boxes and their associated data.
orig_shape (tuple): The original image size as a tuple (height, width), used for normalization.
is_track (bool): Indicates whether tracking IDs are included in the box data.
Properties:
xyxy (torch.Tensor | numpy.ndarray): Boxes in [x1, y1, x2, y2] format.
conf (torch.Tensor | numpy.ndarray): Confidence scores for each box.
cls (torch.Tensor | numpy.ndarray): Class labels for each box.
id (torch.Tensor | numpy.ndarray, optional): Tracking IDs for each box, if available.
xywh (torch.Tensor | numpy.ndarray): Boxes in [x, y, width, height] format, calculated on demand.
xyxyn (torch.Tensor | numpy.ndarray): Normalized [x1, y1, x2, y2] boxes, relative to `orig_shape`.
xywhn (torch.Tensor | numpy.ndarray): Normalized [x, y, width, height] boxes, relative to `orig_shape`.
Methods:
cpu(): Moves the boxes to CPU memory.
numpy(): Converts the boxes to a numpy array format.
cuda(): Moves the boxes to CUDA (GPU) memory.
to(device, dtype=None): Moves the boxes to the specified device.
0].boxes[0].data results[
tensor([[ 17.2858, 230.5922, 801.5182, 768.4058, 0.8705, 5.0000]], device='cuda:0')
model
YOLO(
(model): DetectionModel(
(model): Sequential(
(0): Conv(
(conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): C2f(
(cv1): Conv(
(conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0): Bottleneck(
(cv1): Conv(
(conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(3): Conv(
(conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(4): C2f(
(cv1): Conv(
(conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0-1): 2 x Bottleneck(
(cv1): Conv(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(5): Conv(
(conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(6): C2f(
(cv1): Conv(
(conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0-1): 2 x Bottleneck(
(cv1): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(7): Conv(
(conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(8): C2f(
(cv1): Conv(
(conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(384, 256, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0): Bottleneck(
(cv1): Conv(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(9): SPPF(
(cv1): Conv(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False)
)
(10): Upsample(scale_factor=2.0, mode='nearest')
(11): Concat()
(12): C2f(
(cv1): Conv(
(conv): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0): Bottleneck(
(cv1): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(13): Upsample(scale_factor=2.0, mode='nearest')
(14): Concat()
(15): C2f(
(cv1): Conv(
(conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0): Bottleneck(
(cv1): Conv(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(16): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(17): Concat()
(18): C2f(
(cv1): Conv(
(conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0): Bottleneck(
(cv1): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(19): Conv(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(act): SiLU(inplace=True)
)
(20): Concat()
(21): C2f(
(cv1): Conv(
(conv): Conv2d(384, 256, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(384, 256, kernel_size=(1, 1), stride=(1, 1))
(act): SiLU(inplace=True)
)
(m): ModuleList(
(0): Bottleneck(
(cv1): Conv(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(cv2): Conv(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
)
)
)
(22): Detect(
(cv2): ModuleList(
(0): Sequential(
(0): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): Sequential(
(0): Conv(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
)
(2): Sequential(
(0): Conv(
(conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(cv3): ModuleList(
(0): Sequential(
(0): Conv(
(conv): Conv2d(64, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1))
)
(1): Sequential(
(0): Conv(
(conv): Conv2d(128, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1))
)
(2): Sequential(
(0): Conv(
(conv): Conv2d(256, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(1): Conv(
(conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(act): SiLU(inplace=True)
)
(2): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1))
)
)
(dfl): DFL(
(conv): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
)
)
)
import numpy as np
import os
import torch
from ultralytics import YOLO
from ultralytics.nn.modules.head import Detect
from ultralytics.utils import ops
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
from PIL import Image
import json
class SaveIO:
"""Simple PyTorch hook to save the output of a nn.module."""
def __init__(self):
self.input = None
self.output = None
def __call__(self, module, module_in, module_out):
self.input = module_in
self.output = module_out
def load_and_prepare_model(model_path):
# we are going to register a PyTorch hook on the important parts of the YOLO model,
# then reverse engineer the outputs to get boxes and logits
# first, we have to register the hooks to the model *before running inference*
# then, when inference is run, the hooks will save the inputs/outputs of their respective modules
= YOLO(model_path)
model = None
detect = None
cv2_hooks = None
cv3_hooks = SaveIO()
detect_hook for i, module in enumerate(model.model.modules()):
if type(module) is Detect:
module.register_forward_hook(detect_hook)= module
detect
= [SaveIO() for _ in range(module.nl)]
cv2_hooks = [SaveIO() for _ in range(module.nl)]
cv3_hooks for i in range(module.nl):
module.cv2[i].register_forward_hook(cv2_hooks[i])
module.cv3[i].register_forward_hook(cv3_hooks[i])break
= SaveIO()
input_hook
model.model.register_forward_hook(input_hook)
# save and return these for later
= [input_hook, detect, detect_hook, cv2_hooks, cv3_hooks]
hooks
return model, hooks
def is_text_file(file_path):
# Check if the file extension indicates a text file
= ['.txt'] #, '.csv', '.json', '.xml'] # Add more extensions if needed
text_extensions return any(file_path.lower().endswith(ext) for ext in text_extensions)
def is_image_file(file_path):
# Check if the file extension indicates an image file
= ['.jpg', '.jpeg', '.png', '.gif', '.bmp'] # Add more extensions if needed
image_extensions return any(file_path.lower().endswith(ext) for ext in image_extensions)
def plot_image(img_path, results, category_mapping=None, suffix='test', show_labels=True, include_legend=True):
"""
Display the image with bounding boxes and their corresponding class scores.
Args:
img_path (str): Path to the image file.
results (list): List of dictionaries containing bounding box information.
category_mapping:
suffix: what to append to the original image name when saving
Returns:
None
"""
= Image.open(img_path)
img = plt.subplots()
fig, ax
ax.imshow(img)
for box in results:
= map(int, box['bbox'])
x0, y0, x1, y1
= "r" # red
box_color = "k" # black
tag_color = max(box['activations'])
max_score = box['activations'].index(max_score)
max_category_id = max_category_id
category_name
if category_mapping:
= category_mapping.get(max_category_id, "Unknown")
max_category_name = max_category_name
category_name
= patches.Rectangle(
rect
(x0, y0),- x0,
x1 - y0,
y1 =box_color,
edgecolor=f"{max_category_id}: {category_name} ({max_score:.2f})",
label='none'
facecolor
)
ax.add_patch(rect)
if show_labels:
plt.text(
x0,- 50,
y0 f"{max_category_id} ({max_score:.2f})",
="5",
fontsize=tag_color,
color=box_color,
backgroundcolor
)
if include_legend:
="5")
ax.legend(fontsize
"off")
plt.axis(f'{os.path.basename(img_path).rsplit(".", 1)[0]}_{suffix}.jpg', bbox_inches="tight", dpi=300)
plt.savefig(
def write_json(results):
# Create a list to store the predictions data
= []
predictions
for result in results:
= os.path.basename(result['image_id'])#.split('.')[0]
image_id # image_id = result["image_id"]
#image_id = os.path.basename(img_path).split('.')[0]
= result['activations'].index(max(result['activations']))
max_category_id = max_category_id
category_id = result['bbox']
bbox = max(result['activations'])
score = result['activations']
activations
= {
prediction 'image_id': image_id,
'category_id': category_id,
'bbox': bbox,
'score': score,
'activations': activations
}
predictions.append(prediction)
# Write the predictions list to a JSON file
with open('predictions.json', 'w') as f:
json.dump(predictions, f)
def calculate_iou(box1, box2):
"""
Calculates the Intersection over Union (IoU) between two bounding boxes.
Args:
box1 (list): Bounding box coordinates [x1, y1, w1, h1].
box2 (list): Bounding box coordinates [x2, y2, w2, h2].
Returns:
float: Intersection over Union (IoU) value.
"""
= box1
x1, y1, w1, h1 = box2
x2, y2, w2, h2
= max(x1, x2)
intersect_x1 = max(y1, y2)
intersect_y1 = min(x1 + w1, x2 + w2)
intersect_x2 = min(y1 + h1, y2 + h2)
intersect_y2
= max(0, intersect_x2 - intersect_x1 + 1) * max(0, intersect_y2 - intersect_y1 + 1)
intersect_area = w1 * h1
box1_area = w2 * h2
box2_area
= intersect_area / float(box1_area + box2_area - intersect_area)
iou return iou
# Apply Non-Maximum Suppression
def nms(boxes, iou_threshold=0.7):
"""
Applies Non-Maximum Suppression (NMS) to a list of bounding box dictionaries.
Args:
boxes (list): List of dictionaries, each containing 'bbox', 'logits', and 'activations'.
iou_threshold (float, optional): Intersection over Union (IoU) threshold for NMS. Default is 0.7.
Returns:
list: List of selected bounding box dictionaries after NMS.
"""
# Sort boxes by confidence score in descending order
= sorted(boxes, key=lambda x: max(x['activations']), reverse=True)
sorted_boxes = []
selected_boxes
# Keep the box with highest confidence and remove overlapping boxes
= []
delete_idxs for i, box0 in enumerate(sorted_boxes):
for j, box1 in enumerate(sorted_boxes):
if i < j and calculate_iou(box0['bbox'], box1['bbox']) > iou_threshold:
delete_idxs.append(j)
# Reverse the order of delete_idxs
delete_idxs.reverse()
# now delete by popping them in reverse order
= [box for idx, box in enumerate(sorted_boxes) if idx not in delete_idxs]
filtered_boxes
return filtered_boxes
def results_predict(img_path, model, hooks, threshold=0.5, iou=0.7, save_image = False, category_mapping = None):
"""
Run prediction with a YOLO model and apply Non-Maximum Suppression (NMS) to the results.
Args:
img_path (str): Path to an image file.
model (YOLO): YOLO model object.
hooks (list): List of hooks for the model.
threshold (float, optional): Confidence threshold for detection. Default is 0.5.
iou (float, optional): Intersection over Union (IoU) threshold for NMS. Default is 0.7.
save_image (bool, optional): Whether to save the image with boxes plotted. Default is False.
Returns:
list: List of selected bounding box dictionaries after NMS.
"""
# unpack hooks from load_and_prepare_model()
= hooks
input_hook, detect, detect_hook, cv2_hooks, cv3_hooks
# run inference; we don't actually need to store the results because
# the hooks store everything we need
model(img_path)
# now reverse engineer the outputs to find the logits
# see Detect.forward(): https://github.com/ultralytics/ultralytics/blob/b638c4ed9a24270a6875cdd47d9eeda99204ef5a/ultralytics/nn/modules/head.py#L22
= detect_hook.input[0][0].shape # BCHW
shape = []
x for i in range(detect.nl):
1))
x.append(torch.cat((cv2_hooks[i].output, cv3_hooks[i].output), = torch.cat([xi.view(shape[0], detect.no, -1) for xi in x], 2)
x_cat = x_cat.split((detect.reg_max * 4, detect.nc), 1)
box, cls
# assumes batch size = 1 (i.e. you are just running with one image)
# if you want to run with many images, throw this in a loop
= 0
batch_idx = detect_hook.output[0][batch_idx]
xywh_sigmoid = cls[batch_idx]
all_logits
# figure out the original img shape and model img shape so we can transform the boxes
= input_hook.input[0].shape[2:]
img_shape = model.predictor.batch[1][batch_idx].shape[:2]
orig_img_shape
# compute predictions
= []
boxes for i in range(xywh_sigmoid.shape[-1]): # for each predicted box...
*class_probs_after_sigmoid = xywh_sigmoid[:,i]
x0, y0, x1, y1, = ops.scale_boxes(img_shape, np.array([x0.cpu(), y0.cpu(), x1.cpu(), y1.cpu()]), orig_img_shape)
x0, y0, x1, y1 = all_logits[:,i]
logits
boxes.append({'image_id': img_path,
'bbox': [x0.item(), y0.item(), x1.item(), y1.item()], # xyxy
'bbox_xywh': [(x0.item() + x1.item())/2, (y0.item() + y1.item())/2, x1.item() - x0.item(), y1.item() - y0.item()],
'logits': logits.cpu().tolist(),
'activations': [p.item() for p in class_probs_after_sigmoid]
})
# for debugging
# top10 = sorted(boxes, key=lambda x: max(x['activations']), reverse=True)[:10]
# plot_image(img_path, top10, suffix="before_nms")
# NMS
# we can keep the activations and logits around via the YOLOv8 NMS method, but only if we
# append them as an additional time to the prediction vector. It's a weird hacky way to do it, but
# it works. We also have to pass in the num classes (nc) parameter to make it work.
= torch.stack([
boxes_for_nms *b['bbox_xywh'], *b['activations'], *b['activations'], *b['logits']]) for b in boxes
torch.tensor([=1).unsqueeze(0)
], dim
# do the NMS
= ops.non_max_suppression(boxes_for_nms, conf_thres=threshold, iou_thres=iou, nc=detect.nc)[0]
nms_results
# unpack it and return it
= []
boxes for b in range(nms_results.shape[0]):
= nms_results[b, :]
box *acts_and_logits = box
x0, y0, x1, y1, conf, cls, = acts_and_logits[:detect.nc]
activations = acts_and_logits[detect.nc:]
logits = {
box_dict 'bbox': [x0.item(), y0.item(), x1.item(), y1.item()], # xyxy
'bbox_xywh': [(x0.item() + x1.item())/2, (y0.item() + y1.item())/2, x1.item() - x0.item(), y1.item() - y0.item()],
'best_conf': conf.item(),
'best_cls': cls.item(),
'image_id': img_path,
'activations': [p.item() for p in activations],
'logits': [p.item() for p in logits]
}
boxes.append(box_dict)
return boxes
def run_predict(input_path, model, hooks, score_threshold=0.5, iou_threshold=0.7, save_image = False, save_json = False, category_mapping = None):
"""
Run prediction with a YOLO model.
Args:
input_path (str): Path to an image file or txt file containing paths to image files.
model (YOLO): YOLO model object.
hooks (list): List of hooks for the model.
threshold (float, optional): Confidence threshold for detection. Default is 0.5.
iou_threshold (float, optional): Intersection over Union (IoU) threshold for NMS. Default is 0.7.
save_image (bool, optional): Whether to save the image with boxes plotted. Default is False.
save_json (bool, optional): Whether to save the results in a json file. Default is False.
Returns:
list: List of selected bounding box dictionaries for all the images given as input.
"""
= False
use_txt_input
if is_text_file(input_path):
= True
use_txt_input
if use_txt_input:
with open(input_path, 'r') as f:
= f.read().splitlines()
img_paths else:
= [input_path]
img_paths
= []
all_results
for img_path in img_paths:
= results_predict(img_path, model, hooks, score_threshold, iou=iou_threshold, save_image=save_image, category_mapping=category_mapping)
results
all_results.extend(results)
if save_json:
write_json(all_results)
return all_results
### Start example script here ###
### (This shows how to use the methods in this file) ###
# change these, of course :)
= True
SAVE_TEST_IMG = 'yolov8n.pt'
model_path = 'bus.jpg'
img_path = 0.5
threshold = 0.7
nms_threshold
# load the model
= load_and_prepare_model(model_path)
model, hooks
# run inference
= run_predict(img_path, model, hooks, score_threshold=threshold, iou_threshold=nms_threshold)
results
print("Processed", len(results), "boxes")
print("The first one is", results[0])
if SAVE_TEST_IMG:
plot_image(img_path, results)
image 1/1 /home/nipun.batra/git/blog/posts/bus.jpg: 640x480 4 persons, 1 bus, 1 stop sign, 6.9ms
Speed: 2.5ms preprocess, 6.9ms inference, 1.4ms postprocess per image at shape (1, 3, 640, 480)
Processed 4 boxes
The first one is {'bbox': [17.28582763671875, 230.59222412109375, 801.5181884765625, 768.4058227539062], 'bbox_xywh': [409.4020080566406, 499.4990234375, 784.2323608398438, 537.8135986328125], 'best_conf': 0.8705450892448425, 'best_cls': 5.0, 'image_id': 'bus.jpg', 'activations': [1.0311242476745974e-05, 9.112093721341807e-06, 0.0003611394204199314, 6.673172902083024e-05, 0.00022987156989984214, 0.8705450892448425, 0.0009488395880907774, 0.007086973637342453, 0.0001582797704031691, 9.052685527422e-06, 4.503397406097065e-07, 1.0215371730737388e-05, 1.730335316096898e-05, 4.165742666373262e-06, 5.312303414939379e-07, 1.9164923514836119e-07, 9.471239650338248e-07, 4.973830982635263e-06, 2.5008853654640006e-08, 4.033859681840113e-07, 0.0005047211889177561, 6.792038220737595e-07, 4.306784262553265e-07, 9.298752411268651e-05, 4.2397550714667886e-07, 6.426890467992052e-05, 3.322290240248549e-06, 1.4765126934435102e-06, 2.1423567886813544e-05, 1.0903042948484654e-06, 2.7694509299180936e-06, 3.31359643723772e-07, 1.1906087138413568e-06, 5.6864570069592446e-05, 5.980197329336079e-06, 9.176978892355692e-07, 6.090281658543972e-06, 8.041779437917285e-06, 4.321082087699324e-05, 2.43568738369504e-05, 4.93736079079099e-05, 1.2213927220727783e-05, 3.2120142350322567e-06, 5.898335757592577e-07, 6.751507157787273e-07, 1.514082214271184e-06, 7.679868474497198e-08, 5.769239237451984e-07, 2.8340421067696298e-06, 4.807910954696126e-06, 2.058711316976769e-07, 2.196402647314244e-06, 2.783847776299808e-06, 7.796339195920154e-05, 1.7758512171894836e-07, 2.9252541366986407e-07, 4.3183670641155913e-05, 8.449018423561938e-06, 1.572124688209442e-06, 4.166020971752005e-06, 1.1690817700582556e-05, 1.4325351571642386e-07, 3.190975985489786e-05, 1.2439878673831117e-06, 9.514394605503185e-07, 9.426859719496861e-07, 5.968039204162778e-06, 8.346189133590087e-05, 2.026550646405667e-05, 5.310462256602477e-06, 9.205886271956842e-06, 4.1557203900310924e-08, 4.2926520109176636e-05, 2.642763683979865e-05, 2.686474908841774e-05, 2.5954559532692656e-05, 6.725371122229262e-07, 5.926154926783056e-07, 1.19427056688437e-06, 2.642658500917605e-07], 'logits': [-11.48226547241211, -11.6058988571167, -7.925885200500488, -9.614763259887695, -8.37775993347168, 1.9057865142822266, -6.95932149887085, -4.942384719848633, -8.750988006591797, -11.61244010925293, -14.613263130187988, -11.491606712341309, -10.964592933654785, -12.388611793518066, -14.44806957244873, -15.467598915100098, -13.869834899902344, -12.211315155029297, -17.50403594970703, -14.723371505737305, -7.590999603271484, -14.202343940734863, -14.657903671264648, -9.282952308654785, -14.673589706420898, -9.65237045288086, -12.614852905273438, -13.425826072692871, -10.750997543334961, -13.729052543640137, -12.796858787536621, -14.920061111450195, -13.641044616699219, -9.774781227111816, -12.027050971984863, -13.901396751403809, -12.008810043334961, -11.730852127075195, -10.049376487731934, -10.622672080993652, -9.916045188903809, -11.312921524047852, -12.648609161376953, -14.343424797058105, -14.208329200744629, -13.400699615478516, -16.382078170776367, -14.365554809570312, -12.7738037109375, -12.245243072509766, -15.396015167236328, -13.028687477111816, -12.79167366027832, -9.459193229675293, -15.543815612792969, -15.044713973999023, -10.050004959106445, -11.681451797485352, -13.363080978393555, -12.388545036315918, -11.356695175170898, -15.758649826049805, -10.352566719055176, -13.597187042236328, -13.865288734436035, -13.874531745910645, -12.029086112976074, -9.391036987304688, -10.806570053100586, -12.14582633972168, -11.595658302307129, -16.99619483947754, -10.055977821350098, -10.5410737991333, -10.52466869354248, -10.559137344360352, -14.212207794189453, -14.338719367980957, -13.63797378540039, -15.146309852600098]}
results
[{'bbox': [17.28582763671875,
230.59222412109375,
801.5181884765625,
768.4058227539062],
'bbox_xywh': [409.4020080566406,
499.4990234375,
784.2323608398438,
537.8135986328125],
'best_conf': 0.8705450892448425,
'best_cls': 5.0,
'image_id': 'bus.jpg',
'activations': [1.0311242476745974e-05,
9.112093721341807e-06,
0.0003611394204199314,
6.673172902083024e-05,
0.00022987156989984214,
0.8705450892448425,
0.0009488395880907774,
0.007086973637342453,
0.0001582797704031691,
9.052685527422e-06,
4.503397406097065e-07,
1.0215371730737388e-05,
1.730335316096898e-05,
4.165742666373262e-06,
5.312303414939379e-07,
1.9164923514836119e-07,
9.471239650338248e-07,
4.973830982635263e-06,
2.5008853654640006e-08,
4.033859681840113e-07,
0.0005047211889177561,
6.792038220737595e-07,
4.306784262553265e-07,
9.298752411268651e-05,
4.2397550714667886e-07,
6.426890467992052e-05,
3.322290240248549e-06,
1.4765126934435102e-06,
2.1423567886813544e-05,
1.0903042948484654e-06,
2.7694509299180936e-06,
3.31359643723772e-07,
1.1906087138413568e-06,
5.6864570069592446e-05,
5.980197329336079e-06,
9.176978892355692e-07,
6.090281658543972e-06,
8.041779437917285e-06,
4.321082087699324e-05,
2.43568738369504e-05,
4.93736079079099e-05,
1.2213927220727783e-05,
3.2120142350322567e-06,
5.898335757592577e-07,
6.751507157787273e-07,
1.514082214271184e-06,
7.679868474497198e-08,
5.769239237451984e-07,
2.8340421067696298e-06,
4.807910954696126e-06,
2.058711316976769e-07,
2.196402647314244e-06,
2.783847776299808e-06,
7.796339195920154e-05,
1.7758512171894836e-07,
2.9252541366986407e-07,
4.3183670641155913e-05,
8.449018423561938e-06,
1.572124688209442e-06,
4.166020971752005e-06,
1.1690817700582556e-05,
1.4325351571642386e-07,
3.190975985489786e-05,
1.2439878673831117e-06,
9.514394605503185e-07,
9.426859719496861e-07,
5.968039204162778e-06,
8.346189133590087e-05,
2.026550646405667e-05,
5.310462256602477e-06,
9.205886271956842e-06,
4.1557203900310924e-08,
4.2926520109176636e-05,
2.642763683979865e-05,
2.686474908841774e-05,
2.5954559532692656e-05,
6.725371122229262e-07,
5.926154926783056e-07,
1.19427056688437e-06,
2.642658500917605e-07],
'logits': [-11.48226547241211,
-11.6058988571167,
-7.925885200500488,
-9.614763259887695,
-8.37775993347168,
1.9057865142822266,
-6.95932149887085,
-4.942384719848633,
-8.750988006591797,
-11.61244010925293,
-14.613263130187988,
-11.491606712341309,
-10.964592933654785,
-12.388611793518066,
-14.44806957244873,
-15.467598915100098,
-13.869834899902344,
-12.211315155029297,
-17.50403594970703,
-14.723371505737305,
-7.590999603271484,
-14.202343940734863,
-14.657903671264648,
-9.282952308654785,
-14.673589706420898,
-9.65237045288086,
-12.614852905273438,
-13.425826072692871,
-10.750997543334961,
-13.729052543640137,
-12.796858787536621,
-14.920061111450195,
-13.641044616699219,
-9.774781227111816,
-12.027050971984863,
-13.901396751403809,
-12.008810043334961,
-11.730852127075195,
-10.049376487731934,
-10.622672080993652,
-9.916045188903809,
-11.312921524047852,
-12.648609161376953,
-14.343424797058105,
-14.208329200744629,
-13.400699615478516,
-16.382078170776367,
-14.365554809570312,
-12.7738037109375,
-12.245243072509766,
-15.396015167236328,
-13.028687477111816,
-12.79167366027832,
-9.459193229675293,
-15.543815612792969,
-15.044713973999023,
-10.050004959106445,
-11.681451797485352,
-13.363080978393555,
-12.388545036315918,
-11.356695175170898,
-15.758649826049805,
-10.352566719055176,
-13.597187042236328,
-13.865288734436035,
-13.874531745910645,
-12.029086112976074,
-9.391036987304688,
-10.806570053100586,
-12.14582633972168,
-11.595658302307129,
-16.99619483947754,
-10.055977821350098,
-10.5410737991333,
-10.52466869354248,
-10.559137344360352,
-14.212207794189453,
-14.338719367980957,
-13.63797378540039,
-15.146309852600098]},
{'bbox': [48.739479064941406,
399.263916015625,
244.50173950195312,
902.501220703125],
'bbox_xywh': [146.62060928344727,
650.882568359375,
195.76226043701172,
503.2373046875],
'best_conf': 0.8689801096916199,
'best_cls': 0.0,
'image_id': 'bus.jpg',
'activations': [0.8689801096916199,
2.159083720698618e-07,
7.267921660059073e-07,
4.954289352099295e-07,
3.9690485209575854e-06,
3.6866176742478274e-06,
1.924761590998969e-06,
5.504815590029466e-07,
8.782559461906203e-07,
1.8101994783137343e-06,
1.012005441225483e-06,
1.725299512145284e-07,
1.9031394913326949e-06,
1.7214931347098172e-07,
2.6558041099633556e-06,
1.7690776985546108e-06,
4.047087259095861e-06,
1.6409081808888004e-06,
7.941423518786905e-07,
1.1039174978577648e-06,
1.4459410522249527e-06,
1.3210571978561347e-06,
4.1252138771596947e-07,
1.2257790331204887e-05,
9.701152521301992e-07,
5.426437610367429e-07,
4.6731869929317327e-07,
1.5046108501337585e-06,
6.04137326831733e-08,
3.9817462038627127e-07,
1.8838558162315167e-06,
1.1097992000941304e-06,
1.1555634955584537e-06,
7.810695024090819e-06,
2.828919605235569e-06,
2.9941506340946944e-07,
5.739046855524066e-07,
1.1345385928507312e-06,
4.880649271399307e-07,
2.7935711841564626e-06,
1.2807398661607294e-06,
3.5370536011214426e-07,
3.775756454160728e-07,
2.3527961729996605e-06,
9.562232889948064e-07,
4.1244192061640206e-07,
6.125056870587287e-07,
2.33453351938806e-06,
5.989752480672905e-07,
1.124500272453588e-06,
9.676870149633032e-07,
1.5129870007513091e-05,
3.98353768105153e-06,
6.566142474184744e-06,
6.885069296913571e-07,
5.487340786203276e-07,
7.767624765619985e-07,
5.031957925893948e-07,
2.001876282520243e-06,
9.961429441318614e-07,
1.2013529158139136e-06,
4.2538695765870216e-07,
2.504941676306771e-06,
2.0721746807339514e-07,
2.624780393034598e-07,
7.277583904397034e-07,
4.963558453141559e-08,
1.2240196838320117e-06,
3.9281539443436486e-07,
1.3602493709186092e-06,
2.1281984174947866e-07,
9.780134746506519e-08,
1.0565305501586408e-06,
4.288276045372186e-07,
3.775760148982954e-07,
2.0054291383075906e-07,
1.2226444141560933e-06,
1.0907050636888016e-05,
1.6243636764556868e-06,
5.100530984236684e-07],
'logits': [1.8919711112976074,
-15.348411560058594,
-14.134624481201172,
-14.517841339111328,
-12.436980247497559,
-12.510797500610352,
-13.160706520080566,
-14.412471771240234,
-13.945326805114746,
-13.222071647644043,
-13.80357551574707,
-15.572694778442383,
-13.172003746032715,
-15.57490348815918,
-12.838760375976562,
-13.245050430297852,
-12.417509078979492,
-13.320259094238281,
-14.046002388000488,
-13.716644287109375,
-13.446748733520508,
-13.537076950073242,
-14.700977325439453,
-11.30933666229248,
-13.845849990844727,
-14.426812171936035,
-14.576253890991211,
-13.406974792480469,
-16.62204933166504,
-14.736374855041504,
-13.182188034057617,
-13.71133041381836,
-13.670921325683594,
-11.760008811950684,
-12.775612831115723,
-15.021434783935547,
-14.37080192565918,
-13.68928337097168,
-14.532816886901855,
-12.788187026977539,
-13.568071365356445,
-14.854801177978516,
-14.789494514465332,
-12.959903717041016,
-13.860273361206055,
-14.701169967651367,
-14.305706977844238,
-12.967696189880371,
-14.328044891357422,
-13.69817066192627,
-13.848356246948242,
-11.098824501037598,
-12.43333625793457,
-11.933577537536621,
-14.188739776611328,
-14.415651321411133,
-14.068130493164062,
-14.502285957336426,
-13.121423721313477,
-13.819374084472656,
-13.632061004638672,
-14.670266151428223,
-12.897242546081543,
-15.389496803283691,
-15.153098106384277,
-14.133296012878418,
-16.818557739257812,
-13.61336898803711,
-14.74992561340332,
-13.507841110229492,
-15.36281967163086,
-16.14032745361328,
-13.760519027709961,
-14.662210464477539,
-14.789493560791016,
-15.422237396240234,
-13.614493370056152,
-11.426090240478516,
-13.330392837524414,
-14.488750457763672]},
{'bbox': [670.269287109375,
380.2840270996094,
809.858154296875,
875.6907958984375],
'bbox_xywh': [740.063720703125,
627.9874114990234,
139.5888671875,
495.4067687988281],
'best_conf': 0.8536035418510437,
'best_cls': 0.0,
'image_id': 'bus.jpg',
'activations': [0.8536035418510437,
8.206617394534987e-07,
1.2038118484269944e-06,
1.1462942666184972e-06,
1.5739009313620045e-06,
3.3411379263270646e-06,
3.73882085114019e-06,
4.4993788606007e-07,
1.2424654869391816e-06,
3.6132853438175516e-06,
2.934745680249762e-06,
1.105569964465758e-07,
1.0381992296970566e-06,
5.935368676546204e-07,
1.4730724160472164e-06,
3.304354777355911e-06,
4.168759005551692e-06,
3.27019665746775e-06,
5.449832656267972e-07,
5.046072146797087e-07,
1.3525707345252158e-06,
9.852926723397104e-07,
1.1722339650077629e-06,
7.1923359428183176e-06,
2.579861757112667e-06,
7.465223461622372e-07,
1.0459139048180077e-06,
1.3547731896323967e-06,
3.88375610782532e-07,
3.233400320823421e-07,
1.6130819631143822e-06,
8.577525818509457e-07,
1.0866494903893908e-06,
3.3366293337167008e-06,
2.0756035610247636e-06,
3.148159635202319e-07,
6.232777991499461e-07,
7.360301310654904e-07,
9.060752290679375e-07,
2.585813035693718e-06,
1.3577079016613425e-06,
4.043033925427153e-07,
3.105890868937422e-07,
1.355843437522708e-06,
3.1056660532158276e-07,
1.5147954002259212e-07,
3.9362166148748656e-07,
1.426429321327305e-06,
1.4964024330765824e-07,
4.153795600814192e-07,
3.13703480969707e-07,
2.565983777458314e-06,
1.2322219617999508e-06,
9.984036068999558e-07,
1.9590237343436456e-07,
4.869207828051003e-07,
2.924327418440953e-06,
1.931917267938843e-06,
7.152341822802555e-06,
5.519689239008585e-06,
1.5743992207717383e-06,
2.096727712341817e-06,
2.2192562028067186e-06,
3.0115475624370447e-07,
3.523186933307443e-07,
3.960224432830728e-07,
3.5972377077087e-08,
8.289633797176066e-07,
3.8832413906675356e-07,
8.207001087612298e-07,
1.5001531039615656e-07,
8.185835298490929e-08,
1.6670093145876308e-06,
2.5175890527862066e-07,
3.8390692225220846e-07,
7.335270879593736e-07,
1.5532028783127316e-06,
1.81348677870119e-05,
1.4697145616082707e-06,
5.530769158212934e-07],
'logits': [1.763148307800293,
-14.013154029846191,
-13.630016326904297,
-13.678975105285645,
-13.36195182800293,
-12.609195709228516,
-12.496736526489258,
-14.614155769348145,
-13.598411560058594,
-12.530889511108398,
-12.738886833190918,
-16.01773452758789,
-13.778021812438965,
-14.337165832519531,
-13.4281587600708,
-12.62026596069336,
-12.387887954711914,
-12.630657196044922,
-14.422510147094727,
-14.49948501586914,
-13.51350212097168,
-13.830326080322266,
-13.656598091125488,
-11.842487335205078,
-12.867772102355957,
-14.107839584350586,
-13.770618438720703,
-13.51187515258789,
-14.761292457580566,
-14.944561004638672,
-13.337362289428711,
-13.968949317932129,
-13.732410430908203,
-12.610546112060547,
-13.085256576538086,
-14.971277236938477,
-14.288272857666016,
-14.121994018554688,
-13.914142608642578,
-12.86546802520752,
-13.509711265563965,
-14.721099853515625,
-14.984794616699219,
-13.511085510253906,
-14.984867095947266,
-15.702815055847168,
-14.747875213623047,
-13.460334777832031,
-15.715031623840332,
-14.694072723388672,
-14.974817276000977,
-12.87316608428955,
-13.606690406799316,
-13.817107200622559,
-15.445649147033691,
-14.535163879394531,
-12.742443084716797,
-13.15699577331543,
-11.848063468933105,
-12.107183456420898,
-13.361635208129883,
-13.075130462646484,
-13.018336296081543,
-15.015641212463379,
-14.858729362487793,
-14.74179458618164,
-17.140514373779297,
-14.00308895111084,
-14.761425018310547,
-14.013107299804688,
-15.712528228759766,
-16.318275451660156,
-13.30447769165039,
-15.194793701171875,
-14.772865295410156,
-14.12540054321289,
-13.375189781188965,
-10.917655944824219,
-13.430440902709961,
-14.407768249511719]},
{'bbox': [221.39376831054688,
405.79168701171875,
344.7171936035156,
857.3920288085938],
'bbox_xywh': [283.05548095703125,
631.5918579101562,
123.32342529296875,
451.600341796875],
'best_conf': 0.8193051218986511,
'best_cls': 0.0,
'image_id': 'bus.jpg',
'activations': [0.8193051218986511,
4.0964692971101613e-07,
2.936613100246177e-06,
1.382118284709577e-06,
1.6149664588738233e-05,
8.412195711571258e-06,
9.894216645989218e-07,
1.2228890682308702e-06,
1.2646942195715383e-06,
1.9066765162278898e-06,
6.240153425096651e-07,
7.205974128510206e-08,
2.2608073777519166e-06,
1.6783783962637244e-07,
4.693410573963774e-06,
3.116812195003149e-06,
7.556559467047919e-06,
3.5700793432624778e-06,
8.197254715014424e-07,
1.5916730262688361e-06,
2.326829417143017e-06,
1.0168097333007609e-06,
2.1919508697010315e-07,
4.784199973073555e-06,
1.3443328725770698e-06,
4.94625453484332e-07,
4.1110206439043395e-07,
6.275794248722377e-07,
1.2701826790362247e-07,
3.41750705956656e-07,
1.7371836520396755e-06,
1.0280481319568935e-06,
1.6011792922654422e-06,
5.354379027266987e-06,
4.389916284708306e-06,
3.4768410728247545e-07,
1.3610668929686653e-06,
1.0928538358712103e-06,
7.004660460552259e-07,
3.702946060002432e-06,
2.8141175789642148e-06,
6.319215231087583e-07,
4.0783669419397484e-07,
1.7733005961417803e-06,
1.152611616817012e-06,
2.805371650538291e-07,
8.198293812711199e-07,
3.066387762373779e-06,
8.176439223461784e-07,
6.666024887636013e-07,
1.3133955008015619e-06,
9.064304322237149e-06,
3.986513547715731e-06,
2.927407876995858e-06,
2.5427269179090217e-07,
4.65552773221134e-07,
2.3654131382500054e-06,
7.987733283698617e-07,
1.5470428706976236e-06,
5.12271185471036e-07,
4.3607190036709653e-07,
5.852278945894795e-07,
1.5100877135409974e-06,
3.0730282674085174e-07,
5.032850936004252e-07,
7.947044764478051e-07,
6.477876013377681e-08,
2.3970601432665717e-06,
1.9725838740214385e-07,
3.763364588849072e-07,
2.017427789269277e-07,
1.4345364718337805e-07,
8.217072036131867e-07,
3.7514539030780725e-07,
2.626693742513453e-07,
2.2584059422570135e-07,
9.443388080399018e-07,
1.2512692592281383e-05,
2.4881660465325695e-06,
7.62891602335003e-07],
'logits': [1.5116467475891113,
-14.707969665527344,
-12.738250732421875,
-13.491891860961914,
-11.033595085144043,
-11.685819625854492,
-13.826144218444824,
-13.614293098449707,
-13.580678939819336,
-13.170146942138672,
-14.287090301513672,
-16.445770263671875,
-12.999786376953125,
-15.60026741027832,
-12.269346237182617,
-12.678696632385254,
-11.793087005615234,
-12.542919158935547,
-14.01429557800293,
-13.350723266601562,
-12.971001625061035,
-13.798839569091797,
-15.333303451538086,
-12.250186920166016,
-13.519611358642578,
-14.519464492797852,
-14.704423904418945,
-14.281394958496094,
-15.878934860229492,
-14.88918399810791,
-13.263243675231934,
-13.787847518920898,
-13.344768524169922,
-12.137590408325195,
-12.336195945739746,
-14.871971130371094,
-13.507240295410156,
-13.726716995239258,
-14.17151927947998,
-12.506378173828125,
-12.780858993530273,
-14.274499893188477,
-14.712398529052734,
-13.242666244506836,
-13.673479080200195,
-15.086559295654297,
-14.014168739318848,
-12.69500732421875,
-14.016838073730469,
-14.221071243286133,
-13.542893409729004,
-11.611157417297363,
-12.432589530944824,
-12.741390228271484,
-15.184858322143555,
-14.580039978027344,
-12.95455551147461,
-14.04018783569336,
-13.37916374206543,
-14.484411239624023,
-14.645458221435547,
-14.351263999938965,
-13.403341293334961,
-14.995431900024414,
-14.502108573913574,
-14.045294761657715,
-16.552288055419922,
-12.941265106201172,
-15.438751220703125,
-14.792781829833984,
-15.416272163391113,
-15.757253646850586,
-14.011880874633789,
-14.795951843261719,
-15.152369499206543,
-15.303436279296875,
-13.872779846191406,
-11.2887544631958,
-12.903962135314941,
-14.086149215698242]}]