import torch
import random
import numpy as np


def to_bool(value):
    valid = {'true': True, 't': True, '1': True,
             'false': False, 'f': False, '0': False,
             }

    if isinstance(value, bool):
        return value

    lower_value = value.lower()
    if lower_value in valid:
        return valid[lower_value]
    else:
        raise ValueError('invalid literal for boolean: "%s"' % value)


def set_global_seed(seed):
    # Set seed for Python's built-in random
    random.seed(seed)

    # Set seed for numpy
    np.random.seed(seed)

    # Set seed for PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # for all GPU devices
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Set seed for torch DataLoader
    def _worker_init_fn(worker_id):
        worker_seed = torch.initial_seed() % 2 ** 32 + worker_id
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    return _worker_init_fn


def unzip(log_probs, top_k, vocab_size):
    log_probs, locs = torch.split(log_probs, [top_k, log_probs.size(1) - top_k], dim=1)

    full_log_probs = torch.full((log_probs.size(0), vocab_size), np.log(1e-10)) # cut off at log(1e-10) to prevent nan
    full_log_probs.scatter_(1, locs.to(torch.int64), log_probs)

    return full_log_probs


def preprocess_tensor(tensor):
    padding_size = 512 - tensor.size(0)
    if padding_size > 0:
        padding = torch.full((padding_size, tensor.size(1)), float('inf'))
        tensor = torch.cat([tensor, padding], dim=0)
    else:
        tensor = tensor[:512, :]
    return tensor


def process_tensor(tensor):
    inf_indices = (tensor[:, 0] == float('inf')).nonzero(as_tuple=True)[0]
    if len(inf_indices) > 0:
        i = inf_indices[0].item()
        tensor = tensor[:i, :]
    else:
        tensor = None
    return tensor


def truncate(tensor_1, tensor_2):
    min_size = min(tensor_1.size(0), tensor_2.size(0))
    return tensor_1[:min_size], tensor_2[:min_size]