import os
from shutil import move

import numpy as np
import pandas as pd
import torchvision.datasets as datasets


import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import pickle
# from preprocess import get_transform
# from utils import *
import math
from collections import Counter
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision.datasets import STL10
from PIL import Image
import os
import pdb


__DATASETS_DEFAULT_PATH = './data/'


def get_dataset(name, split='train', transform=None,
                target_transform=None, download=True, datasets_path=__DATASETS_DEFAULT_PATH):
    train = (split == 'train')
    root = os.path.join(datasets_path, name)
    if name == 'cifar10':
        return datasets.CIFAR10(root=root,
                                train=train,
                                transform=transform,
                                target_transform=target_transform,
                                download=download)

    elif name == 'cifar100':
        return datasets.CIFAR100(root=root,
                                 train=train,
                                 transform=transform,
                                 target_transform=target_transform,
                                 download=download)
    elif name == 'mnist':
        return datasets.MNIST(root=root,
                              train=train,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'stl10':
        return datasets.STL10(root=root,
                              split=split,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'imagenet':
        if train:
            root = os.path.join(root, 'train')
        else:
            root = os.path.join(root, 'val')
        return datasets.ImageFolder(root=root,
                                    transform=transform,
                                    target_transform=target_transform)
    elif name == 'svhn':
        return datasets.SVHN(root=root,
                              split=split,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    # elif name =='chexpert':
    #     return ChexpertSmall(root=root, mode=split, transform=transform)


def get_balanced_dataset(name, transform):
    train_data = get_dataset(name, 'train', transform['train'])
    if name == 'svhn':
        val_data = get_dataset(name, 'test', transform['eval'])
    # elif name == 'chexpert':
    #     val_data = get_dataset(name, 'valid', transform['eval'])
    else:
        val_data = get_dataset(name, 'val', transform['eval'])
    return train_data, val_data

def get_imbalanced_dataset(name, im_ratio, transform, data_dir=__DATASETS_DEFAULT_PATH, val_ratio=0.1):

    im_train_data_file = data_dir + name + '/imb_train_data' + "_" + str(im_ratio) + "_" + str(val_ratio)

    if not os.path.isfile(im_train_data_file):
        train_all = get_dataset(name, 'train', transform['train'])

        train_size = int(len(train_all) * (1 - val_ratio))
        val_size = len(train_all) - train_size

        train_data = torch.utils.data.Subset(train_all, list(range(train_size)))
        if val_ratio > 0.:
            val_data = torch.utils.data.Subset(get_dataset(name, 'train', transform['eval']), list(range(train_size, train_size+val_size)))

        del train_all

        label_list = []
        for (input, label) in train_data:
            label_list.append(label)
            # print(label_list.index(0))
        np_label = np.array(label_list)

        label_stats = Counter(np_label)
        #saved_indexes_start = math.floor((len(np_label) * (1 - im_ratio)) // len(np.unique(np_label)))
        #print(saved_indexes_start, len(np_label) // len(np.unique(np_label)))
        saved_indexes = []
        for i in range(len(np.unique(np_label))):
            if i < len(np.unique(np_label)) // 2:
                saved_indexes_start = math.floor(label_stats[i]* (1 - im_ratio))
                saved_indexes = saved_indexes + list(np.where(np_label == i)[0][saved_indexes_start:])
            else:
                saved_indexes = saved_indexes + list(np.where(np_label == i)[0])

        imbalanced_train_data = torch.utils.data.Subset(train_data, saved_indexes)
        print(len(imbalanced_train_data))

        f = open(im_train_data_file, 'wb')
        pickle.dump(imbalanced_train_data, f)
        f.close()

        if val_ratio > 0.:
            f = open(data_dir + name + '/val_data_'+str(val_ratio), 'wb')
            pickle.dump(val_data, f)
            f.close()

    test_data_file = data_dir + name + '/test_data'
    if not os.path.isfile(test_data_file):
        if name == 'svhn' or name == 'stl10':
            test_data = get_dataset(name, 'test', transform['eval'])
        else:
            test_data = get_dataset(name, 'val', transform['eval'])
        f = open(data_dir + name + '/test_data', 'wb')
        pickle.dump(test_data, f)
        f.close()



    f = open(im_train_data_file, 'rb') # data_dir + name + '/imb_train_data'+ "_" + str(im_ratio)
    train_data = pickle.load(f)
    f.close()
    if val_ratio > 0.:
        f = open(data_dir + name + '/val_data_'+str(val_ratio), 'rb')
        val_data = pickle.load(f)
        f.close()
    f = open(data_dir + name + '/test_data', 'rb')
    test_data = pickle.load(f)
    f.close()

    if val_ratio > 0.:
        return  train_data, val_data, test_data
    else:
        return train_data, test_data, test_data


def imb_cifar10_dataloader(batch_size=64, data_dir=__DATASETS_DEFAULT_PATH, val_ratio=0.1, imratio=0.2):
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    transform = {
        'train': train_transform,
        'eval': test_transform
    }

    train_set, val_set, test_set = get_imbalanced_dataset('cifar10', imratio, transform, data_dir=data_dir, val_ratio=val_ratio)
    print('OBTAINED IMBALANCED CIFAR10 DATASET')

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    dataset_normalization = NormalizeByChannelMeanStd(
        mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])

    num_classes = 10
    return train_loader, val_loader, test_loader, dataset_normalization, num_classes


def imb_svhn_dataloader(batch_size=64, data_dir=__DATASETS_DEFAULT_PATH, val_ratio=0.1, imratio=0.2):
    train_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    transform = {
        'train': train_transform,
        'eval': test_transform
    }

    train_set, val_set, test_set = get_imbalanced_dataset('svhn', imratio, transform, data_dir=data_dir, val_ratio=val_ratio)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True,
                             drop_last=True)
    dataset_normalization = NormalizeByChannelMeanStd(mean=[0.4377, 0.4438, 0.4728], std=[0.1201, 0.1231, 0.1052])

    num_classes = 10

    return train_loader, val_loader, test_loader, dataset_normalization, num_classes


def imb_cifar100_dataloader(batch_size=64, data_dir=__DATASETS_DEFAULT_PATH, val_ratio=0.1, imratio=0.2):
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    transform = {
        'train': train_transform,
        'eval': test_transform
    }

    train_set, val_set, test_set = get_imbalanced_dataset('cifar100', imratio, transform, data_dir=data_dir, val_ratio=val_ratio)
    print('OBTAINED IMBALANCED CIFAR100 DATASET')

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    dataset_normalization = NormalizeByChannelMeanStd(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2673, 0.2564, 0.2762])

    num_classes = 100
    return train_loader, val_loader, test_loader, dataset_normalization, num_classes


