Source code for samplers.quasi_rejection_sampler

import torch

from . import Sampler
from disco.utils.device import get_device


[docs]class QuasiRejectionSampler(Sampler): """ Quasi Rejection-Sampling class """ def __init__(self, target, proposal, beta=1): """ Parameters ---------- target: distribution Energy-based model to (log-)score the samples proposal: distribution distribution to generate the samples beta: float coefficient to control the sampling """ super(QuasiRejectionSampler, self).__init__(target, proposal) self.beta = beta self.n_samples = 0 self.n_accepted_samples = 0
[docs] def sample(self, sampling_size=32, context=''): """Generates samples according to the QRS algorithm Parameters ---------- sampling_size: int number of requested samples when sampling context: text contextual text for which to sample Returns ------- tuple of accepted samples and their log-scores """ samples, proposal_log_scores = self.proposal.sample(sampling_size=sampling_size, context=context) self.n_samples += len(samples) device = get_device(proposal_log_scores) target_log_scores = self.target.log_score(samples=samples, context=context).to(device) rs = torch.clamp( torch.exp(target_log_scores - proposal_log_scores) / self.beta, min=0.0, max=1.0 ) us = torch.rand(len(rs)).to(device) accepted_samples = [x for k, x in zip(us < rs, samples) if k] self.n_accepted_samples += len(accepted_samples) accepted_log_scores = torch.tensor([s for k, s in zip(us < rs, proposal_log_scores) if k]).to(device) return accepted_samples, accepted_log_scores
[docs] def get_acceptance_rate(self): """Computes the acceptance rate, that is the number of accepted samples over the total sampled ones Returns ------- acceptance rate as float between 0 and 1""" return self.n_accepted_samples / self.n_samples