Source code for metrics.tv

import torch

from disco.utils.device import get_device
from .base import BaseDivergence


[docs]class TV(BaseDivergence):
[docs] @classmethod def pointwise_estimates(cls, m1_log_scores, m2_log_scores, z, proposal_log_scores=None): """ computes the TVD between 2 distributions Parameters ---------- m1_log_scores: floats log-scores for samples according to network 1 m2_log_scores: floats log-scores for samples according to network 2 z: float partition function of network 1 proposal_log_scores: floats log-scores for samples according to proposal (by default m2_log_scores) Returns ------- divergence between m1 and m2 """ device = get_device(m1_log_scores) m2_log_scores = m2_log_scores.to(device) if proposal_log_scores is None: proposal_log_scores = m2_log_scores else: proposal_log_scores = proposal_log_scores.to(device) normalized_m1_log_scores = m1_log_scores - torch.log(z) return 1/2 * (torch.abs(torch.exp(m2_log_scores - proposal_log_scores) - torch.exp(normalized_m1_log_scores - proposal_log_scores)))