Source code for distributions.dataset_context_distribution
import torch
import numpy as np
from random import sample
from datasets import load_dataset
from .distribution import Distribution
[docs]class DatasetContextDistribution(Distribution):
"""
Context distribution class, fetching the contexts from a text file.
It can be used as a template for other context distributions.
"""
def __init__(self, dataset="", subset="", split="train", key="text", prefix=""):
"""
Parameters
----------
dataset: string
name of dataset in Hugging Face's Datasets
subset: string
reference of subset in dataset
split: string
reference of split in dataset/subset
key: string
key to use on row to pick the relevant part
prefix: text
text prepended to each context
"""
try:
self.dataset = load_dataset(dataset, subset, split=split)
except IOError:
self.dataset = list()
assert self.dataset, "there's an issue with the parameters of the dataset."
self.key = key
self.prefix = prefix
[docs] def log_score(self, contexts):
"""Computes plausible log-probabilities of the contexts.
Note that there's no check that the context are part of the dataset,
hence the plausible qualifier.
Parameters
----------
contexts: list(str)
list of contexts to (log-)score
Returns
-------
tensor of logprobabilities
"""
assert contexts, "there needs to be contexts to (log-)score."
return torch.log(torch.full((len(contexts), ), 1 / self.dataset.num_rows))
[docs] def sample(self, sampling_size=32):
"""Samples random elements from the list of contexts
Parameters
----------
sampling_size: int
number of contexts to sample
Returns
-------
tuple of (list of texts, tensor of logprobs)
"""
assert self.dataset.num_rows >= sampling_size, "the dataset does not have enough elements to sample."
contexts = [self.prefix + c[self.key]\
for c in self.dataset.select(sample(range(self.dataset.num_rows), sampling_size))]
return (contexts, self.log_score(contexts))