import os
import pickle
import random
import torch
import copy
from xmeta.utils.seed import set_seed
import learn2learn as l2l
import numpy as np
from torchvision.transforms.functional import rotate


def get_tasksets(num_tasks=100, seed=42, **kwargs):
    tasksets = l2l.vision.benchmarks.get_tasksets(num_tasks=num_tasks, **kwargs)
    if num_tasks > 0:
        set_seed(seed)
        for ii in range(num_tasks):
            _ = tasksets.train[ii]
        for ii in range(num_tasks):
            _ = tasksets.validation[ii]
        for ii in range(num_tasks):
            _ = tasksets.test[ii]

    return tasksets


def save_tasks(tasksets, name: str, savedir: str = './cache'):
    train_tasks = [tasksets.train[ii] for ii in range(tasksets.num_tasks)]
    validation_tasks = [tasksets.validation[ii] for ii in range(tasksets.num_tasks)]
    test_tasks = [tasksets.test[ii] for ii in range(tasksets.num_tasks)]
    
    pkl_name = name + '_train.pkl'
    pkl_path = os.path.join(savedir, pkl_name)
    with open(pkl_path, 'bw') as f:
        pickle.dump(train_tasks, f)
    print(f'saved {pkl_path}')

    pkl_name = name + '_validation.pkl'
    pkl_path = os.path.join(savedir, pkl_name)
    with open(pkl_path, 'bw') as f:
        pickle.dump(validation_tasks, f)
    print(f'saved {pkl_path}')

    pkl_name = pkl_name = name + '_test.pkl'
    pkl_path = os.path.join(savedir, pkl_name)
    with open(pkl_path, 'bw') as f:
        pickle.dump(test_tasks, f)
    print(f'saved {pkl_path}')


def mask_image(x, mask_labels):
    imgs, labels = x
    imgs = copy.deepcopy(imgs)
    for mask_label in mask_labels:
        imgs[labels == mask_label] *= 0
    return imgs, labels


def mask_images(taskset, masks, ways):
    tasks = []
    for ii in range(len(taskset)):
        labels = random.sample(range(ways), masks)
        tasks.append(mask_image(taskset[ii], labels))
    
    return tasks


def mask_tasks(taskset, mask_idces, ways):
    tasks = []
    # mask_idces = random.sample(range(num_tasks), masks)
    for ii in range(len(taskset)):
        if ii in mask_idces:
            tasks.append(mask_image(taskset[ii], mask_labels=list(range(ways))))
        else:
            tasks.append(taskset[ii])
    
    return tasks


def bgr_image(x):
    imgs, labels = x
    imgs = imgs[:, [2, 1, 0], :, :]
    return imgs, labels


def recolor_image(x):  # [TODO] deepcopy
    imgs, labels = x
    for ii in range(len(labels)):
        label = labels[ii] % 8
        for jj in range(3):
            code = label % 2
            imgs[ii, jj, :, :] *= code
            label = (label - code) / 2
    return imgs, labels


def rotate_image(x, angles):
    imgs, labels = x
    assert len(angles) == len(imgs)
    imgs = copy.deepcopy(imgs)
    for ii in range(len(imgs)):
        imgs[ii, :, :, :] = rotate(imgs[ii], angles[ii])
    return imgs, labels


def dark_image(x, alpha=0.5):
    imgs, labels = x
    imgs = imgs * alpha
    return imgs, labels


def noise_image(x):
    imgs, labels = x
    imgs = torch.rand(imgs.shape) * 255
    return imgs, labels


def noise_tasks(taskset, noise_idces):
    tasks = []
    for ii in range(len(taskset)):
        if ii in noise_idces:
            tasks.append(noise_image(taskset[ii]))
        else:
            tasks.append(taskset[ii])
    
    return tasks


def shuffle_label(x):
    imgs, labels = x
    labels = labels[torch.randperm(len(labels))]
    return imgs, labels


def shuffle_tasks(taskset, sfl_idces):
    tasks = []
    for ii in range(len(taskset)):  
        if ii in sfl_idces:
            tasks.append(shuffle_label(taskset[ii]))
        else:
            tasks.append(taskset[ii])
    return tasks


def mix_tasks(task_0, task_1, ratio=0.5, shuffle=False):
    data_0, label_0 = task_0
    data_1, label_1 = task_1
    unique_label_0 = label_0.unique().detach().to('cpu').numpy().tolist()
    n_way_0 = int(ratio * len(unique_label_0))
    
    unique_labels_0 = random.sample(unique_label_0, k=n_way_0)
    pos_0 = torch.isin(label_0, torch.Tensor(unique_labels_0))
    pos_1 = torch.isin(label_1, torch.Tensor(unique_labels_0), invert=True)

    data = torch.cat([data_0[pos_0], data_1[pos_1]])
    label = torch.cat([label_0[pos_0], label_1[pos_1]])

    if shuffle:
        idces = random.sample(range(len(data)), k=len(data))
        data = data[idces]
        label = label[idces]
     
    return [data, label]


