import os
import copy
import random
from re import A
from tqdm import tqdm
import pandas
import numpy as np
from PIL import Image
import skimage
from sklearn.model_selection import train_test_split

from torchvision import datasets, transforms
from torchvision.datasets import CIFAR100, ImageFolder, Caltech256, StanfordCars
from torch.utils.data import ConcatDataset

from lib.dataset import StanfordCars as SC
from lib.dataset.imagenetr_utils import imagenet_r_mask, imagenet_a_mask, all_wnids
from lib.dataset.objectnet_dataset import ObjectNetDataset
from lib.dataset.imagenet_v2 import ImageNetV2
from lib.dataset.imagenet_caption import ImageNetCaptions
from lib.dataset.cub200 import CUB200
from utils.utils import set_seed


def split_dataset(dataset, ratio=0.2, seed=0):
    val_dataset = copy.deepcopy(dataset)
    if hasattr(dataset, 'y'): # Caltech256
        trainset, testset = train_test_split(list(range(len(dataset))), test_size=ratio, random_state=seed, stratify=dataset.y)
        trainset.sort()
        testset.sort()
        dataset.y = [dataset.y[i] for i in trainset]
        val_dataset.y = [val_dataset.y[i] for i in testset]
        dataset.index = [dataset.index[i] for i in trainset]
        val_dataset.index = [val_dataset.index[i] for i in testset]
    else:
        trainset, testset = train_test_split(list(range(len(dataset))), test_size=ratio, random_state=seed, stratify=dataset.targets)
        trainset.sort()
        testset.sort()
        if hasattr(dataset, 'samples'):
            dataset.samples = [dataset.samples[i] for i in trainset]
            val_dataset.samples = [val_dataset.samples[i] for i in testset]
        dataset.imgs = [dataset.imgs[i] for i in trainset]
        dataset.targets = [dataset.targets[i] for i in trainset]
        val_dataset.imgs = [val_dataset.imgs[i] for i in testset]
        val_dataset.targets = [val_dataset.targets[i] for i in testset]
    return dataset, val_dataset

def reindex_dataset(dataset, mask):
    """
    Reindex dataset according to mask
    """
    temp = np.nonzero(mask)[0]
    converter = {}
    for i, t in enumerate(temp):
        converter[i] = t # convert imagenet-a/r class to imagenet class
    dataset.targets = [converter[t] for t in dataset.targets]
    for i, sample in enumerate(dataset.samples):
        dataset.samples[i] = (sample[0], dataset.targets[i])
    return dataset


def get_concated_imagenet_c(root, preprocess=None, is_split=False, joint_only=False, shot=None):
    corruptions = [
        'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'
        ]
    datasets = []
    val_datasets = []
    print("Load imagenet-c")
    for corruption in tqdm(corruptions):
        for severity in range(1,6):
            path = os.path.join(root, 'imagenet-c', corruption, str(severity))
            dataset = ImageFolder(path, transform=preprocess)
            if joint_only:
                dataset = filter_joint(dataset)
            if is_split:
                dataset, val_dataset = split_dataset(dataset, 0.2, 0)
                val_datasets.append(val_dataset)
            if shot is not None:
                dataset = build_shot(dataset, shot)
            datasets.append(dataset)
    dataset = ConcatDataset(datasets)

    if is_split:
        val_dataset = ConcatDataset(val_datasets)
    else:
        val_dataset = None
    return dataset, val_dataset


def filter_joint(dataset, indices=[462, 470, 951, 954]):
    """
    Filter out the dataset to only include joint classes

    Args:
        dataset: dataset to filter
        indices: indices of joint classes
    """
    if 'ObjectNet' in dataset.__class__.__name__:
        dataset.samples = [dataset.imgs[i] for i in range(len(dataset.targets)) if dataset.targets[i] in indices]
        dataset.imgs = [e for e, t in zip(dataset.imgs, dataset.targets) if t in indices]
        labels = [dataset.class_to_label[t] for t in indices]
        dataset.pathDict = {k: v for k, v in dataset.pathDict.items() if v[1] in labels}
        dataset.targets = [e for e in dataset.targets if e in indices]
        dataset.classes = labels
    else:
        if 'ImageNetV2' in dataset.__class__.__name__:
            wnids = [str(e) for e in indices]
        else:
            wnids = [all_wnids[e] for e in indices]
        dataset_indices = [dataset.classes.index(w) for w in wnids]
        # filter target and sample.
        # samples, targets, imgs
        dataset.samples = [dataset.samples[i] for i in range(len(dataset.targets)) if dataset.targets[i] in dataset_indices]
        dataset.imgs = [dataset.imgs[i] for i in range(len(dataset.targets)) if dataset.targets[i] in dataset_indices]
        dataset.targets = [dataset.targets[i] for i in range(len(dataset.targets)) if dataset.targets[i] in dataset_indices]
        dataset.class_to_idx = {wnids[i]: i for i in range(len(wnids))}
        dataset.classes = wnids
    return dataset

       
