import numpy as np
import sklearn
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from latexify import *
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
Why use logits
ML
import torch
import torch.nn.functional as F
import ipywidgets as widgets
from ipywidgets import interactive
from IPython.display import display
# Example ground truth probabilities
= torch.tensor([0.3, 0.7])
ground_truth_probs
# Example model predictions (logits)
= torch.tensor([-2.0, 2.0])
model_logits
# Applying softmax to logits to get probabilities
= F.sigmoid(model_logits)
model_probs
# Cross-entropy loss using probabilities
= F.binary_cross_entropy(model_probs, ground_truth_probs)
loss_probs
# Cross-entropy loss using logits
= F.binary_cross_entropy_with_logits(model_logits, ground_truth_probs)
loss_logits
print("Loss using probabilities:", loss_probs.item())
print("Loss using logits:", loss_logits.item())
Loss using probabilities: 0.7269279956817627
Loss using logits: 0.7269280552864075
= widgets.FloatSlider(value=0.3, min=-300, max=300, step=0.01, description='Logits ex 1')
l1 = widgets.FloatSlider(value=0.7, min=-300, max=300, step=0.01, description='Logits ex 2')
l2
= widgets.VBox([l1, l2]) box
def print_loss_using_both_methods(l1, l2):
= torch.tensor([l1, l2])
logits = F.sigmoid(logits)
probs = F.binary_cross_entropy(probs, ground_truth_probs)
loss_probs = F.binary_cross_entropy_with_logits(logits, ground_truth_probs)
loss_logits print("Loss using probabilities:", loss_probs.item())
print("Loss using logits:", loss_logits.item())
# add interactivity
=l1, l2=l2) interactive(print_loss_using_both_methods, l1
print_loss_using_both_methods(l1.value, l2.value)
Loss using probabilities: 3.164093017578125
Loss using logits: 3.164093017578125
def our_sigmoid(z):
return 1/(1+torch.exp(-z))
-90.0)) our_sigmoid(torch.tensor(
tensor(0.)
-92.0)) our_sigmoid(torch.tensor(
tensor(0.)
9.0)) our_sigmoid(torch.tensor(
tensor(0.9999)
F.binary_cross_entropy_with_logits?
Signature: F.binary_cross_entropy_with_logits( input: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None, size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None, ) -> torch.Tensor Docstring: Function that measures Binary Cross Entropy between target and input logits. See :class:`~torch.nn.BCEWithLogitsLoss` for details. Args: input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits). target: Tensor of the same shape as input with values between 0 and 1 weight (Tensor, optional): a manual rescaling weight if provided it's repeated to match input tensor shape size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored when reduce is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` pos_weight (Tensor, optional): a weight of positive examples. Must be a vector with length equal to the number of classes. Examples:: >>> input = torch.randn(3, requires_grad=True) >>> target = torch.empty(3).random_(2) >>> loss = F.binary_cross_entropy_with_logits(input, target) >>> loss.backward() File: ~/miniconda3/lib/python3.9/site-packages/torch/nn/functional.py Type: function