import os
import json
import numpy as np
import torchvision.transforms as transforms
import torch
import torchvision.datasets as datasets
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10, CIFAR100

# packages for tsne
# import pandas as pd
# from sklearn.manifold import TSNE
# import seaborn as sn
# import matplotlib.pyplot as plt


# def tsne(train_f_cross, train_y_cross, name='tsne',
#          n_components=2, verbose=0, learning_rate=1, perplexity=5, n_iter=1000, logger=None):
#     """ train_f_cross: X, numpy array. train_y_cross: y, numpy array """
#     num_y = len(list(set(train_y_cross)))

#     tsne = TSNE(n_components=n_components, verbose=verbose,
#                 learning_rate=learning_rate, perplexity=perplexity,
#                 n_iter=n_iter)
#     tsne_results = tsne.fit_transform(train_f_cross)

#     df_subset = pd.DataFrame(data={'tsne-2d-one': tsne_results[:, 0],
#                                     'tsne-2d-two': tsne_results[:, 1]})
#     df_subset['y'] = train_y_cross

#     plt.figure(figsize=(16,10))
#     sn.scatterplot(
#         x="tsne-2d-one", y="tsne-2d-two",
#         hue="y",
#         style="y",
#         s=100,
#         palette=sn.color_palette("hls", num_y),
#         data=df_subset,
#         legend="full",
#         alpha=0.3
#     )

#     dir = '' if logger is None else logger.dir()

#     plt.savefig(dir + name)
#     plt.close()
#     quit()


def id_ood_accuracy(output, target, id_classes):
    total_id_count = 0
    total_ood_count = 0
    correct_id_count = 0
    correct_ood_count = 0

    _, pred = output.topk(1, 1, True, True)
    pred = pred.t()
    pred = pred[0]

    for i, y in enumerate(target):
        if y in id_classes:
            total_id_count += 1
            if pred[i] == y:
                correct_id_count += 1
        else:
            total_ood_count += 1
            if pred[i] == y:
                correct_ood_count += 1
    # print("Accuracy Stats | ID:", correct_id_count, "/", total_id_count, " - OOD:", correct_ood_count, "/", total_ood_count)
    return [correct_id_count * 100.0 / total_id_count, correct_ood_count * 100.0 / total_ood_count, (correct_id_count + correct_ood_count) * 100.0 / (total_id_count + total_ood_count)]
      

def accuracy(output, target, topk=(1,), output_has_class_ids=False):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    if not output_has_class_ids:
        output = torch.Tensor(output)
    else:
        output = torch.LongTensor(output)
    target = torch.LongTensor(target)
    with torch.no_grad():
        maxk = max(topk)
        batch_size = output.shape[0]
        if not output_has_class_ids:
            _, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
        else:
            pred = output[:, :maxk].t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

class MyDataset(Dataset):
    def __init__(self, data, indices):
        self.data = data
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, index):
        x, y = self.data[index]
        return x, y

def setup_cifar10_dataloader(dirname, training, min_class, max_class, batch_size=256, augment=False, shuffle=False,
                              sampler=None, batch_sampler=None, num_workers=8, open_world = False):
    
    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
    
    if training: 
        dataset = CIFAR10(root="~/data", train=True, download=True, transform=transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize
                ]))
    else:
        dataset = CIFAR10(root="~/data", train=False, download=True, transform=transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize
                ]))

    training_portion = 0.8

    if open_world:
        curr_idxs = []
        for i in range(0,10):
            curr_idxs_temp = np.where((np.array(dataset.targets) == i))[0]
            curr_idxs_temp = curr_idxs_temp[int(len(curr_idxs_temp) * training_portion):]
            curr_idxs_temp = curr_idxs_temp[:100]
            curr_idxs.extend(curr_idxs_temp)
    elif training:
        curr_idxs = []
        for i in range(0,10):
            curr_idxs_temp = np.where((np.array(dataset.targets) == i))[0]
            curr_idxs_temp = curr_idxs_temp[:int(len(curr_idxs_temp) * training_portion)]
            if i >= min_class and i < max_class:
                curr_idxs.extend(curr_idxs_temp)

    else:
        curr_idxs = np.where((np.array(dataset.targets) >= min_class) & (np.array(dataset.targets) < (max_class)))[0]

    if batch_sampler is None and sampler is None:
        shuffle = True

        if shuffle:
            sampler = torch.utils.data.sampler.SubsetRandomSampler(curr_idxs)
        else:
            sampler = IndexSampler(curr_idxs)

    ds = MyDataset(dataset, curr_idxs)
        
    loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                                         batch_sampler=batch_sampler, sampler=sampler)

    if open_world:
        loaders = []
        loaders.append(loader)
        return loaders

    return loader