def build_shot(dataset, shot):
    targets = np.asarray(dataset.targets)
    uniques = np.unique(targets)
    samples = []
    for u in uniques:
        indices = np.where(targets == u)[0]
#            np.random.shuffle(indices)
        samples.extend(indices[:shot])
    if 'ObjectNet' in dataset.__class__.__name__:
        dataset.imgs = [img for i, img in enumerate(dataset.imgs) if i in samples]
        dataset.targets = [t for i, t in enumerate(dataset.targets) if i in samples]
    else:
        dataset.samples = [dataset.samples[i] for i in samples]
        dataset.imgs = [dataset.imgs[i] for i in samples]
        dataset.targets = [dataset.targets[i] for i in samples]
    return dataset



def get_dataset(args, preprocess):
    if 'shot' in args.dataset:
        dname = args.dataset.split('-shot')[0]
        shot = int(args.dataset.split('-shot')[-1])
    else:
        dname = args.dataset
        shot = None
    if dname == 'caltech256':
        dataset = Caltech256(root=f"{args.root}/caltech256", transform=preprocess)
        dataset.targets = dataset.y
    elif dname == 'cifar100':
        dataset = CIFAR100(root=args.root, train=True, download=True)
    elif dname == 'cars':
        dataset = SC(root=args.root, split='train', transform=preprocess)
    elif dname == 'cub200':
        dataset = CUB200(root=args.root, split='train', transform=preprocess, download=False)
    elif dname == 'objectnet':
        dataset = ObjectNetDataset(root=f"{args.root}/objectnet-1.0", transform=preprocess)
    elif dname == 'objectnet-v2':
        dataset = ObjectNetDataset(root=f"{args.root}/objectnet-1.0", transform=preprocess, reindex=True)
    elif dname == 'imagenet-c':
        dataset, val_dataset = get_concated_imagenet_c(args.root, preprocess, not args.no_split, args.joint_only, shot=shot)
    elif dname == 'imagenet-v2':
        dataset = ImageNetV2(os.path.join(args.root, 'imagenet-v2'), transform=preprocess)
    elif 'caption' in dname:
        if 'description' in dname:
            dataset = ImageNetCaptions(root=f"{args.root}/imagenet/train", transform=preprocess, filename=f"{args.root}/imagenet_captions_description.csv")
        elif 'title' in dname:
            dataset = ImageNetCaptions(root=f"{args.root}/imagenet/train", transform=preprocess, filename=f"{args.root}/imagenet_captions_title.csv")
    elif dname == 'all':
        datasets = []
        datasets.append(ImageNetV2(os.path.join(args.root, 'imagenet-v2'), transform=preprocess))
        dataset_names = ['imagenet-a', 'imagenet-r', 'imagenet-sketch', 'imagenet-cartoon', 'imagenet-drawing']
        for dname in dataset_names:
            dataset = ImageFolder(os.path.join(args.root, dname), transform=preprocess)
            if dname == 'imagenet-a':
                dataset = reindex_dataset(dataset, imagenet_a_mask)
            elif dname == 'imagenet-r':
                dataset = reindex_dataset(dataset, imagenet_r_mask)
            datasets.append(dataset)
        datasets.append(ObjectNetDataset(root=f"{args.root}/objectnet-1.0", transform=preprocess))
        datasets.append(get_concated_imagenet_c(args.root, preprocess, not args.no_split, args.joint_only))
        dataset = ConcatDataset(datasets)
    elif '+' in dname:
        datasets = [] 
        dataset_names = dname.split('+')
        targets = []
        for dname in dataset_names:
            if dname == 'imagenet-v2':
                dataset = ImageNetV2(os.path.join(args.root, 'imagenet-v2'), transform=preprocess)
            elif dname == 'objectnet-v2':
                dataset = ObjectNetDataset(root=f"{args.root}/objectnet-1.0", transform=preprocess, reindex=False)
            else:
                dataset = ImageFolder(os.path.join(args.root, dname), transform=preprocess)
            if dname == 'imagenet-a':
                dataset = reindex_dataset(dataset, imagenet_a_mask)
            elif dname == 'imagenet-r':
                dataset = reindex_dataset(dataset, imagenet_r_mask)
            # TODO reindex for objectnet
            datasets.append(dataset)
            targets.append(dataset.targets)
        # reindexing.
        new_targets = np.unique(np.concatenate(targets), return_inverse=True)[1]
        L = [len(dataset) for dataset in datasets[:-1]]
        new_targets = np.split(new_targets, np.cumsum(L))
        for i in range(len(datasets)):
            if type(datasets[i].imgs[0]) is str:
                datasets[i].reindex = True
                datasets[i].label2reindex = {} 
                for t, nt in zip(np.unique(datasets[i].targets), np.unique(new_targets[i])):
                    datasets[i].label2reindex[t] = nt
            else:
                datasets[i].imgs = [(e[0], t) for e, t in zip(datasets[i].imgs, new_targets[i])]
                datasets[i].samples = datasets[i].imgs
            datasets[i].targets = new_targets[i]
        
        dataset = ConcatDataset(datasets)
    elif dname == 'imagenet':
        dataset = ImageFolder(os.path.join(args.root, 'imagenet/train'), transform=preprocess)
        # modify dataset so that its number is the same as imagenet-validation set.
    elif dname == 'imagenet-val':
        dataset = ImageFolder(os.path.join(args.root, 'imagenet/val'), transform=preprocess)
    elif 'imagenet' == dname and 'size' in dname: # legacy
        size = int(args.dataset.split('size')[-1])
        dataset = ImageFolder(os.path.join(args.root, 'imagenet/train'), transform=preprocess)
        dataset.samples = dataset.samples[::size]
        dataset.imgs = dataset.imgs[::size]
        dataset.targets = dataset.targets[::size]
    elif 'imagenet' == dname and 'shot' in args.dataset:
        shot = int(args.dataset.split('shot')[-1])
        dataset = ImageFolder(os.path.join(args.root, 'imagenet/train'), transform=preprocess)
        # filter N data for each class.
        dataset = build_shot(dataset, shot)
    else:
        dataset = ImageFolder(os.path.join(args.root, dname), transform=preprocess)
    # set seed

    if args.joint_only:
        dataset = filter_joint(dataset)

    np.random.seed(args.seed)
    random.seed(args.seed)
    train_dataset = dataset
    test_dataset = ImageFolder(os.path.join(args.root, 'imagenet/val'), transform=preprocess)
    if args.dataset == 'cars': # swap
        test_dataset = SC(root=args.root, split='test', transform=preprocess)
    elif args.dataset == 'caltech256':
        test_dataset = SC(root=args.root, split='test', transform=preprocess)

    if args.dataset == 'imagenet': # exception
        val_dataset = test_dataset
    elif args.dataset == 'cars':
        train_dataset = dataset
        val_dataset = StanfordCars(root=args.root, split='test', transform=preprocess)
    elif dname == 'cub200':
        train_dataset = dataset
        val_dataset = CUB200(root=args.root, split='test', transform=preprocess, download=False)
    elif args.no_split:
        val_dataset = copy.deepcopy(dataset)
    elif dname == 'cifar100':
        train_dataset = dataset
        val_dataset = CIFAR100(root=args.root, train=False, download=True)
    elif args.dataset != 'imagenet-c':
        train_dataset, val_dataset = split_dataset(dataset, 0.2, args.seed)

    if 'shot' in args.dataset and dname != 'imagenet-c':
        dataset = build_shot(dataset, shot)

    return train_dataset, val_dataset, test_dataset

