import torch
from tqdm.autonotebook import trange
from collections import defaultdict
from transformers import (get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_cosine_schedule_with_warmup)
from disco.tuners.losses import *
from disco.samplers import AccumulationSampler
from disco.distributions.single_context_distribution import SingleContextDistribution
from disco.metrics import KL, TV, JS
from disco.utils.helpers import batchify
from disco.utils.observable import Observable, forward
from disco.utils.device import to_same_device, get_device
from disco.utils.moving_average import MovingAverage
from disco.utils.moving_average import average
divergence_pointwise_estimates_funcs = {
'tv': TV.pointwise_estimates,
'kl': KL.pointwise_estimates,
'js': JS.pointwise_estimates}
[docs]class Tuner():
"""
Generic tuning class.
Observables
-----------
step_idx: reports the current gradient updates index
step_idx: integer
ministep_idx: reports the current minibatch index
ministep_idx: integer
metric_updated: reports the value of a given metric
name: string
value: scalar
proposal_updated: reports the new proposal distribution when it is updated
proposal: Distribution
eval_samples_updated: reports a fresh set of samples that the network has not yet been trained on
context: text
samples: list
proposal_log_scores: list of floats
model_log_scores: list of floats
target_log_scores: list of floats
"""
default_params = {
"optimizer": "Adam",
"learning_rate": 1.41e-5,
"scheduler": "constant",
"warmup_steps":2*6,
"n_gradient_steps": 2**10, # number of gradient updates
"n_samples_per_step": 2**10, # number of samples used per update step
"scoring_size": 2**6, # number of samples used for one computation of the loss
"sampling_size": 2**5, # number of samples requested per sampling
"context_sampling_size": 2**4, # number of different contexts to sample
"divergence_evaluation_interval": 2**4, # number of gradient steps between evaluation of divergence
# (also used to eventually update proposal when offline tuning)
"proposal_update_metric": "kl" # the proposal will be updated if the model is better according to this metric
}
def __init__(self, model, target, proposal=None, context_distribution=SingleContextDistribution(), loss=KLLoss(), features=[],
track_metrics=["kl", "tv", "js"], track_divergence_from_base=False, **params):
"""
Parameters
----------
model: distribution
model distribution, to be tuned
target: product
EBM made of a distribution and one or multiple (log-)scorers
proposal: distribution
sampling distribution, if specified tuning is offline
else online (model is also used to sample from)
context_distribution: distribution
to contextualize the sampling from the proposal
loss: function
used to compute of the loss at each step
features: list of (label, feature)
feature monitored during the tuning
track_metrics: list of strings
metrics used to report differences between the target and the
model/proposal distributions.
track_divergence_from_base: boolean
whether or not track divergence from the base model of the EBM
params: dictionary
fine-tuning parameters
"""
self.params = self.default_params
self.params.update(params)
self.target = target
if proposal:
self.proposal = proposal
self.learning = "offline"
else:
self.proposal = model
self.learning = "online"
self.model = model
self.context_distribution = context_distribution
if self.params["proposal_update_metric"] not in track_metrics:
track_metrics.append(self.params["proposal_update_metric"])
self.z = defaultdict(MovingAverage)
self.divergence_estimates_target_proposal = dict()
self.divergence_estimates_target_model = dict()
for metric in track_metrics:
assert metric in divergence_pointwise_estimates_funcs, \
f"Unknown metric {metric}. " \
f"Options are: {list(divergence_pointwise_estimates_funcs.keys())}"
self.divergence_estimates_target_proposal[metric] = defaultdict(MovingAverage)
self.divergence_estimates_target_model[metric] = defaultdict(MovingAverage)
self.track_divergence_from_base = track_divergence_from_base
if self.track_divergence_from_base:
self.divergence_estimates_proposal_base = dict()
self.divergence_estimates_model_base = dict()
for metric in track_metrics:
assert metric in divergence_pointwise_estimates_funcs, \
f"Unknown metric {metric}. " \
f"Options are: {list(divergence_pointwise_estimates_funcs.keys())}"
self.divergence_estimates_proposal_base[metric] = defaultdict(MovingAverage)
self.divergence_estimates_model_base[metric] = defaultdict(MovingAverage)
self._loss = loss
self.features = list(features)
if "AdamW" == self.params["optimizer"]:
self.optimizer = torch.optim.AdamW(self.model.network.parameters(), lr=self.params["learning_rate"])
if "SGD" == self.params["optimizer"]:
self.optimizer = torch.optim.SGD(self.model.network.parameters(), lr=self.params["learning_rate"])
else:
self.optimizer = torch.optim.Adam(self.model.network.parameters(), lr=self.params["learning_rate"])
if "linear" == self.params["scheduler"]:
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, self.params["warmup_steps"])
elif "cosine" == self.params["scheduler"]:
self.scheduler = get_cosine_schedule_with_warmup(self.optimizer, self.params["warmup_steps"], self.params["n_gradient_steps"])
else:
self.scheduler = get_constant_schedule_with_warmup(self.optimizer, self.params["warmup_steps"])
# observables
self.parameters_updated = Observable()
self.step_idx_updated = Observable()
self.ministep_idx_updated = Observable()
self.metric_updated = Observable()
self.proposal_updated = Observable()
self.eval_samples_updated = Observable()
forward(self._loss.metric_updated, self.metric_updated)
if self.features:
self.eval_samples_updated.enroll(self.report_feature_moments)
[docs] def report_feature_moments(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores):
device = get_device(proposal_log_scores)
model_log_scores = model_log_scores.to(device)
logweights = model_log_scores - proposal_log_scores
importance_ratios = torch.exp(logweights)
moments = {}
for (label, feature) in self.features:
proposal_moment_pointwise_estimates = feature.log_score(samples, context).exp().to(device)
moments[f"proposal_{label}"] = torch.mean(proposal_moment_pointwise_estimates)
moments[f"model_{label}"] = torch.mean((importance_ratios * proposal_moment_pointwise_estimates))
for k, v in moments.items():
self.metric_updated.dispatch(k, v)
def _update_moving_z(self, proposal_log_scores, target_log_scores, context):
"""
Improves the `z` importance sampling estimate of Z
by averaging new samples
Parameters
----------
proposal_log_scores: array of floats
log-probabilities of the samples according to the proposal
target_log_scores: array of floats
log-probabilities of the samples according to the target
context: text
context for the samples
"""
target_log_scores, proposal_log_scores = to_same_device(target_log_scores, proposal_log_scores)
z_pointwise_estimates = torch.exp(target_log_scores - proposal_log_scores)
self.z[context] += z_pointwise_estimates
self.metric_updated.dispatch('z', average(self.z))
def _update_divergence_estimates_target_proposal(self, proposal_log_scores, target_log_scores, context):
"""
Improves the importance sampling estimate of D(p||q)
for every divergence D by averaging new samples
Parameters
----------
proposal_log_scores: array of floats
log-probabilities of the samples according to the proposal
target_log_scores: array of floats
log-probabilities of the samples according to the target
context: text
context for the samples
"""
target_log_scores, proposal_log_scores = to_same_device(target_log_scores, proposal_log_scores)
if self.z[context].value > 0:
for divergence_type, _ in self.divergence_estimates_target_proposal.items():
self.divergence_estimates_target_proposal[divergence_type][context] += \
divergence_pointwise_estimates_funcs[divergence_type](
target_log_scores, proposal_log_scores, self.z[context].value)
def _update_divergence_estimates_target_model(self, proposal_log_scores, target_log_scores, model_log_scores, context):
"""
Improves the importance sampling estimates of D(p||q)
for every divergence D by averaging new samples
Parameters
----------
proposal_log_scores: array of floats
log-probabilities of the samples according to the proposal
target_log_scores: array of floats
log-probabilities of the samples according to the target
model_log_scores: array of floats
log-probabilities of the samples according to the model
context: text
context for the samples
"""
target_log_scores, model_log_scores, proposal_log_scores = to_same_device(
target_log_scores, model_log_scores, proposal_log_scores)
if self.z[context].value > 0:
for divergence_type, _ in self.divergence_estimates_target_model.items():
self.divergence_estimates_target_model[divergence_type][context] += \
divergence_pointwise_estimates_funcs[divergence_type](
target_log_scores, model_log_scores, self.z[context].value,
proposal_log_scores=proposal_log_scores)
def _update_divergence_estimates_proposal_base(self, proposal_log_scores, base_log_scores, context):
"""
Improves the importance sampling estimate of D(p||q)
for every divergence D by averaging new samples
Parameters
----------
proposal_log_scores: array of floats
log-probabilities of the samples according to the proposal
base_log_scores: array of floats
log-probabilities of the samples according to the base
context: text
context for the samples
"""
base_log_scores, proposal_log_scores = to_same_device(base_log_scores, proposal_log_scores)
for divergence_type, _ in self.divergence_estimates_target_proposal.items():
self.divergence_estimates_proposal_base[divergence_type][context] += \
divergence_pointwise_estimates_funcs[divergence_type](
proposal_log_scores, base_log_scores, torch.as_tensor(1), proposal_log_scores)
def _update_divergence_estimates_model_base(self, proposal_log_scores, model_log_scores, base_log_scores, context):
"""
Improves the importance sampling estimate of D(p||q)
for every divergence D by averaging new samples
Parameters
----------
proposal_log_scores: array of floats
log-probabilities of the samples according to the proposal
model_log_scores: array of floats
log-probabilities of the samples according to the model
base_log_scores: array of floats
log-probabilities of the samples according to the base
context: text
context for the samples
"""
model_log_scores, base_log_scores, proposal_log_scores = \
to_same_device(model_log_scores, base_log_scores, proposal_log_scores)
for divergence_type, _ in self.divergence_estimates_target_proposal.items():
self.divergence_estimates_model_base[divergence_type][context] += \
divergence_pointwise_estimates_funcs[divergence_type](
model_log_scores, base_log_scores, torch.as_tensor(1), proposal_log_scores)
def _report_and_reset_divergence_estimate(self, divergence_estimates_dict, arguments_name):
"""
Reports all tracked divergences in the divergence_estimates_dict using
as key name a concatentation of the divergence type name and the arguments_name
divergence_estimates_dict: dictionary of strings to MovingAverage
The dictionary tracking divergence estimates
argument_name: string
A name that identifies the pair of distributions which divergence we are tracking
"""
for divergence_type, moving_averages in divergence_estimates_dict.items():
self.metric_updated.dispatch(f"{divergence_type}_{arguments_name}",
average(moving_averages))
divergence_estimates_dict[divergence_type] = defaultdict(MovingAverage)
def _update_proposal_if_better(self):
"""
Checks if D(p||.) is lower for model than for the proposal
and if so, updates the proposal
"""
if average(self.divergence_estimates_target_proposal[self.params["proposal_update_metric"]]) > \
average(self.divergence_estimates_target_model[self.params["proposal_update_metric"]]):
print("updating proposal according to KL divergence")
self.proposal.network.load_state_dict(self.model.network.state_dict())
self.metric_updated.dispatch('proposal_updated', 1)
self.proposal_updated.dispatch(self.proposal)
else:
self.metric_updated.dispatch('proposal_updated', 0)
def _compute_gradient(self, samples, proposal_log_scores, target_log_scores, model_log_scores, context, n_steps):
"""
Computes the gradient on a minibatch of samples
Parameters
----------
samples: list of items
samples from the proposal network
proposal_log_scores: array of floats
log-probabilities for the samples according to the proposal
target_log_scores: array of floats
log-probabilities for the samples according to the target
model_log_scores: array of floats
log-probabilities for the samples according to the model
context: text
context for the samples
n_steps: int
number of accumulation steps
"""
proposal_log_scores, target_log_scores, model_log_scores, z_value = to_same_device(
proposal_log_scores, target_log_scores, model_log_scores, self.z[context].value)
if z_value > 0:
loss = self._loss(samples, context, proposal_log_scores, target_log_scores, model_log_scores, z_value) / n_steps
self.metric_updated.dispatch('loss', loss.item())
loss.backward()
def _step(self):
"""
Performs a tuning step of the model distribution's network
Performs a single step of gradient updates on a batch of samples:
- obtains samples and their log-scores from the proposal network
- repeats gradient computations, with minibatches
- applies the accumulated gradients
"""
sampler = AccumulationSampler(self.proposal, total_size=self.params["n_samples_per_step"])
n_steps = self.params["n_samples_per_step"] // self.params["scoring_size"]
contexts, _ = self.context_distribution.sample(self.params["context_sampling_size"])
for context in contexts:
samples, proposal_log_scores = sampler.sample(sampling_size=self.params["sampling_size"], context=context)
target_log_scores = batchify(self.target.log_score, self.params["scoring_size"], samples=samples, context=context)
self._update_moving_z(proposal_log_scores, target_log_scores, context)
self._update_divergence_estimates_target_proposal(proposal_log_scores, target_log_scores, context)
if self.track_divergence_from_base:
base = self.target.scorers[0]
base_log_scores = batchify(base.log_score, self.params["scoring_size" ], samples=samples, context=context)
self._update_divergence_estimates_proposal_base(proposal_log_scores, base_log_scores, context)
for s in range(n_steps):
self.ministep_idx_updated.dispatch(s)
minibatch_slice = slice(s * self.params["scoring_size"], (s + 1) * self.params["scoring_size"])
mb_samples = samples[minibatch_slice]
mb_proposal_log_scores = proposal_log_scores[minibatch_slice]
mb_target_log_scores = target_log_scores[minibatch_slice]
mb_model_log_scores = self.model.log_score(mb_samples, context=context, grad=True)
self._update_divergence_estimates_target_model(
mb_proposal_log_scores, mb_target_log_scores, mb_model_log_scores, context)
if self.track_divergence_from_base:
mb_base_log_scores = base_log_scores[minibatch_slice]
self._update_divergence_estimates_model_base(mb_proposal_log_scores, mb_model_log_scores, mb_base_log_scores, context)
self.eval_samples_updated.dispatch(
context, mb_samples, mb_proposal_log_scores, mb_model_log_scores, mb_target_log_scores)
self._compute_gradient(
mb_samples, mb_proposal_log_scores, mb_target_log_scores, mb_model_log_scores,
context, n_steps * self.params["context_sampling_size"])
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
[docs] def tune(self):
"""
Fine-tunes model distribution's network
Fine-tunes the network of the model distribution:
- repeats n_gradient_steps tuning steps
- eventually updates the samplee according to KL divergence
"""
self.parameters_updated.dispatch(self.params)
torch.cuda.empty_cache()
with trange(self.params["n_gradient_steps"], desc='fine-tuning') as t:
for s in t:
self.step_idx_updated.dispatch(s)
self._step()
if 0 == (s + 1) % self.params["divergence_evaluation_interval"]:
if "offline" == self.learning:
self._update_proposal_if_better()
self._report_and_reset_divergence_estimate(self.divergence_estimates_target_model, 'target_model')
self._report_and_reset_divergence_estimate(self.divergence_estimates_target_proposal, 'target_proposal')
if self.track_divergence_from_base:
self._report_and_reset_divergence_estimate(self.divergence_estimates_model_base, 'model_base')
self._report_and_reset_divergence_estimate(self.divergence_estimates_proposal_base, 'proposal_base')