def imb_stl10_dataloader(batch_size=64, data_dir=__DATASETS_DEFAULT_PATH, val_ratio=0., imratio=0.2):
    train_transform = transforms.Compose([
        transforms.Pad(4),
        transforms.RandomCrop(96),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    transform = {
        'train': train_transform,
        'eval': test_transform
    }

    train_set, _, test_set = get_imbalanced_dataset('stl10', imratio, transform, data_dir=data_dir, val_ratio=val_ratio)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2,
                              drop_last=True, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    dataset_normalization = NormalizeByChannelMeanStd(mean=[0.4467, 0.4398, 0.4066], std=[0.2242, 0.2215, 0.2239])

    num_classes = 10
    return train_loader, test_loader, test_loader, dataset_normalization, num_classes


class NormalizeByChannelMeanStd(torch.nn.Module):
    def __init__(self, mean, std):
        super(NormalizeByChannelMeanStd, self).__init__()
        if not isinstance(mean, torch.Tensor):
            mean = torch.tensor(mean)
        if not isinstance(std, torch.Tensor):
            std = torch.tensor(std)
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)

    def forward(self, tensor):
        return self.normalize_fn(tensor, self.mean, self.std)

    def extra_repr(self):
        return 'mean={}, std={}'.format(self.mean, self.std)

    def normalize_fn(self, tensor, mean, std):
        """Differentiable version of torchvision.functional.normalize"""
        # here we assume the color channel is in at dim=1
        mean = mean[None, :, None, None]
        std = std[None, :, None, None]
        return tensor.sub(mean).div(std)



class LT(Dataset):
    '''
    ImageNet, ImageNet-LT, iNaturalist2018, iNaturalist2019 Imbalanced Dataset Contruction
    '''
    def __init__(self, root, txt, transform=None):
        self.img_path = []
        self.labels = []
        self.transform = transform

        print("LT:", txt)
        if 'amax' in os.uname()[1]:
            with open(txt) as f:
                for line in f:
                    self.img_path.append(os.path.join(root, line.split()[0]))
                    self.labels.append(int(line.split()[1]))
        else:
            if 'test' in txt and 'ImageNet' in txt:
                with open(txt) as f:
                    for line in f:
                        img_name = '/'.join([line.split()[0].split('/')[0], line.split()[0].split('/')[2]])
                        self.img_path.append(os.path.join(root, img_name))
                        self.labels.append(int(line.split()[1]))
            else:
                with open(txt) as f:
                    for line in f:
                        self.img_path.append(os.path.join(root, line.split()[0]))
                        self.labels.append(int(line.split()[1]))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        path = self.img_path[index]
        label = self.labels[index]

        with open(path, 'rb') as f:
            sample = Image.open(f).convert('RGB')

        if self.transform is not None:
            sample = self.transform(sample)

        return sample, label  # , index


class GTSRB(Dataset):
    base_folder = 'GTSRB'

    def __init__(self, root_dir, train=False, transform=None):
        self.root_dir = root_dir

        self.sub_directory = 'trainingset' if train else 'testset'
        self.csv_file_name = 'training.csv' if train else 'test.csv'

        csv_file_path = os.path.join(
            root_dir, self.base_folder, self.sub_directory, self.csv_file_name)

        print("Reading GTSRB data......")
        self.csv_data = pd.read_csv(csv_file_path)

        self.transform = transform

        self.imgs = []
        self.labels = []

        print("Processing GTSRB data......")
        for idx in range(len(self.csv_data)):
            img_path = os.path.join(self.root_dir, self.base_folder, self.sub_directory,
                                    self.csv_data.iloc[idx, 0])
            img = Image.open(img_path)
            classId = self.csv_data.iloc[idx, 1]
            self.labels.append(classId)

            if self.transform is not None:
                img = self.transform(img)

            self.imgs.append(img)
        self.imgs = torch.stack(self.imgs)
        self.labels = torch.tensor(self.labels)

    def __len__(self):
        return len(self.csv_data)

    def __getitem__(self, idx):
        return self.imgs[idx], self.labels[idx]


def gtsrb_dataloader(batch_size=128, data_dir='./data/', val_ratio=0.1):
    """
    Download dataset from https://onedrive.live.com/?authkey=%21AKNpIXu0xpmVm1I&cid=25B382439BAD237F&id=25B382439BAD237F%21224763&parId=25B382439BAD237F%21224762&action=locate
    Unzip the zip file and make the path the data_dir below.

    Args:
        data_dir: see ABOVE
    Returns:

    """
    train_transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ])

    number_train_images = 39208

    train_size = int(number_train_images * (1 - val_ratio))
    val_size = number_train_images - train_size
    train_set = Subset(GTSRB(data_dir, train=True, transform=train_transform), list(range(train_size)))
    val_set = Subset(GTSRB(data_dir, train=True, transform=test_transform),
                     list(range(train_size, train_size + val_size)))
    test_set = GTSRB(data_dir, train=False, transform=test_transform)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    dataset_normalization = NormalizeByChannelMeanStd(
        mean=[0.3403, 0.3121, 0.3214], std=[0.2724, 0.2608, 0.2669])
    num_classes = 43
    return train_loader, test_loader, test_loader, dataset_normalization, num_classes

