Source code for distributions.context_distribution
import torch
import numpy as np
from random import sample
from .distribution import Distribution
[docs]class ContextDistribution(Distribution):
"""
Context distribution class, fetching the contexts from a text file.
It can be used as a template for other context distributions.
"""
def __init__(self, path="contexts.txt"):
"""
Parameters
----------
path: string
path to context file
"""
try:
with open(path) as f:
self.contexts = f.readlines()
except IOError:
self.contexts = list()
assert self.contexts, "there's an issue with the context file provided."
[docs] def log_score(self, contexts):
"""Computes log-probabilities of the contexts
Parameters
----------
contexts: list(str)
list of contexts to (log-)score
Returns
-------
tensor of logprobabilities
"""
assert contexts, "there needs to be contexts to (log-)score."
n_contexts = len(contexts)
return torch.tensor(
[np.log(self.contexts.count(context) / n_contexts) if context in self.contexts\
else -float("inf")\
for context in contexts
]
)
[docs] def sample(self, sampling_size=32):
"""Samples random elements from the list of contexts
Parameters
----------
sampling_size: int
number of contexts to sample
Returns
-------
tuple of (list of texts, tensor of logprobs)
"""
assert len(self.contexts) >= sampling_size, "the contexts does not have enough elements to sample."
contexts = sample(self.contexts, sampling_size)
return (contexts, self.log_score(contexts))