import numpy as np
import torch
import glob
import matplotlib.pyplot as plt
from sys import platform
from tqdm import tqdm
import scipy
import os
import ipdb


def shuffle_batch(images, labels):
    permutation = np.random.permutation(images.shape[0])
    return images[permutation], labels[permutation]


def extract_data(data, augment_data):
    images, char_nums = [], []
    if augment_data:
        for character in data:
            data = augment_character_set(data, character)
    for character_index, character in enumerate(data):
        for m, instance in enumerate(character):
            images.append(instance[0])
            char_nums.append(character_index)
    images = np.expand_dims(np.array(images), 3)
    char_number = np.array(char_nums)
    return images, char_number


def augment_character_set(data, character_set):
    rotation_90, rotation_180, rotation_270 = [], [], []
    for instance in character_set:
        image, char_num, char_language_num = instance
        rotation_90.append((np.rot90(image, k=1), char_num, char_language_num))
        rotation_180.append((np.rot90(image, k=2), char_num, char_language_num))
        rotation_270.append((np.rot90(image, k=3), char_num, char_language_num))
    return np.vstack((data, np.array([rotation_90, rotation_180, rotation_270])))


def rotate_images(imgs, rot_angle):
    imgs_rot = scipy.ndimage.rotate(imgs, rot_angle, axes=(1, 2))
    delta_x = int((imgs_rot.shape[1] - imgs.shape[1]) / 2)
    imgs_rot = imgs_rot[:, delta_x:(delta_x + imgs.shape[1]), delta_x:(delta_x + imgs.shape[1]), :]
    return imgs_rot


