Source code for metrics.js

import torch
import logging
from disco.utils.device import get_device
from .base import BaseDivergence
from .kl import KL


logger = logging.getLogger(__name__)

MIN_EXPONENTIABLE_DOUBLE = -745

[docs]class JS(BaseDivergence): """ Jensen-Shannon divergence class. """
[docs] @classmethod def pointwise_estimates(cls, m1_log_scores, m2_log_scores, z, proposal_log_scores=None): """ Computes the KL divergence between 2 distributions Parameters ---------- m1_log_scores: floats log-scores for samples according to network 1 m2_log_scores: floats log-scores for samples according to network 2 z: float partition function of network 1 proposal_log_scores: floats log-scores for samples according to proposal (by default m2_log_scores) Returns ------- divergence between m1 and m2 """ very_small = lambda t: (MIN_EXPONENTIABLE_DOUBLE > t[~torch.isinf(t)]).any() device = get_device(m1_log_scores) m2_log_scores = m2_log_scores.to(device) normalized_m1_log_scores = m1_log_scores - torch.log(z) if very_small(m1_log_scores) or very_small(m2_log_scores): logger.warn(f"Scores below minimal precision in JS pointwise estimator") m_log_scores = torch.log((normalized_m1_log_scores.double().exp() + m2_log_scores.double().exp()) / 2).float() divergence = KL.pointwise_estimates(normalized_m1_log_scores, m_log_scores, torch.as_tensor(1), proposal_log_scores) / 2 + \ KL.pointwise_estimates(m2_log_scores, m_log_scores, torch.as_tensor(1), proposal_log_scores) / 2 return divergence