import torch
from torch.utils.data import Dataset
from typing import Optional
import os
from torchvision.datasets.utils import download_and_extract_archive
from typing import Optional, Callable, Tuple, Any, List
import torchvision.datasets as datasets
from torchvision.datasets.folder import default_loader
import scipy.io
import numpy as np
from tensorflow.keras.datasets import mnist
import sklearn.preprocessing
from scipy import ndimage
from PIL import Image


class DomainDataset(Dataset):
    def __init__(self, x, weight, transform=None):
        self.data = x.cpu().detach()
        self.targets = -1 * torch.ones(len(self.data))
        self.weight = weight
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.transform is not None:
            return self.transform(self.data[idx]), self.targets[idx], self.weight[idx]
        return self.data[idx], self.targets[idx], self.weight[idx]


class EncodeDataset(Dataset):
    def __init__(self, x, y, transform=None):
        self.data = x
        self.targets = y
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.transform is not None:
            # data = Image.fromarray((self.data[idx].squeeze(2)*255).astype(np.uint8))
            # return self.transform(data).float(), self.targets[idx]
            return self.transform(self.data[idx]).float(), self.targets[idx]
        return self.data[idx], self.targets[idx]


"""
    Make portraits dataset
"""
def shuffle(xs, ys):
    indices = list(range(len(xs)))
    np.random.shuffle(indices)
    return xs[indices], ys[indices]


def split_sizes(array, sizes):
    indices = np.cumsum(sizes)
    return np.split(array, indices)


def load_portraits_data(load_file='dataset_32x32.mat'):
    data = scipy.io.loadmat('./' + load_file)
    return data['Xs'], data['Ys'][0]

def make_portraits_data(n_src_tr, n_src_val, n_inter, n_target_unsup, n_trg_val, n_trg_tst,
                        load_file='dataset_32x32.mat'):
    xs, ys = load_portraits_data(load_file)
    src_end = n_src_tr + n_src_val
    inter_end = src_end + n_inter
    trg_end = inter_end + n_trg_val + n_trg_tst
    src_x, src_y = shuffle(xs[:src_end], ys[:src_end])
    trg_x, trg_y = shuffle(xs[inter_end:trg_end], ys[inter_end:trg_end])
    [src_tr_x, src_val_x] = split_sizes(src_x, [n_src_tr])
    [src_tr_y, src_val_y] = split_sizes(src_y, [n_src_tr])
    [trg_val_x, trg_test_x] = split_sizes(trg_x, [n_trg_val])
    [trg_val_y, trg_test_y] = split_sizes(trg_y, [n_trg_val])
    inter_x, inter_y = xs[src_end:inter_end], ys[src_end:inter_end]
    dir_inter_x, dir_inter_y = inter_x[-n_target_unsup:], inter_y[-n_target_unsup:]
    return (src_tr_x, src_tr_y, src_val_x, src_val_y, inter_x, inter_y,
            dir_inter_x, dir_inter_y, trg_val_x, trg_val_y, trg_test_x, trg_test_y)


"""
    Make rotated MNIST dataset
"""
def get_preprocessed_mnist():
    (train_x, train_y), (test_x, test_y) = mnist.load_data()
    train_x, test_x = train_x / 255.0, test_x / 255.0
    train_x, train_y = shuffle(train_x, train_y)
    train_x = np.expand_dims(np.array(train_x), axis=-1)
    test_x = np.expand_dims(np.array(test_x), axis=-1)
    return (train_x, train_y), (test_x, test_y)


def sample_rotate_images(xs, start_angle, end_angle):
    new_xs = []
    num_points = xs.shape[0]
    for i in range(num_points):
        if start_angle == end_angle:
            angle = start_angle
        else:
            angle = np.random.uniform(low=start_angle, high=end_angle)
        img = ndimage.rotate(xs[i], angle, reshape=False)
        new_xs.append(img)
    return np.array(new_xs)