def setup_cifar100_dataloader(dirname, training, min_class, max_class, batch_size=256, augment=False, shuffle=False,
                              sampler=None, batch_sampler=None, num_workers=8, open_world = False):
    # # Data loading code
    normalize = transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2009, 0.1984, 0.2023])

    
    if training: 
        dataset = CIFAR100(root="~/data", train=True, download=True, transform=transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize
                ]))
    else:
        dataset  = CIFAR100(root="~/data", train=False, download=True, transform=transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize
                ]))
   

    training_portion = 0.8

    if open_world:
        curr_idxs = []
        for i in range(0,100):
            curr_idxs_temp = np.where((np.array(dataset.targets) == i))[0]
            curr_idxs_temp = curr_idxs_temp[int(len(curr_idxs_temp) * training_portion):]
            curr_idxs.extend(curr_idxs_temp)
    elif training:
        curr_idxs = []
        for i in range(0,100):
            curr_idxs_temp = np.where((np.array(dataset.targets) == i))[0]
            curr_idxs_temp = curr_idxs_temp[:int(len(curr_idxs_temp) * training_portion)]
            if i >= min_class and i < max_class:
                curr_idxs.extend(curr_idxs_temp)

    else:
        curr_idxs = np.where((np.array(dataset.targets) >= min_class) & (np.array(dataset.targets) < (max_class)))[0]

    if batch_sampler is None and sampler is None:
        shuffle = True
        if shuffle:
            sampler = torch.utils.data.sampler.SubsetRandomSampler(curr_idxs)
        else:
            sampler = IndexSampler(curr_idxs)


    ds = MyDataset(dataset, curr_idxs)
    loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                                         batch_sampler=batch_sampler, sampler=sampler)

    if open_world:
        loaders = []
        loaders.append(loader)
        return loaders

    return loader




def setup_tinyimagenet_dataloader(dirname, training, min_class, max_class, batch_size=256, augment=False, shuffle=False,
                              sampler=None, batch_sampler=None, num_workers=8, open_world = False):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    if open_world:
        data_path =f'~/data/tiny-imagenet-200/train'      
    elif training: 
        data_path = '~/data/tiny-imagenet-200/train'
    else:
        data_path = '~/data/tiny-imagenet-200/val'

    dataset = datasets.ImageFolder(
        root=data_path,
        transform=transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize
                ]))
    
    training_portion = 0.8

    if open_world:
        curr_idxs = []
        for i in range(0,200):
            curr_idxs_temp = np.where((np.array(dataset.targets) == i))[0]
            curr_idxs_temp = curr_idxs_temp[int(len(curr_idxs_temp) * training_portion):]
            curr_idxs.extend(curr_idxs_temp)
    elif training:
        curr_idxs = []
        for i in range(0,200):
            curr_idxs_temp = np.where((np.array(dataset.targets) == i))[0]
            curr_idxs_temp = curr_idxs_temp[:int(len(curr_idxs_temp) * training_portion)]
            if i >= min_class and i < max_class:
                curr_idxs.extend(curr_idxs_temp)
    else:
        curr_idxs = np.where((np.array(dataset.targets) >= min_class) & (np.array(dataset.targets) < (max_class)))[0]

    if batch_sampler is None and sampler is None:
        shuffle = True
        if shuffle:
            sampler = torch.utils.data.sampler.SubsetRandomSampler(curr_idxs)
        else:
            sampler = IndexSampler(curr_idxs)

    ds = MyDataset(dataset, curr_idxs)

    loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                                         batch_sampler=batch_sampler, sampler=sampler)

    if open_world: 
        loaders = []
        loaders.append(loader)
        return loaders
    
    return loader




