import numpy as np
import colorsys
import torch
import gym

def hash_args(args_dict, no_hash):
    import json
    import hashlib

    args_dict = {x: args_dict[x] for x in args_dict if x not in no_hash}
    args_str = json.dumps(args_dict, sort_keys=True, indent=4)
    args_hash = hashlib.md5(str.encode(args_str)).hexdigest()[:8]
    return args_hash

def get_env_name(name):
    return name + 'NoFrameskip-v4'

def make_env(env_id, **kwargs):
    if env_id in ['Ant', 'Swimmer', 'Humanoid', 'Walker2d', 'HalfCheetah', 'Hopper']:
        return gym.make(env_id + '-v3', **kwargs)
    kwargs.pop('reset_noise_scale')
    return gym.make(env_id, **kwargs)

def set_seed(seed):
    import random

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def unstack(a, axis):

    return [np.squeeze(x, axis) for x in np.split(a, a.shape[axis], axis=axis)]


def virtual_display():
    from pyvirtualdisplay import Display

    Display().start()


def batch_dot(x, y):
    return (x * y).sum(-1)


def get_device(net):
    '''
    Returns the `torch.device` on which the network resides.
    This method only makes sense when all module parameters reside on the **same** device.
    '''
    return list(net.parameters())[0].device

def check_weights_nan(parameters):
    for p in parameters:
        if torch.isnan(p).any():
            return True
    return False


def get_weights_norm(parameters, norm_type=2.0):
    with torch.no_grad():
        return torch.norm(torch.stack([torch.norm(p, norm_type) for p in parameters]), norm_type).item()


def get_grads_norm(parameters, norm_type=2.0):
    with torch.no_grad():
        return torch.norm(torch.stack([torch.norm(p.grad, norm_type) for p in parameters]), norm_type).item()

def scatter_3d(ax, x, y, z, c, elev=30, azim=-60, **kwargs):
    elev_rad = elev / 180 * np.pi
    azim_rad = azim / 180 * np.pi
    ax.view_init(elev, azim)
    idx_sort = np.argsort(np.cos(azim_rad) * x + np.sin(azim_rad) * y + np.sin(elev_rad) * z, kind='heapsort')
    ax.scatter(x[idx_sort], y[idx_sort], z[idx_sort], c=c[idx_sort], **kwargs)

class SubEncoder(torch.nn.Module):
    def __init__(self, enc, le, ri):
        super().__init__()
        self.enc = enc
        self.le = le
        self.ri = ri

    def forward(self, x):
        return self.enc(x)[:, self.le: self.ri]

def visualize(args, ax, enc, data, s=1, normalize=False, **kwargs):
    device = get_device(enc)
    n_theta_1, n_theta_2 = data.shape[:2]

    code = torch.empty((n_theta_1, n_theta_2, args.proj_size))
    with torch.no_grad():
        for i in range(n_theta_1):
            code[i] = enc(torch.tensor(data[i], dtype=torch.float32, device=device))
    code = code.reshape(-1, args.proj_size)
    if normalize:
        code = code / torch.linalg.norm(code, dim=-1, keepdim=True)
    code = code.cpu().numpy()

    if 'c' in kwargs:
        colors = kwargs['c']
        kwargs.pop('c')
    else:
        colors = np.empty((n_theta_1, n_theta_2, 3))
        for i in range(n_theta_1):
            for j in range(n_theta_2):
                hls = (j / n_theta_2, 0.5, 1)
                rgb = colorsys.hls_to_rgb(*hls)
                colors[i, j] = rgb
        colors = colors.reshape(-1, 3)

    if args.proj_size == 3:
        # make the panes transparent
        ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
        ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
        ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
        # make the grid lines transparent
        ax.xaxis._axinfo['grid']['color'] = (1, 1, 1, 0)
        ax.yaxis._axinfo['grid']['color'] = (1, 1, 1, 0)
        ax.zaxis._axinfo['grid']['color'] = (1, 1, 1, 0)

    if args.proj_size == 3:
        scatter_3d(ax, code[:, 0], code[:, 1], code[:, 2], c=colors, s=s, **kwargs)
    elif args.proj_size == 2:
        ax.scatter(code[:, 0], code[:, 1], c=colors, s=s, **kwargs)
    else:
        assert False

def compare_models(model_1, model_2):
    models_differ = 0
    for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
        if torch.equal(key_item_1[1], key_item_2[1]):
            pass
        else:
            models_differ += 1
            if (key_item_1[0] == key_item_2[0]):
                print('Mismtach found at', key_item_1[0])
            else:
                raise Exception
    if models_differ == 0:
        print('Models match perfectly! :)')