def degrade_images(task, ratio=0.5, alpha=1.):
    data, label = task
    data = copy.deepcopy(data)
    label = copy.deepcopy(label)

    n_mask = int(ratio * len(data))
    idces = random.sample(range(len(data)), k=n_mask)
    data[idces] = data[idces] * (1 - alpha)
     
    return [data, label]


class ImpureTaskset(l2l.data.task_dataset.TaskDataset):
    def __init__(self, taskset,
                 # impurities: dict = {},
                 seed=42):
        super().__init__(dataset=taskset.dataset,
                         num_tasks=taskset.num_tasks,
                         task_transforms=taskset.task_transforms,
                         task_collate=taskset.task_collate
                         )
        set_seed(seed)
        for ii in range(self.num_tasks):
            _ = super().__getitem__(ii)
        
        # self.impurities = impurities
        self.impurities = {}
           
    def __getitem__(self, i):
        assert isinstance(i, int), f'{i} is not an integer'
        if i in self.impurities:
            ret = self.impurities[i]
            if callable(ret):
                return ret(super().__getitem__(i))
            else:
                return ret
        else:
            return super().__getitem__(i)
    
    def add_impurities(self, idxes, converter, convert_now=True):
        sup = super()  # https://teratail.com/questions/210791
        if convert_now:
            impurities = {idx: converter(sup.__getitem__(idx)) for idx in idxes}
        else:
            impurities = {idx: converter for idx in idxes}
        self.impurities.update(impurities)


class RotatedTaskset(ImpureTaskset):
    def __init__(self, taskset: l2l.data.task_dataset.TaskDataset,
                 shots: int = 10, ways: int = 5,
                 seed: int = 42, rot_seed: int = 0,
                 n_augment: int = 8, angle_range: list = [- 90., 90.]):
        super().__init__(taskset=taskset, seed=seed)
        if hasattr(taskset, 'impurities'):
            self.impurities = taskset.impurities
        else:
            self.impurities = {}
        self.angle_range = angle_range
        self.n_augment = n_augment
        self.aug_idxes = np.zeros(self.num_tasks).astype(int)
        set_seed(seed)
        self.angles = np.random.uniform(
            low=angle_range[0], high=angle_range[1],
            size=(self.num_tasks, n_augment, shots * ways)
        )
    
    def __getitem__(self, i):
        ret = rotate_image(super().__getitem__(i),
                           angles=self.angles[i, self.aug_idxes[i]])
        self.aug_idxes[i] = (self.aug_idxes[i] + 1) % self.n_augment
        return ret


# class ShuffledTaskset(l2l.data.task_dataset.TaskDataset):
class ShuffledTaskset(ImpureTaskset):
    def __init__(self, taskset, seed=42, shuffle_seed=0):
        set_seed(seed)
        # super().__init__(dataset=taskset.dataset,
        #                  num_tasks=taskset.num_tasks,
        #                  task_transforms=taskset.task_transforms,
        #                  task_collate=taskset.task_collate
        #                  )
        # set_seed(seed)
        # for ii in range(self.num_tasks):
        #     _ = super().__getitem__(ii)
        super().__init__(taskset=taskset, seed=seed)
        if hasattr(taskset, 'impurities'):
            self.impurities = taskset.impurities
        else:
            self.impurities = {}
        idxes = list(range(taskset.num_tasks))
        set_seed(shuffle_seed)
        random.shuffle(idxes)
        self.idxes = idxes
    
    def __getitem__(self, i):
        assert isinstance(i, int), f'{i} is not an integer'
        return super().__getitem__(self.idxes[i % len(self.idxes)])


