Source code for metrics.base
import torch
from disco.utils.device import get_device
[docs]class BaseDivergence:
"""
Kullback-Leibler divergence class.
"""
[docs] @classmethod
def divergence(cls, m1_log_scores, m2_log_scores, z, proposal_log_scores=None):
"""
Computes an IS of the KL divergence between 2 distributions
Parameters
----------
m1_log_scores: floats
log-scores for samples according to network 1
m2_log_scores: floats
log-scores for samples according to network 2
z: float
partition function of network 1
proposal_log_scores: floats
log-scores for samples according to proposal (by default m2_log_scores)
Returns
-------
divergence between m1 and m2
"""
return torch.mean(cls.pointwise_estimates(
m1_log_scores, m2_log_scores, z, proposal_log_scores=proposal_log_scores))