Source code for utils.helpers

import torch
import tqdm.autonotebook as tqdm

[docs]def batchify(func, batch, samples=list(), **args): all = [] with tqdm.tqdm(total=len(samples), desc=func.__name__) as pbar: for i in range(len(samples)//batch + 1): subsamples = samples[i * batch:(i+1) * batch] if subsamples: all.append(func(subsamples, **args)) pbar.update(batch) return torch.cat(all)
[docs]def get_token_first_indices(x, token): if 0 == x.shape[-1]: return torch.tensor(-1).repeat(x.shape[0]) else: mask = token == x mask_max_values, mask_max_indices = torch.max(mask, dim=1) mask_max_indices[mask_max_values == 0] = -1 return mask_max_indices