class ImpureTasksets:
    """
    Circumvent the problem that TaskDataset in learn2learn does not support
    item assignment
    """
    def __init__(self, tasksets, num_tasks: int, ways: int, shots: int,
                 train_mask_labels: int = None, train_mask_tasks: int = None,
                 train_noise_tasks: int = None, train_shuffle_tasks: int = None,
                 train_dark_tasks: int = None, train_recolor_tasks: int = None,
                 train_bgr_tasks: int = None,
                 savedir: str = './cache', seed: int = 42
                 ):

        dirty_idces = list(range(num_tasks))
        set_seed(seed)
        random.shuffle(dirty_idces)

        self.impurity_dict = {}
        self.mask_label_list = []
        pos = 0
        if train_mask_tasks is not None:
            self.impurity_dict['train_mask_tasks'] =\
                dirty_idces[pos: train_mask_tasks]
            pos += train_mask_tasks
        else:
            self.impurity_dict['train_mask_tasks'] = []
        if train_noise_tasks is not None:
            self.impurity_dict['train_noise_tasks'] =\
                dirty_idces[pos: pos + train_noise_tasks]
            pos += train_noise_tasks
        else:
            self.impurity_dict['train_noise_tasks'] = []

        if train_shuffle_tasks is not None:
            self.impurity_dict['train_shuffle_tasks'] =\
                dirty_idces[pos: pos + train_shuffle_tasks]
            pos += train_shuffle_tasks
        else:
            self.impurity_dict['train_shuffle_tasks'] = []
        
        if train_dark_tasks is not None:
            self.impurity_dict['train_dark_tasks'] =\
                dirty_idces[pos: pos + train_dark_tasks]
            pos += train_dark_tasks
        else:
            self.impurity_dict['train_dark_tasks'] = []

        if train_recolor_tasks is not None:
            self.impurity_dict['train_recolor_tasks'] =\
                dirty_idces[pos: pos + train_recolor_tasks]
            pos += train_recolor_tasks
        else:
            self.impurity_dict['train_recolor_tasks'] = []
        
        if train_bgr_tasks is not None:
            self.impurity_dict['train_bgr_tasks'] =\
                dirty_idces[pos: pos + train_bgr_tasks]
            pos += train_bgr_tasks
        else:
            self.impurity_dict['train_bgr_tasks'] = []

        if pos > 0:
            self.train = ImpureTaskset(tasksets.train, seed=seed)
            if train_mask_labels is not None:
                self.mask_label_list = random.sample(range(ways), train_mask_labels)
            else:
                self.mask_label_list = list(range(ways))
            self.train.add_impurities(self.impurity_dict['train_mask_tasks'],
                                      converter=(
                                          lambda x: mask_image(x, self.mask_label_list))
                                      )
            self.train.add_impurities(self.impurity_dict['train_noise_tasks'],
                                      converter=noise_image
                                      )
            self.train.add_impurities(self.impurity_dict['train_shuffle_tasks'],
                                      converter=shuffle_label
                                      )
            self.train.add_impurities(self.impurity_dict['train_dark_tasks'],
                                      converter=dark_image, convert_now=False
                                      )
            self.train.add_impurities(self.impurity_dict['train_recolor_tasks'],
                                      converter=recolor_image, convert_now=False
                                      )
            self.train.add_impurities(self.impurity_dict['train_bgr_tasks'],
                                      converter=bgr_image, convert_now=False
                                      )
        else:
            self.train = tasksets.train

        self.test = tasksets.test
        self.validation = tasksets.validation

        self.num_tasks = num_tasks
        self.ways = ways
        self.shots = shots
        self.train_mask_labels = train_mask_labels
        self.train_mask_tasks = train_mask_tasks
        self.savedir = savedir

        self.name = f'im_tasks{num_tasks}_ways{ways}_shots{shots}'
        if train_mask_labels is not None:
            self.name += f'_ml{train_mask_labels}'
        if train_mask_tasks is not None:
            self.name += f'_mt{train_mask_tasks}'
        if train_noise_tasks is not None:
            self.name += f'_nt{train_noise_tasks}'
        if train_shuffle_tasks is not None:
            self.name += f'_st{train_shuffle_tasks}'
        if train_recolor_tasks is not None:
            self.name += f'_st{train_recolor_tasks}'
        if train_bgr_tasks is not None:
            self.name += f'_st{train_bgr_tasks}'

    def save(self):
        pkl_path = os.path.join(self.savedir, 'index_dict.pkl')
        with open(pkl_path, 'wb') as f:
            pickle.dump(self.impurity_dict, f)
        self.save_tasks()

    def save_tasks(self):
        if hasattr(self.train, "impurities"):
            pkl_name = self.name + '_impurity.pkl'
            pkl_path = os.path.join(self.savedir, pkl_name)
            with open(pkl_path, 'bw') as f:
                pickle.dump(self.train.impurities, f)
            print(f'saved {pkl_path}')


def pollute_tasks(taskset, num_tasks: int,
                  test_recolor_tasks: int = None,
                  test_bgr_tasks: int = None,
                  seed: int = 42
                  ):
    taskset = ImpureTaskset(taskset, seed=seed)
    
    impure_idces = list(range(num_tasks))
    set_seed(seed)
    random.shuffle(impure_idces)

    pos = 0
    if test_recolor_tasks is not None:
        taskset.add_impurities(impure_idces[0: test_recolor_tasks],
                               converter=recolor_image, convert_now=False
                               )
        pos += test_recolor_tasks
    if test_bgr_tasks is not None:
        taskset.add_impurities(impure_idces[pos: pos + test_bgr_tasks],
                               converter=bgr_image, convert_now=False
                               )
        pos += test_bgr_tasks
    return taskset







