Source code for metrics.kl

import torch

from disco.utils.device import get_device
from .base import BaseDivergence


[docs]class KL(BaseDivergence): """ Kullback-Leibler divergence class """
[docs] @classmethod def pointwise_estimates(cls, m1_log_scores, m2_log_scores, z, proposal_log_scores=None): """ computes the KL divergence between 2 distributions Parameters ---------- m1_log_scores: floats log-scores for samples according to network 1 m2_log_scores: floats log-scores for samples according to network 2 z: float partition function of network 1 proposal_log_scores: floats log-scores for samples according to proposal (by default m2_log_scores) Returns ------- divergence between m1 and m2 """ device = get_device(m1_log_scores) m2_log_scores = m2_log_scores.to(device) if proposal_log_scores is None: proposal_log_scores = m2_log_scores else: proposal_log_scores = proposal_log_scores.to(device) importance_ratio = torch.exp(m1_log_scores - proposal_log_scores) unnormalized_pointwise_estimates = importance_ratio * (m1_log_scores - m2_log_scores) unnormalized_pointwise_estimates[ torch.isnan(unnormalized_pointwise_estimates)] = 0 return -1 * torch.log(z) + (1 / z) * unnormalized_pointwise_estimates