import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt


def t_to_np(X):
    if X.dtype in [torch.float32, torch.float64]:
        X = X.detach().cpu().numpy()
    return X


def np_to_t(X, device='cuda'):
    if torch.cuda.is_available() is False:
        device = 'cpu'

    from numpy import dtype
    if X.dtype in [dtype('float32'), dtype('float64')]:
        X = torch.from_numpy(X.astype(np.float32)).to(device)
    return X


def reorder(sequence):
    if sequence.dtype in [torch.float32, torch.float64]:
        return sequence.permute(0, 1, 4, 2, 3)
    else:
        return np.moveaxis(sequence, -1, 2)


def set_seed_device(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Use cuda if available
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
    return device


def load_checkpoint(model, opt, checkpoint_name):
    print("Loading Checkpoint from '{}'".format(checkpoint_name))
    checkpoint = torch.load(checkpoint_name)
    model.load_state_dict(checkpoint['state_dict'])
    opt.load_state_dict(checkpoint['optimizer'])

    return model, opt


def load_classifier(args):
    from classifier_cnn import Classifier_CNN
    args.hidden_dim = 128
    args.g_dim = 128
    args.conv_dim = 32
    args.k_dim = 40
    classifier = Classifier_CNN(args).to(device=args.device)

    print("Loading Checkpoint from '{}'".format(args.checkpoint_name))
    args.resume = args.checkpoint_name
    loaded_dict = torch.load(args.resume)
    classifier.load_state_dict(loaded_dict)
    classifier = classifier.to(args.device)

    return classifier


def load_dataset(args):
    path = args.dataset_path

    with open(path + 'sprites_X_test.npy', 'rb') as f:
        X_test = np.load(f)
    with open(path + 'sprites_A_test.npy', 'rb') as f:
        A_test = np.load(f)
    with open(path + 'sprites_D_test.npy', 'rb') as f:
        D_test = np.load(f)

    test_data = Sprite(data=X_test, A_label=A_test, D_label=D_test)

    test_loader = DataLoader(test_data,
                             num_workers=4,
                             batch_size=args.batch_size,
                             shuffle=False,
                             drop_last=True,
                             pin_memory=True)

    return test_data, test_loader


class Sprite(Dataset):
    def __init__(self, data, A_label, D_label):
        self.data = data
        self.A_label = A_label
        self.D_label = D_label
        self.N = self.data.shape[0]

    def __len__(self):
        return self.N

    def __getitem__(self, index):
        data = self.data[index] # (8, 64, 64, 3)
        A_label = self.A_label[index] # (4,)
        D_label = self.D_label[index] # ()

        return {"images": data, "A_label": A_label, "D_label": D_label, "index": index}


def get_unique_num(D, I, static_number):
    """ This function gets a parameter for number of unique components. Unique is a componenet with imag part of 0 or
        couple of conjugate couple """
    i = 0
    for j in range(static_number):
        index = len(I) - i - 1
        val = D[I[index]]

        if val.imag == 0:
            i = i + 1
        else:
            i = i + 2

    return i


def static_dynamic_split(D, static_size):
    # static/dynamic split
    Dr = np.real(D)
    Db = np.sqrt((Dr - np.ones(len(Dr))) ** 2 + np.imag(D) ** 2)
    I = np.argsort(Db)

    static_size = get_unique_num(D, I, static_size)
    Id, Is = I[static_size:], I[:static_size]

    return I, Id, Is


def imshow_seq(DATA, titles=None, figsize=(15, 3), fontsize=20, wspace=.01):
    rc = len(DATA[0])
    fig, axs = plt.subplots(rc, 2, figsize=figsize)

    for ii, data in enumerate(DATA):
        for jj, img in enumerate(data):
            img = t_to_np(img)
            tsz, csz, hsz, wsz = img.shape
            img = img.transpose((2, 0, 3, 1)).reshape((hsz, tsz * wsz, -1))

            axs[ii][jj].imshow(img)
            axs[ii][jj].set_axis_off()
            if titles is not None:
                axs[ii][jj].set_title(titles[ii][jj], fontsize=fontsize)

    plt.subplots_adjust(wspace=wspace, hspace=0)
    plt.show()

