Source code for samplers.accumulation_sampler

from . import Sampler
import torch
from tqdm.autonotebook import trange


[docs]class AccumulationSampler(Sampler): """ Utility class to accumulate samples, up to a total size """ def __init__(self, distribution, total_size=512): """ Parameters ---------- distribution: distribution distribution to sample from total_size: int total number of samples """ self.distribution = distribution self.total_size = total_size
[docs] def sample(self, sampling_size=32, context=""): """accumulates batches of samples from the distribution Parameters ---------- sampling_size: int number of requested samples per individual sampling context: text contextual text for which to sample Returns ------- a tuple of accumulated samples and scores """ with trange( self.total_size, desc=f"sampling from {type(self.distribution).__name__}" ) as t: remaining = self.total_size samples, log_scores = list(), torch.empty([0]) while remaining > 0: more_samples, more_log_scores = self.distribution.sample(context=context, sampling_size=sampling_size) length = min(remaining, len(more_samples)) more_samples, more_log_scores = more_samples[:length], more_log_scores[:length] samples, log_scores = ( samples + more_samples, torch.cat((log_scores, more_log_scores)) ) if samples else (more_samples, more_log_scores) remaining -= len(more_samples) t.update(len(more_samples)) return (samples, log_scores)