import torch
import collections

from transformers import PreTrainedTokenizerBase, AutoTokenizer


def apply_to_sample(f, sample):
    if hasattr(sample, "__len__") and len(sample) == 0:
        return {}

    def _apply(x):
        if torch.is_tensor(x):
            return f(x)
        elif isinstance(x, collections.OrderedDict):
            # OrderedDict has attributes that needs to be preserved
            od = collections.OrderedDict(
                (key, _apply(value)) for key, value in x.items()
            )
            od.__dict__ = x.__dict__
            return od
        elif isinstance(x, dict):
            return {key: _apply(value) for key, value in x.items()}
        elif isinstance(x, list):
            return [_apply(x) for x in x]
        elif isinstance(x, tuple):
            return tuple(_apply(x) for x in x)
        elif isinstance(x, set):
            return {_apply(x) for x in x}
        else:
            return x

    return _apply(sample)


def move_to_cuda(sample, device=None):
    # Move tensor to cuda
    device = device or torch.cuda.current_device()

    def _move_to_cuda(tensor):
        # non_blocking is ignored if tensor is not pinned, so we can always set
        # to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620)
        return tensor.to(device=device, non_blocking=True)

    return apply_to_sample(_move_to_cuda, sample)


def device(precision: str):
    if precision == "fp32":
        return torch.float32
    elif precision == "fp16":
        return torch.float16
    elif precision == "bf16":
        return torch.bfloat16
    else:
        raise ValueError(f"Do not support {precision} tensor.")


def get_tokenizer(tokenizer: str, token: str) -> PreTrainedTokenizerBase:
    return AutoTokenizer.from_pretrained(tokenizer, legacy=False, token=token)


def device(precision: str):
    if precision == "fp32":
        return torch.float32
    elif precision == "fp16":
        return torch.float16
    elif precision == "bf16":
        return torch.bfloat16
    else:
        raise ValueError(f"Do not support {precision} tensor.")


def get_trainable_parameters(model: torch.nn.Module):
    params = list(
        filter(
            lambda p: p.requires_grad,
            model.parameters()
        )
    )
    return params


def count_trainable_parameters(model: torch.nn.Module):
    params = get_trainable_parameters(model)
    return sum([p.numel() for p in params])


def set_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)