def stl10_dataloader(batch_size=64, data_dir='./data/'):
    train_transform = transforms.Compose([
        transforms.Pad(4),
        transforms.RandomCrop(96),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_set = STL10(data_dir, split='train', download=True, transform=train_transform)
    test_set = STL10(data_dir, split='test', download=True, transform=test_transform)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2,
                              drop_last=True, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    dataset_normalization = NormalizeByChannelMeanStd(mean=[0.4467, 0.4398, 0.4066], std=[0.2242, 0.2215, 0.2239])

    num_classes = 10
    return train_loader, test_loader, test_loader, dataset_normalization, num_classes


if __name__ == '__main__':

    # train_transform = transforms.Compose([
    #     transforms.ToTensor()
    # ])
    #
    # test_transform = transforms.Compose([
    #     transforms.ToTensor()
    # ])
    #
    # transform = {
    #       'train': train_transform,
    #      'eval': test_transform
    #  }
    #
    #
    # print("MNIST IMRATIO 0.02")
    # train_data, val_data, test_data = get_imbalanced_dataset('mnist', 0.02, transform, val_ratio=0.1)
    # label_list = []
    # for (input, label) in train_data:
    #     label_list.append(label)
    # print(f'Train label stats: {Counter(np.array(label_list))}')
    # val_label_list = []
    # for (input, label) in val_data:
    #     val_label_list.append(label)
    # print(f'Val label stats: {Counter(np.array(val_label_list))}')
    # test_label_list = []
    # for (input, label) in test_data:
    #     test_label_list.append(label)
    # print(f'Test label stats: {Counter(np.array(test_label_list))}')
    #
    # pdb.set_trace()

    #
    #
    # print("IMRATIO 0.05")
    # train_data, val_data = get_imbalance_dataset('mnist', 0.05, transform)
    # label_list = []
    # for (input, label) in train_data:
    #     label_list.append(label)
    # print(Counter(np.array(label_list)))
    # val_label_list = []
    # for (input, label) in val_data:
    #     val_label_list.append(label)
    # print(Counter(np.array(val_label_list)))

    # print("MNIST IMRATIO 0.2")
    # train_data, val_data = get_imbalanced_dataset('mnist', 0.2, transform)
    # label_list = []
    # for (input, label) in train_data:
    #     label_list.append(label)
    # print(Counter(np.array(label_list)))
    # val_label_list = []
    # for (input, label) in val_data:
    #     val_label_list.append(label)
    # print(Counter(np.array(val_label_list)))


    # print("IMRATIO 0.5")
    # train_data, val_data = get_imbalance_dataset('mnist', 0.5, transform)
    # label_list = []
    # for (input, label) in train_data:
    #     label_list.append(label)
    # print(Counter(np.array(label_list)))
    # val_label_list = []
    # for (input, label) in val_data:
    #     val_label_list.append(label)
    # print(Counter(np.array(val_label_list)))
    #
    # print("IMRATIO 1")
    # train_data, val_data = get_imbalance_dataset('mnist', 1, transform)
    # label_list = []
    # for (input, label) in train_data:
    #     label_list.append(label)
    # print(Counter(np.array(label_list)))
    # val_label_list = []
    # for (input, label) in val_data:
    #     val_label_list.append(label)
    # print(Counter(np.array(val_label_list)))

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    transform = {
        'train': train_transform,
        'eval': test_transform
    }

    # train_data, val_data = get_imbalance_dataset("cifar10", 0.02, transform)
    # train_data, val_data = get_imbalance_dataset("cifar10", 0.05, transform)
    # train_data, val_data = get_imbalance_dataset("cifar10", 0.1, transform)
    # train_data, val_data = get_imbalance_dataset("cifar10", 0.2, transform)
    # print("Length of Train Data: ", len(train_data))
    # print("Length of Validation Data: ", len(val_data))
    # for i, (inputs, target) in enumerate(train_loader):
    #     if i == 0:
    #         print("targets:", target>100)
    #         print(inputs)
    #         print(inputs[0].size())


    print("CIFAR10 IMRATIO 1")
    train_data, val_data, test_data = get_imbalanced_dataset('cifar10', 1., transform, val_ratio=0.5)
    label_list = []
    for (input, label) in train_data:
        label_list.append(label)
    print(f'Train label stats: {Counter(np.array(label_list))}')
    val_label_list = []
    for (input, label) in val_data:
        val_label_list.append(label)
    print(f'Val label stats: {Counter(np.array(val_label_list))}')
    test_label_list = []
    for (input, label) in test_data:
        test_label_list.append(label)
    print(f'Test label stats: {Counter(np.array(test_label_list))}')


    # pdb.set_trace()