def setup_cifar10_dataloader_cil(dirname, training, min_class, max_class, batch_size=256, augment=False, shuffle=False,
                              sampler=None, batch_sampler=None, num_workers=8, open_world = False):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    if training: 
        data_path = f'{dirname}/train'
    else:
        data_path = f'{dirname}/test'

    dataset = datasets.ImageFolder(
        root=data_path,
        transform=transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize
                ]))
    
    training_portion = 0.8

    if open_world:
        num_classes_cw = max_class * 5 // 10
        num_task = 5
        num_cls_per_task = (max_class - num_classes_cw) // num_task
        loaders = []
        for curr_task in range(num_task):
            curr_idxs = []
            for i in range(max_class): # max class should be equal to num_classes
                curr_idxs_temp = np.where((np.array(dataset.targets) == i))[0]
                if i < num_classes_cw:
                    num_img_per_task = int(len(curr_idxs_temp) * (0.2)) // num_task
                    curr_idxs_temp = curr_idxs_temp[int(len(curr_idxs_temp) * training_portion) + curr_task * num_img_per_task: int(len(curr_idxs_temp) * training_portion) + (curr_task + 1) * num_img_per_task]
                    curr_idxs_temp = curr_idxs_temp[:20]
                    curr_idxs.extend(curr_idxs_temp)
                elif i in range(num_classes_cw + curr_task * num_cls_per_task, num_classes_cw + (curr_task + 1) * num_cls_per_task):
                    curr_idxs_temp = curr_idxs_temp[int(len(curr_idxs_temp) * training_portion):]
                    curr_idxs_temp = curr_idxs_temp[:100]
                    curr_idxs.extend(curr_idxs_temp)

            sampler = torch.utils.data.sampler.SubsetRandomSampler(curr_idxs)
            ds = MyDataset(dataset, curr_idxs)

            loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                                         batch_sampler=batch_sampler, sampler=sampler)
            loaders.append(loader)
        return loaders

    elif training:
        curr_idxs = []
        for i in range(0,max_class): 
            if i >= min_class and i < max_class:
                curr_idxs_temp = np.where((np.array(dataset.targets) == i))[0]
                curr_idxs_temp = curr_idxs_temp[:int(len(curr_idxs_temp) * training_portion)]
                curr_idxs.extend(curr_idxs_temp)
    else:
        curr_idxs = np.where((np.array(dataset.targets) >= min_class) & (np.array(dataset.targets) < (max_class)))[0]


    if batch_sampler is None and sampler is None:
        shuffle = True
        if shuffle:
            sampler = torch.utils.data.sampler.SubsetRandomSampler(curr_idxs)
        else:
            sampler = IndexSampler(curr_idxs)

    ds = MyDataset(dataset, curr_idxs)
        
    loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                                         batch_sampler=batch_sampler, sampler=sampler)

    return loader


