Source code for distributions.lm_distribution

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM

from collections import namedtuple

from .base_distribution import BaseDistribution
from disco.utils.helpers import get_token_first_indices


TextSample = namedtuple('TextSample', ['token_ids', 'text'])

[docs]class LMDistribution(BaseDistribution): """ Language model distribution class, a core class for all NLP use-cases, relying on Huggingface's Transformers library. """ def __init__(self, network="gpt2", tokenizer="gpt2", nature="causal", freeze=True, length=40, device="cpu", **config ): """ Parameters ---------- network: string Transformers' name of a causal or seq2seq language model tokenizer: string Transformers' name for the related tokenizer nature: string "causal" or "seq2seq" to use the correct config class from Transformers freeze: boolean flag to eventually (not) freeze the network's parameters length: int number of tokens in the samples device: string reference of the computing device config: kwarg parameters and values passed to transformers' ```generate(…)``` """ assert nature in ["causal", "seq2seq"], "only causal and seq2seq model are handled." self.nature = nature klass = AutoModelForCausalLM if "causal" == nature else AutoModelForSeq2SeqLM self.tokenizer= AutoTokenizer.from_pretrained(tokenizer) self.tokenizer.pad_token = self.tokenizer.eos_token self.network = klass.from_pretrained( network, pad_token_id=self.tokenizer.eos_token_id ) self.device = device self.network.to(self.device) self.network.eval() # to make sure scoring is consistent if freeze: self.freeze(True) self.length = length default_params = { "top_k": 0, "top_p": 1.0, "typical_p": 1.0, "temperature": 1.0, "num_beams": 1 } self.params = default_params.copy() self.params.update(config) self.scorable = True if all(\ [default_params[k] == self.params[k] for k in default_params.keys()]\ ) else False
[docs] def to(self, device): self.device = device self.network.to(self.device)
[docs] def freeze(self, frozen=True): """Freeze (or unfreeze) parameters for gradient computation. Parameters ---------- frozen: boolean (True) state to transition to, default is to freeze """ self.network.requires_grad_(not frozen)
[docs] def log_score(self, samples, context="", grad=False, sum=True): """Computes log-probabilities for the samples according to the language model network in the given context Parameters ---------- samples: list(Sample) samples to (log-)score as a list() context: text context for which to (log-)score the samples grad: boolean flag to eventually compute the gradients, e.g. when fitting sum: boolean flag to eventually return token-level tensor of scores Returns ------- tensor of log-probabilities """ assert self.scorable, "this distribution's parameters make it unscorable." shapes = set([s.token_ids.shape for s in samples]) assert 1 == len(shapes), "sequences of token_ids should have the same shape, but got: {shapes}." device = self.device context = self.tokenizer.convert_ids_to_tokens(self.tokenizer.bos_token_id) if "" == context else context tokenized_context = self.tokenizer([context] * len(samples), return_tensors="pt", add_special_tokens=True) tokenized_context["input_ids"] = tokenized_context["input_ids"].to(device) tokenized_context["attention_mask"] = tokenized_context["attention_mask"].to(device) tokenized_samples = dict() tokenized_samples["input_ids"] = torch.stack([sample.token_ids for sample in samples]).to(device) first_eos_indices = get_token_first_indices( tokenized_samples["input_ids"], self.tokenizer.eos_token_id ) tokenized_samples["attention_mask"] = torch.where( self.tokenizer.pad_token_id == tokenized_samples["input_ids"], 0, 1 ) for i, ix in enumerate(first_eos_indices): tokenized_samples["attention_mask"][i][ix] = 1 tokenized_samples["attention_mask"] = tokenized_samples["attention_mask"].to(device) if "causal" == self.nature: shift = tokenized_context["input_ids"].shape[-1] - 1 last = -1 inputs = { "input_ids": torch.cat((tokenized_context["input_ids"], tokenized_samples["input_ids"]), 1), "attention_mask": torch.cat((tokenized_context["attention_mask"], tokenized_samples["attention_mask"]), 1) } labels = inputs["input_ids"] else: shift = None last = None inputs = tokenized_context labels = tokenized_samples["input_ids"] if grad: outputs = self.network(**inputs, labels=labels) else: with torch.no_grad(): outputs = self.network(**inputs, labels=labels) all_logprobs = outputs.logits[:, shift:last, :].log_softmax(-1) # [n_samples, length, vocab] seq_logprobs = torch.gather( all_logprobs, 2, tokenized_samples["input_ids"][:, :, None] ).squeeze(-1) # [n_samples, length] seq_logprobs = torch.where(1 == tokenized_samples["attention_mask"], seq_logprobs, torch.tensor(0.).to(device)) return seq_logprobs.sum(dim=1) if sum else seq_logprobs
[docs] def sample(self, context="", sampling_size=32, sum=True): """Samples sequences from the language model in the given context Parameters ---------- context: text contextual text for which to sample sampling_size: int number of sequences to sample sum: Boolean flag to eventually return token-level tensor of scores Returns ------- tuple of (list of Sample(tokens, text), tensor of logprobs) """ context = self.tokenizer.convert_ids_to_tokens(self.tokenizer.bos_token_id) if "" == context else context input_ids = self.tokenizer([context] * sampling_size, return_tensors="pt", add_special_tokens=True).input_ids.to(self.device) n_context_tokens = input_ids.shape[-1] if "causal" == self.nature: shift = n_context_tokens last = None else: shift = 1 last = None outputs = self.network.generate(input_ids, output_scores=True, return_dict_in_generate=True, max_new_tokens=self.length, do_sample=True, **self.params) all_logprobs = torch.stack(outputs.scores, dim=1).log_softmax(-1) # [sampling_size, length, vocab] token_seq_logprobs = torch.gather( all_logprobs, 2, outputs.sequences[:, shift:last][:, :, None] ).squeeze(-1) # [sampling_size, length] # we need to zero the (log-)scores of extra <eos> first_eos_indices = get_token_first_indices( outputs.sequences[:, shift:last], # starting at 1 to skip an eventual bos token self.tokenizer.eos_token_id ) non_pad_tokens = torch.cat( (outputs.sequences[:, shift:last][:, 0].unsqueeze(1), torch.where( self.tokenizer.pad_token_id == outputs.sequences[:, shift:last][:, 1:], -1, outputs.sequences[:, shift:last][:, 1:]) ), dim=1 ) non_pad_log_scores = torch.where(-1 != non_pad_tokens, token_seq_logprobs, torch.tensor(0.).to(self.device)) for i, ix in enumerate(first_eos_indices): non_pad_log_scores[i][ix] = token_seq_logprobs[i][ix] seq_logprobs = non_pad_log_scores.sum(dim=1) if sum else non_pad_log_scores output_tokens = outputs.sequences[:, shift:] # [sampling_size, length] return ( [TextSample(ots, self.tokenizer.decode(ots)) for ots in output_tokens], seq_logprobs )