Source code for tuners.losses.kl

import torch
from .base import BaseLoss

[docs]class KLLoss(BaseLoss): """ Kullback-Leibler divergence loss for DPG """ def __init__(self, use_baseline=True): """ Parameters ---------- use_baseline: boolean use a baseline to reduce variance """ super(KLLoss, self).__init__() self.use_baseline = use_baseline def __call__(self, samples, context, proposal_log_scores, target_log_scores, model_log_scores, z): """ Computes the KL loss on a given minibatch of samples ∇ loss = (target(x) / q(x)) * ∇ log π(x) Parameters ---------- samples: list of items samples from the proposal network context: text context for the samples proposal_log_scores: array of floats log-probabilities for the samples according to the proposal target_log_scores: array of floats log-probabilities for the samples according to the target model_log_scores: array of floats log-probabilities for the samples according to the model network z: float estimation of the partition function of the EBM Returns ------- mean loss across the minibatch """ normalized_target_log_scores = target_log_scores - torch.log(z) rewards = torch.exp(normalized_target_log_scores - proposal_log_scores) self.metric_updated.dispatch('rewards', rewards.mean()) if self.use_baseline: importance_ratios = (model_log_scores.detach() - proposal_log_scores).exp() advantage = rewards - importance_ratios loss = -torch.mean(advantage * model_log_scores) self.metric_updated.dispatch('importance_ratios', importance_ratios.mean()) self.metric_updated.dispatch('advantage', advantage.mean()) else: loss = -torch.mean(rewards * model_log_scores) return loss