Source code for distributions.base_distribution
import torch
from tqdm.autonotebook import trange
from disco.scorers.positive_scorer import Product
from disco.scorers.exponential_scorer import ExponentialScorer
from disco.scorers.boolean_scorer import BooleanScorer
from .distribution import Distribution
from .single_context_distribution import SingleContextDistribution
from disco.samplers.accumulation_sampler import AccumulationSampler
from disco.utils.device import get_device
from disco.utils.helpers import batchify
from disco.utils.moving_average import MovingAverage
[docs]class BaseDistribution(Distribution):
"""
Base distribution class, which can be used
to build an EBM.
"""
[docs] def constrain(self,
features, moments=None,
proposal=None, context_distribution=SingleContextDistribution(''), context_sampling_size=1,
n_samples=2**9, iterations=1000, learning_rate=0.05, tolerance=1e-5, sampling_size=2**5
):
"""
Constrains features to the base according to their moments,
so producing an EBM
Parameters
----------
features: list(feature)
multiple features to constrain
moments: list(float)
moments for the features. There should be as many moments as there are features
proposal: distribution
distribution to sample from, if different from self
context_distribution: distribution
to contextualize the sampling and scoring
context_sampling_size:
size of the batch when sampling context
n_samples: int
number of samples to use to fit the coefficients
learning_rate: float
multipliers of the delta used when fitting the coefficients
tolerance: float
accepted difference between the targets and moments
sampling_size:
size of the batch when sampling samples
Returns
-------
exponential scorer with fitted coefficients
"""
if list != type(features):
raise TypeError("features should be passed as a list.")
if not moments:
return Product(self, *features)
if list != type(moments):
raise TypeError("moments should be passed as a list.")
if not len(features) == len(moments):
raise TypeError("there should be as many as many moments as there are features.")
if all([BooleanScorer == type(f) for f in features])\
and all([1.0 == float(m) for m in moments]):
return Product(self, *features)
if not proposal:
proposal = self
context_samples, context_log_scores = context_distribution.sample(context_sampling_size)
proposal_samples = dict()
proposal_log_scores = dict()
joint_log_scores = dict()
feature_scores = dict()
for (context, log_score) in zip(context_samples, context_log_scores):
accumulator = AccumulationSampler(proposal, total_size=n_samples)
proposal_samples[context], proposal_log_scores[context] = accumulator.sample(
sampling_size=sampling_size, context=context
)
device = get_device(proposal_log_scores[context])
reference_log_scores = batchify(
self.log_score, sampling_size, samples=proposal_samples[context], context=context
).to(device)
joint_log_scores[context] = torch.tensor(log_score).repeat(n_samples).to(device) + reference_log_scores
feature_scores[context] = torch.stack(
([f.score(proposal_samples[context], context).to(device) for f in features])
)
coefficients = torch.tensor(0.0).repeat(len(features)).to(device)
targets = torch.tensor(moments).to(device)
with trange(iterations, desc='fitting exponential scorer') as t:
for i in t:
scorer = ExponentialScorer(features, coefficients)
numerator = torch.tensor(0.0).repeat(len(features)).to(device)
denominator = torch.tensor(0.0).repeat(len(features)).to(device)
for context in context_samples:
target_log_scores = joint_log_scores[context] + scorer.log_score(
proposal_samples[context], context
).to(device)
importance_ratios = torch.exp(target_log_scores - proposal_log_scores[context])
numerator += (importance_ratios * feature_scores[context]).sum(dim=1)
denominator += importance_ratios.sum()
moments = numerator / denominator
grad_coefficients = moments - targets
err = grad_coefficients.abs().max().item()
t.set_postfix(err=err)
if tolerance > err:
t.total_size = i
t.refresh()
break
coefficients -= learning_rate * grad_coefficients
return self * ExponentialScorer(features, coefficients)