Source code for distributions.single_context_distribution
import torch
from .distribution import Distribution
[docs]class SingleContextDistribution(Distribution):
"""
Single context distribution class, useful to sample the
same context that is to fall back to a fixed-context case.
"""
def __init__(self, context=''):
"""
Parameters
----------
context: string
unique context to return when sampling
"""
self.context = context
[docs] def log_score(self, contexts):
"""Computes log-probabilities of the contexts
to match the instance's context
Parameters
----------
contexts: list(str)
list of contexts to (log-)score
Returns
-------
tensor of log-probabilities
"""
return torch.tensor([0 if self.context == context else -float("inf") for context in contexts])
[docs] def sample(self, sampling_size=32):
"""Samples multiple copies of the instance's context
Parameters
----------
sampling_size: int
number of contexts to sample
Returns
-------
tuple of (list of texts, tensor of log-probabilities)
"""
return (
[self.context] * sampling_size,
[0] * sampling_size
)