def setup_cifar100_dataloader_cil(dirname, training, min_class, max_class, batch_size=256, augment=False, shuffle=False,
                              sampler=None, batch_sampler=None, num_workers=8, open_world = False):
    # # Data loading code
    normalize = transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2009, 0.1984, 0.2023])

    if training: 
        data_path = f'{dirname}/train'
    else:
        data_path = f'{dirname}/test'

    dataset = datasets.ImageFolder(
        root=data_path,
        transform=transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize
                ]))
    
    training_portion = 0.8

    if open_world:
        num_classes_cw = max_class * 5 // 10
        num_task = 5
        num_cls_per_task = (max_class - num_classes_cw) // num_task
        loaders = []
        for curr_task in range(num_task):
            curr_idxs = []
            for i in range(max_class): # max class should be equal to num_classes
                curr_idxs_temp = np.where((np.array(dataset.targets) == i))[0]
                if i < num_classes_cw:
                    num_img_per_task = int(len(curr_idxs_temp) * (0.2)) // num_task
                    curr_idxs_temp = curr_idxs_temp[int(len(curr_idxs_temp) * training_portion) + curr_task * num_img_per_task: int(len(curr_idxs_temp) * training_portion) + (curr_task + 1) * num_img_per_task]
                    curr_idxs.extend(curr_idxs_temp)
                elif i in range(num_classes_cw + curr_task * num_cls_per_task, num_classes_cw + (curr_task + 1) * num_cls_per_task):
                    curr_idxs_temp = curr_idxs_temp[int(len(curr_idxs_temp) * training_portion):]
                    curr_idxs.extend(curr_idxs_temp)

            sampler = torch.utils.data.sampler.SubsetRandomSampler(curr_idxs)
            ds = MyDataset(dataset, curr_idxs)

            loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                                         batch_sampler=batch_sampler, sampler=sampler)
            loaders.append(loader)
        return loaders

    elif training:
        curr_idxs = []
        for i in range(0,max_class): 
            if i >= min_class and i < max_class:
                curr_idxs_temp = np.where((np.array(dataset.targets) == i))[0]
                curr_idxs_temp = curr_idxs_temp[:int(len(curr_idxs_temp) * training_portion)]
                curr_idxs.extend(curr_idxs_temp)
    else:
        curr_idxs = np.where((np.array(dataset.targets) >= min_class) & (np.array(dataset.targets) < (max_class)))[0]

    if batch_sampler is None and sampler is None:
        shuffle = True
        if shuffle:
            sampler = torch.utils.data.sampler.SubsetRandomSampler(curr_idxs)
        else:
            sampler = IndexSampler(curr_idxs)

    ds = MyDataset(dataset, curr_idxs)
        
    loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                                         batch_sampler=batch_sampler, sampler=sampler)

    return loader


def setup_tinyimagenet_dataloader_cil(dirname, training, min_class, max_class, batch_size=256, augment=False, shuffle=False,
                              sampler=None, batch_sampler=None, num_workers=8, open_world = False):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    if training: 
        data_path = f'{dirname}/train'
    else:
        data_path = f'{dirname}/val'

    dataset = datasets.ImageFolder(
        root=data_path,
        transform=transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize
                ]))
    
    training_portion = 0.8

    if open_world:
        num_classes_cw = max_class * 5 // 10
        num_task = 5
        num_cls_per_task = (max_class - num_classes_cw) // num_task
        loaders = []
        for curr_task in range(num_task):
            curr_idxs = []
            for i in range(max_class): # max class should be equal to num_classes
                curr_idxs_temp = np.where((np.array(dataset.targets) == i))[0]
                if i < num_classes_cw:
                    num_img_per_task = int(len(curr_idxs_temp) * (0.2)) // num_task
                    curr_idxs_temp = curr_idxs_temp[int(len(curr_idxs_temp) * training_portion) + curr_task * num_img_per_task: int(len(curr_idxs_temp) * training_portion) + (curr_task + 1) * num_img_per_task]
                    curr_idxs.extend(curr_idxs_temp)
                elif i in range(num_classes_cw + curr_task * num_cls_per_task, num_classes_cw + (curr_task + 1) * num_cls_per_task):
                    curr_idxs_temp = curr_idxs_temp[int(len(curr_idxs_temp) * training_portion):]
                    curr_idxs.extend(curr_idxs_temp)

            sampler = torch.utils.data.sampler.SubsetRandomSampler(curr_idxs)
            ds = MyDataset(dataset, curr_idxs)

            loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                                         batch_sampler=batch_sampler, sampler=sampler)
            loaders.append(loader)
        return loaders

    elif training:
        curr_idxs = []
        for i in range(0,max_class): 
            if i >= min_class and i < max_class:
                curr_idxs_temp = np.where((np.array(dataset.targets) == i))[0]
                curr_idxs_temp = curr_idxs_temp[:int(len(curr_idxs_temp) * training_portion)]
                curr_idxs.extend(curr_idxs_temp)
    else:
        curr_idxs = np.where((np.array(dataset.targets) >= min_class) & (np.array(dataset.targets) < (max_class)))[0]

    if batch_sampler is None and sampler is None:
        shuffle = True
        if shuffle:
            sampler = torch.utils.data.sampler.SubsetRandomSampler(curr_idxs)
        else:
            sampler = IndexSampler(curr_idxs)

    ds = MyDataset(dataset, curr_idxs)
        
    loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                                         batch_sampler=batch_sampler, sampler=sampler)

    return loader


