import math

import torch as t


def one_hot(
    batch_size,
    indices,
    n_classes,
    device="cpu",
):
    tensor = t.zeros(batch_size, n_classes, requires_grad=False, device=device)
    tensor[range(batch_size), indices] = 1
    return tensor


def zero_hot(
    batch_size,
    indices,
    n_classes,
    device="cpu",
):
    tensor = t.ones(batch_size, n_classes, requires_grad=False, device=device)
    tensor[range(batch_size), indices] = 0
    return tensor


def select(
    tensor,
    index,
    keepdim=False,
):
    if keepdim:
        return tensor[range(tensor.shape[0]), index].unsqueeze(1)
    else:
        return tensor[range(tensor.shape[0]), index]


def copy(
    tensor,
    dim,
    repeat,
):
    if repeat == 0:
        return tensor

    tensor = tensor.unsqueeze(dim)
    shape = t.ones(len(tensor.shape), dtype=int, device="cpu")
    shape[dim] = repeat
    return tensor.repeat(*shape)


def merge(
    tensor,
    dims,
):
    shape = (
        tuple(tensor.shape[: dims[0]])
        + (math.prod(tuple(tensor.shape[d] for d in dims)),)
        + tuple(tensor.shape[dims[-1] + 1 :])
    )
    return tensor.view(*shape)


def unmerge(
    tensor,
    dim,
    shape,
):
    shape = tuple(tensor.shape[:dim]) + tuple(shape) + tuple(tensor.shape[dim + 1 :])
    return tensor.view(*shape)


def accuracy(
    predictions,
    actual,
):
    return (predictions == actual).float().mean()


# Exctracted from https://github.com/astooke/rlpyt
def update_state_dict(model, state_dict, tau=1, strip_ddp=True):
    """Update the state dict of ``model`` using the input ``state_dict``, which
    must match format.  ``tau==1`` applies hard update, copying the values, ``0<tau<1``
    applies soft update: ``tau * new + (1 - tau) * old``.
    """
    if strip_ddp:
        state_dict = strip_ddp_state_dict(state_dict)
    if tau == 1:
        model.load_state_dict(state_dict)
    elif tau > 0:
        update_sd = {k: tau * state_dict[k] + (1 - tau) * v for k, v in model.state_dict().items()}
        model.load_state_dict(update_sd)


def strip_ddp_state_dict(state_dict):
    """Workaround the fact that DistributedDataParallel prepends 'module.' to
    every key, but the sampler models will not be wrapped in
    DistributedDataParallel. (Solution from PyTorch forums.)"""
    clean_state_dict = type(state_dict)()
    for k, v in state_dict.items():
        key = k[7:] if k[:7] == "module." else k
        clean_state_dict[key] = v
    return clean_state_dict