def continually_rotate_images(xs, start_angle, end_angle):
    new_xs = []
    num_points = xs.shape[0]
    for i in range(num_points):
        angle = (i // 4200) * 5 #+ start_angle
        # print(angle)
        # angle = float(end_angle - start_angle) / num_points * i + start_angle
        img = ndimage.rotate(xs[i], angle, reshape=False)
        new_xs.append(img)
    return np.array(new_xs)


def _transition_rotation_dataset(train_x, train_y, test_x, test_y,
                                 source_angles, target_angles, inter_func,
                                 src_train_end, src_val_end, inter_end, target_end):
    assert(target_end <= train_x.shape[0])
    assert(train_x.shape[0] == train_y.shape[0])
    src_tr_x, src_tr_y = train_x[:src_train_end], train_y[:src_train_end]
    src_tr_x = sample_rotate_images(src_tr_x, source_angles[0], source_angles[1])
    src_val_x, src_val_y = train_x[src_train_end:src_val_end], train_y[src_train_end:src_val_end]
    src_val_x = sample_rotate_images(src_val_x, source_angles[0], source_angles[1])
    tmp_inter_x, inter_y = train_x[src_val_end:inter_end], train_y[src_val_end:inter_end]
    inter_x = inter_func(tmp_inter_x)
    dir_inter_x = sample_rotate_images(tmp_inter_x, target_angles[0], target_angles[1])
    dir_inter_y = np.array(inter_y)
    assert(inter_x.shape == dir_inter_x.shape)
    trg_val_x, trg_val_y = train_x[inter_end:target_end], train_y[inter_end:target_end]
    trg_val_x = sample_rotate_images(trg_val_x, target_angles[0], target_angles[1])
    trg_test_x, trg_test_y = test_x, test_y
    trg_test_x = sample_rotate_images(trg_test_x, target_angles[0], target_angles[1])
    return (src_tr_x, src_tr_y, src_val_x, src_val_y, inter_x, inter_y,
            dir_inter_x, dir_inter_y, trg_val_x, trg_val_y, trg_test_x, trg_test_y)


def dial_rotation_proportions(xs, source_angles, target_angles):
    N = xs.shape[0]
    new_xs = []
    rotate_ps = np.arange(N) / float(N - 1)
    is_target = np.random.binomial(n=1, p=rotate_ps)
    assert(is_target.shape == (N,))
    for i in range(N):
        if is_target[i]:
            angle = np.random.uniform(low=target_angles[0], high=target_angles[1])
        else:
            angle = np.random.uniform(low=source_angles[0], high=source_angles[1])
        cur_x = ndimage.rotate(xs[i], angle, reshape=False)
        new_xs.append(cur_x)
    return np.array(new_xs)


def dial_proportions_rotated_dataset(train_x, train_y, test_x, test_y,
                                     source_angles, target_angles,
                                     src_train_end, src_val_end, inter_end, target_end):
    inter_func = lambda x: dial_rotation_proportions(
        x, source_angles, target_angles)
    return _transition_rotation_dataset(
        train_x, train_y, test_x, test_y, source_angles, target_angles,
        inter_func, src_train_end, src_val_end, inter_end, target_end)


def make_rotated_dataset(train_x, train_y, test_x, test_y,
                         source_angles, inter_angles, target_angles,
                         src_train_end, src_val_end, inter_end, target_end):
    inter_func = lambda x: continually_rotate_images(x, inter_angles[0], inter_angles[1])
    return _transition_rotation_dataset(
        train_x, train_y, test_x, test_y, source_angles, target_angles,
        inter_func, src_train_end, src_val_end, inter_end, target_end)


def make_population_rotated_dataset(xs, ys, delta_angle, num_angles):
    images, labels = [], []
    for i in range(num_angles):
        cur_angle = i * delta_angle
        cur_images = sample_rotate_images(xs, cur_angle, cur_angle)
        images.append(cur_images)
        labels.append(ys)
    images = np.concatenate(images, axis=0)
    labels = np.concatenate(labels, axis=0)
    assert images.shape[1:] == xs.shape[1:]
    assert labels.shape[1:] == ys.shape[1:]
    return images, labels


def make_rotated_dataset_continuous(dataset, start_angle, end_angle, num_points):
    images, labels = [], []
    (train_x, train_y), (_, _) = dataset.load_data()
    train_x, train_y = shuffle(train_x, train_y)
    train_x = train_x / 255.0
    assert(num_points < train_x.shape[0])
    indices = np.random.choice(train_x.shape[0], size=num_points, replace=False)
    for i in range(num_points):
        angle = float(end_angle - start_angle) / num_points * i + start_angle
        idx = indices[i]
        img = ndimage.rotate(train_x[idx], angle, reshape=False)
        images.append(img)
        labels.append(train_y[idx])
    return np.array(images), np.array(labels)


def make_rotated_mnist(start_angle, end_angle, num_points, normalize=False):
    Xs, Ys = make_rotated_dataset(mnist, start_angle, end_angle, num_points)
    if normalize:
        Xs = np.reshape(Xs, (Xs.shape[0], -1))
        old_mean = np.mean(Xs)
        Xs = sklearn.preprocessing.normalize(Xs, norm='l2')
        new_mean = np.mean(Xs)
        Xs = Xs * (old_mean / new_mean)
    return np.expand_dims(np.array(Xs), axis=-1), Ys