class IndexSampler(torch.utils.data.Sampler):
    """Samples elements sequentially, always in the same order.
    """

    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)

def get_cifar10_loader(images_path, min_class, max_class, training, batch_size=256, shuffle=False, open_world=False, is_cil=False):
    if is_cil:
        loader = setup_cifar10_dataloader_cil(images_path, training, min_class, max_class, batch_size=batch_size, shuffle=shuffle, open_world=open_world)
    else:
        loader = setup_cifar10_dataloader(images_path, training, min_class, max_class, batch_size=batch_size, shuffle=shuffle, open_world=open_world)
    return loader

def get_cifar100_loader(images_path, min_class, max_class, training, batch_size=256, shuffle=False, open_world=False, is_cil=False):
    if is_cil:
        loader = setup_cifar100_dataloader_cil(images_path, training, min_class, max_class, batch_size=batch_size, shuffle=shuffle, open_world=open_world)
    else:
        loader = setup_cifar100_dataloader(images_path, training, min_class, max_class, batch_size=batch_size, shuffle=shuffle, open_world=open_world)
    return loader

def get_tinyimagenet_loader(images_path, min_class, max_class, training, batch_size=256, shuffle=False, open_world=False, is_cil=False):
    if is_cil:
        loader = setup_tinyimagenet_dataloader_cil(images_path, training, min_class, max_class, batch_size=batch_size, shuffle=shuffle, open_world=open_world)
    else:
        loader = setup_tinyimagenet_dataloader(images_path, training, min_class, max_class, batch_size=batch_size, shuffle=shuffle, open_world=open_world)
    return loader


def save_predictions(y_pred, min_class_trained, max_class_trained, save_path, suffix=''):
    name = 'preds_min_trained_' + str(min_class_trained) + '_max_trained_' + str(max_class_trained) + suffix
    torch.save(y_pred, save_path + '/' + name + '.pt')


def save_accuracies(accuracies, min_class_trained, max_class_trained, save_path, suffix=''):
    name = 'accuracies_min_trained_' + str(min_class_trained) + '_max_trained_' + str(
        max_class_trained) + suffix + '.json'
    json.dump(accuracies, open(os.path.join(save_path, name), 'w'))


def safe_load_dict(model, new_model_state):
    """
    Safe loading of previous ckpt file.
    """
    old_model_state = model.state_dict()

    c = 0
    for name, param in new_model_state.items():
        n = name.split('.')
        beg = n[0]
        end = n[1:]
        if beg == 'model':
            name = '.'.join(end)
        if name not in old_model_state:
            # print('%s not found in old model.' % name)
            continue
        c += 1
        if old_model_state[name].shape != param.shape:
            print('Shape mismatch...ignoring %s' % name)
            continue
        else:
            old_model_state[name].copy_(param)
    if c == 0:
        raise AssertionError('No previous ckpt names matched and the ckpt was not loaded properly.')
