import torch
import torch.nn as nn
import numpy as np


def soft_update_from_to(source, target, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(
            target_param.data * (1.0 - tau) + param.data * tau
        )


def copy_model_params_from_to(source, target):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)


def set_requires_grad(net: nn.Module, allow_grad=True):
    for param in net.parameters():
        param.requires_grad = allow_grad


def fanin_init(tensor):
    size = tensor.size()
    if len(size) == 2:
        fan_in = size[0]
    elif len(size) > 2:
        fan_in = np.prod(size[1:])
    else:
        raise Exception("Shape must be have dimension at least 2.")
    bound = 1. / np.sqrt(fan_in)
    return tensor.data.uniform_(-bound, bound)


def fanin_init_weights_like(tensor):
    size = tensor.size()
    if len(size) == 2:
        fan_in = size[0]
    elif len(size) > 2:
        fan_in = np.prod(size[1:])
    else:
        raise Exception("Shape must be have dimension at least 2.")
    bound = 1. / np.sqrt(fan_in)
    new_tensor = FloatTensor(tensor.size())
    new_tensor.uniform_(-bound, bound)
    return new_tensor


def ortho_init(tensor, w_scale=1.0):
    nn.init.orthogonal_(tensor.data)
    tensor.data.mul_(w_scale)
    return tensor


def uniform_init(tensor, init_w=1.0):
    tensor.data.uniform_(-init_w, init_w)
    return tensor


def has_nan(tensor):
    return torch.any(torch.isnan(tensor)).cpu().numpy() == True


"""
GPU wrappers
"""

_use_gpu = False
device = None
_gpu_id = 0


def set_gpu_mode(mode, gpu_id=0):
    global _use_gpu
    global device
    global _gpu_id
    _gpu_id = gpu_id
    _use_gpu = mode
    device = torch.device("cuda:" + str(gpu_id) if _use_gpu else "cpu")


def gpu_enabled():
    return _use_gpu


def set_device(gpu_id):
    torch.cuda.set_device(gpu_id)


# noinspection PyPep8Naming
def FloatTensor(*args, torch_device=None, **kwargs):
    if torch_device is None:
        torch_device = device
    return torch.FloatTensor(*args, **kwargs, device=torch_device)


def from_numpy(*args, **kwargs):
    return torch.from_numpy(*args, **kwargs).float().to(device)


def get_numpy(tensor):
    return tensor.to('cpu').detach().numpy()


def zeros(*sizes, torch_device=None, **kwargs):
    if torch_device is None:
        torch_device = device
    return torch.zeros(*sizes, **kwargs, device=torch_device)


def ones(*sizes, torch_device=None, **kwargs):
    if torch_device is None:
        torch_device = device
    return torch.ones(*sizes, **kwargs, device=torch_device)


def ones_like(*args, torch_device=None, **kwargs):
    if torch_device is None:
        torch_device = device
    return torch.ones_like(*args, **kwargs, device=torch_device)


def eye(*args, torch_device=None, **kwargs):
    if torch_device is None:
        torch_device = device
    return torch.eye(*args, **kwargs, device=torch_device)


def arange(*args, torch_device=None, **kwargs):
    if torch_device is None:
        torch_device = device
    return torch.arange(*args, **kwargs, device=torch_device)


def randn(*args, torch_device=None, **kwargs):
    if torch_device is None:
        torch_device = device
    return torch.randn(*args, **kwargs, device=torch_device)


def randint(*args, torch_device=None, **kwargs):
    if torch_device is None:
        torch_device = device
    return torch.randint(*args, **kwargs, device=torch_device)


def randperm(*args, torch_device=None, **kwargs):
    if torch_device is None:
        torch_device = device
    return torch.randperm(*args, **kwargs, device=torch_device)


def randn_like(*args, torch_device=None, **kwargs):
    if torch_device is None:
        torch_device = device
    return torch.randn_like(*args, **kwargs, device=torch_device)


def zeros_like(*args, torch_device=None, **kwargs):
    if torch_device is None:
        torch_device = device
    return torch.zeros_like(*args, **kwargs, device=torch_device)


def empty(*args, torch_device=None, **kwargs):
    if torch_device is None:
        torch_device = device
    return torch.empty(*args, **kwargs, device=torch_device)


def tensor(*args, torch_device=None, **kwargs):
    if torch_device is None:
        torch_device = device
    return torch.tensor(*args, **kwargs, device=torch_device)


def normal(*args, **kwargs):
    return torch.normal(*args, **kwargs).to(device)


def range(end, **kwargs):
    return tensor(np.arange(end), **kwargs).long()


def expand_to_ts_form(x, num_particles, num_models):
    # Original x shape is (N * n_p, d)
    d = x.size(-1)
    reshaped = x.view(-1, num_models, num_particles, d)                 # (N // n_e, n_e, n_p, d)
    transposed = reshaped.transpose(0, 1)                               # (n_e, N//n_e, n_p, d)
    reshaped = transposed.contiguous().view(num_models, -1, d)          # (n_e, N / n_e * n_p, d)
    return reshaped

def flatten_from_ts(x, num_particles, num_models):
    # The original shape of x: (n_e, N / n_e * n_p, d)
    d = x.shape[-1]
    reshaped = x.view(num_models, -1, num_particles, d)     # (n_e, N / n_e, n_p, d)
    transposed = reshaped.transpose(0, 1)                   # (N / n_e, n_e, n_p, d)
    reshaped = transposed.contiguous().view(-1, d)          # (N * n_p, d)
    return reshaped