Why use logits


Nipun Batra


February 27, 2024

import numpy as np
import sklearn 
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from latexify import *
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import torch.nn.functional as F
import ipywidgets as widgets
from ipywidgets import interactive
from IPython.display import display

# Example ground truth probabilities
ground_truth_probs = torch.tensor([0.3, 0.7])

# Example model predictions (logits)
model_logits = torch.tensor([-2.0, 2.0])

# Applying softmax to logits to get probabilities
model_probs = F.sigmoid(model_logits)

# Cross-entropy loss using probabilities
loss_probs = F.binary_cross_entropy(model_probs, ground_truth_probs)

# Cross-entropy loss using logits
loss_logits = F.binary_cross_entropy_with_logits(model_logits, ground_truth_probs)

print("Loss using probabilities:", loss_probs.item())
print("Loss using logits:", loss_logits.item())
Loss using probabilities: 0.7269279956817627
Loss using logits: 0.7269280552864075
l1 = widgets.FloatSlider(value=0.3, min=-300, max=300, step=0.01, description='Logits ex 1')
l2 = widgets.FloatSlider(value=0.7, min=-300, max=300, step=0.01, description='Logits ex 2')

box = widgets.VBox([l1, l2])
def print_loss_using_both_methods(l1, l2):
    logits = torch.tensor([l1, l2])
    probs = F.sigmoid(logits)
    loss_probs = F.binary_cross_entropy(probs, ground_truth_probs)
    loss_logits = F.binary_cross_entropy_with_logits(logits, ground_truth_probs)
    print("Loss using probabilities:", loss_probs.item())
    print("Loss using logits:", loss_logits.item())
# add interactivity
interactive(print_loss_using_both_methods, l1=l1, l2=l2)
print_loss_using_both_methods(l1.value, l2.value)
Loss using probabilities: 3.164093017578125
Loss using logits: 3.164093017578125
def our_sigmoid(z):
    return 1/(1+torch.exp(-z))