class OmniglotData(torch.utils.data.Dataset):

    def __init__(self, n_ways, n_support, n_query, train_size, validation_size, mode, device):

        self.num_support = n_support
        # self.way = 5

        self.ways = n_ways

        omniglot_data = np.load("./data/Omniglot/omniglot.npy", allow_pickle=True)
        np.random.shuffle(omniglot_data)

        self.name = 'omniglot'

        self.instances_per_char = n_query #20
        self.original_size = 28
        self.image_height = self.original_size
        self.image_width = self.original_size
        self.image_channels = 1
        self.total_chars = omniglot_data.shape[0]

        self.img_width = self.image_width

        # self.train_rotation_interval = [0.0, 180.0]
        # self.test_rotation_interval = [0.0, 180.0]  # [90.0, 180.0]

        self.train_images, self.train_char_nums = extract_data(omniglot_data[:train_size], augment_data=False)
        if validation_size != 0:
            self.validation_images, self.validation_char_nums = extract_data(omniglot_data[train_size:train_size + validation_size], augment_data=False)
        self.test_images, self.test_char_nums = extract_data(omniglot_data[train_size + validation_size:], augment_data=False)

        self.mode = mode

        if mode == 'train':
            self.num_images = len(self.train_images)
        elif mode == 'validation':
            self.num_images = len(self.validation_images)
        elif mode == 'test':
            self.num_images = len(self.test_images)

        self.device = device

    def get_num_classes(self):
        return self.total_chars

    def get_image_height(self):
        return self.image_height

    def get_image_width(self):
        return self.image_width

    def get_image_channels(self):
        return self.image_channels

    def __len__(self):
        if self.mode == 'train':
            return 1920  # with 100 epochs and batch size=32 this equals 100 * 1920 / 32 = 60,000 gradient steps
        if self.mode == 'validation':
            return 100
        if self.mode == 'test':
            return 100

    def __getitem__(self, idx):
        shot = self.num_support

        if self.mode == 'train':
            images = self.train_images
            character_indices = self.train_char_nums
            # rotation_interval = self.train_rotation_interval
        if self.mode == 'validation':
            images = self.validation_images
            character_indices = self.validation_char_nums
            # rotation_interval = self.train_rotation_interval
        elif self.mode == 'test':
            images = self.test_images
            character_indices = self.test_char_nums
            # rotation_interval = self.test_rotation_interval

        im_train, im_test, lbl_train, lbl_test = self._generate_random_task(images, character_indices, shot, self.ways)#, rotation_interval)
        
        
        self.split_quadrants = True
        if self.split_quadrants:
            train_images = self.split_img(torch.from_numpy(im_train).float()).view(-1, self.img_width // 2, self.img_width // 2, 1).permute(0, 3, 1, 2)
            test_images = self.split_img(torch.from_numpy(im_test).float()).view(-1, self.img_width // 2, self.img_width // 2, 1).permute(0,3,1,2)
            train_labels = torch.from_numpy(lbl_train).unsqueeze(1).repeat(1, 4, 1).reshape(-1, self.ways)
            test_labels = torch.from_numpy(lbl_test).unsqueeze(1).repeat(1, 4, 1).reshape(-1, self.ways)
        else:
            train_images = torch.from_numpy(im_train).permute(0, 3, 1, 2).float()
            train_labels = torch.from_numpy(lbl_train).float()
            test_images = torch.from_numpy(im_test).permute(0, 3, 1, 2).float()
            test_labels = torch.from_numpy(lbl_test).float()

        train_labels = torch.argmax(train_labels, -1)
        test_labels = torch.argmax(test_labels, -1)

        return train_images.to(self.device), train_labels.to(self.device), test_images.to(self.device), test_labels.to(self.device)

    def split_img(self, img):
        # img size batch x dim x dim x 1
        size = self.img_width
        mid = size // 2
        q1 = img[..., :mid, :mid, :].unsqueeze(1)
        q2 = img[..., mid:, :mid, :].unsqueeze(1)
        q3 = img[..., :mid, mid:, :].unsqueeze(1)
        q4 = img[..., mid:, mid:, :].unsqueeze(1)

        q = torch.cat((q1, q2, q3, q4), 1)
        return q



    def get_batch(self, source, batch_size, num_support):

        if source == 'train':
            return self._yield_random_task_batch(batch_size, self.train_images, self.train_char_nums, num_support, self.ways)
        elif source == 'validation':
            return self._yield_random_task_batch(batch_size, self.validation_images, self.validation_char_nums, num_support, self.ways)
        elif source == 'test':
            return self._yield_random_task_batch(batch_size, self.test_images, self.test_char_nums, num_support, self.ways)

    def _yield_random_task_batch(self, batch_size, images, character_indices, shot, way):#, rotation_interval):

        train_images_to_return, test_images_to_return = [], []
        train_labels_to_return, test_labels_to_return = [], []
        for task in range(batch_size):
            im_train, im_test, lbl_train, lbl_test = self._generate_random_task(images, character_indices, shot, way)#, rotation_interval)
            train_images_to_return.append(im_train)
            test_images_to_return.append(im_test)
            train_labels_to_return.append(np.argmax(lbl_train, axis=1))
            test_labels_to_return.append(np.argmax(lbl_test, axis=1))
        train_images = torch.from_numpy(np.array(train_images_to_return)).permute(0, 1, 4, 2, 3)
        test_images = torch.from_numpy(np.array(test_images_to_return)).permute(0, 1, 4, 2, 3)
        train_labels = torch.from_numpy(np.array(train_labels_to_return))
        test_labels = torch.from_numpy(np.array(test_labels_to_return))

        return train_images, train_labels, test_images, test_labels

    def _generate_random_task(self, images, character_indices, shot, way):#, rotation_interval):

        dh = (self.image_height - self.original_size)
        dw = (self.image_width - self.original_size)

        train_images_list, test_images_list = [], []
        task_characters = np.random.choice(np.unique(character_indices), way)
        for character in task_characters:
            idx = np.where(character_indices == character)[0]
            np.random.shuffle(idx)
            character_images = images[idx]

            train_images_list.append(np.pad(character_images[:shot], ((0, 0), (int(dh / 4), int(dh - int(dh / 4))), (int(dw / 4), int(dw - int(dw / 4))), (0, 0))))
            test_images_list.append(np.pad(character_images[shot:], ((0, 0), (int(dh / 4), int(dh - int(dh / 4))), (int(dw / 4), int(dw - int(dw / 4))), (0, 0))))

        train_images_to_return, test_images_to_return = np.vstack(train_images_list), np.vstack(test_images_list)
        train_labels_to_return = np.eye(way).repeat(shot, 0)

        test_labels_to_return = np.eye(way).repeat(int(test_images_to_return.shape[0] / way), 0)
        train_images_to_return, train_labels_to_return = shuffle_batch(train_images_to_return, train_labels_to_return)
        test_images_to_return, test_labels_to_return = shuffle_batch(test_images_to_return, test_labels_to_return)
        return train_images_to_return, test_images_to_return, train_labels_to_return, test_labels_to_return

    def get_symbol(self, symbol):

        dh = (self.image_height - self.original_size)
        dw = (self.image_width - self.original_size)

        idx = np.where(self.test_char_nums == symbol)[0]
        np.random.shuffle(idx)
        character_images = self.test_images[idx]
        train_images = np.pad(character_images, ((0, 0), (int(dh / 4), int(dh - int(dh / 4))), (int(dw / 4), int(dw - int(dw / 4))), (0, 0)))
        # test_images = np.pad(np.rot90(character_images, k=1, axes=[1, 2]), ((0, 0), (int(dh / 4), int(dh - int(dh / 4))), (int(dw / 4), int(dw - int(dw / 4))), (0, 0)))
        test_images = np.pad(character_images, ((0, 0), (int(dh / 4), int(dh - int(dh / 4))), (int(dw / 4), int(dw - int(dw / 4))), (0, 0)))

        train_images = torch.from_numpy(np.array(train_images)).permute(0, 3, 1, 2)
        test_images = torch.from_numpy(np.array(test_images)).permute(0, 3, 1, 2)

        return train_images, test_images


if __name__ == '__main__':

    omni = OmniglotData(2, 5, 20, 1000, 100, "train", "cpu")
    dloader = torch.utils.data.DataLoader(omni, batch_size=5)

    train_images, train_labels, test_images, test_labels = next(iter(dloader))
    print()
