import argparse
import os
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data.distributed import DistributedSampler
from torch.autograd import Variable
import numpy as np
import math
import utils
import sys

# CIFAR-{10, 100}-C dataset
CORRUPTIONS = [
    'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
    'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
    'brightness', 'contrast', 'elastic_transform', 'pixelate',
    'jpeg_compression',
    'speckle_noise', 'gaussian_blur', 'saturate', 'spatter'
]

def get_transforms(dataset, augment=True):
    if dataset in ['cifar10', 'cifar100', 'cifar100c']:
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        if augment:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
                                                  (4, 4, 4, 4), mode='reflect').squeeze()),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])
    elif dataset in ['imagenet']:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        if augment:
            transform_train = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            transform_train = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])
        transform_test = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])
        pass
    return transform_train, transform_test


def get_loaders(dataset, batch_size=128, val_batch_size=128,
                augment=True, split_size=-1, split_size_2=-1,
                distill_temps=None, ddp=False, shift=1):
    assert(dataset in ['cifar10', 'cifar100', 'cifar100c', 'imagenet'])
    train_loader, train_val_loader, val_loader = None, None, None
    transform_train, transform_test = get_transforms(dataset, augment=augment)
    if dataset in ['cifar10', 'cifar100']:
        data_path = '../data'
        kwargs = {'num_workers': 1, 'pin_memory': True}
        train_dataset = datasets.__dict__[dataset.upper()](data_path, train=True, download=True,
                                                           transform=transform_train)
        if split_size_2 > 0:
            train_dataset_notransform = datasets.__dict__[dataset.upper()](
                data_path, train=True, download=True, transform=transform_test)
        if ddp:
            train_sampler = DistributedSampler(train_dataset)
            train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=batch_size, sampler=train_sampler, **kwargs)
        else:
            train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.__dict__[dataset.upper()](data_path, train=False, transform=transform_test),
            batch_size=val_batch_size, shuffle=False, **kwargs)
    elif dataset in ['cifar100c']:
        data_path = '/export/home/data/CIFAR-100-C'
        kwargs = {'num_workers': 1, 'pin_memory': True}
        train_dataset = datasets.__dict__['CIFAR100']('../data', train=True, download=True,
                                                      transform=transform_train)
        val_dataset = datasets.__dict__['CIFAR100']('../data', train=True, download=True,
                                                    transform=transform_test)
        # import pdb; pdb.set_trace()
        data_all = np.zeros((10000 * len(CORRUPTIONS), *train_dataset.data.shape[1:]),
                            dtype=train_dataset.data.dtype)
        targets_all = [0] * (10000 * len(CORRUPTIONS))
        targets = np.load(os.path.join(data_path, 'labels.npy'))
        counter = 0
        for corruption in CORRUPTIONS:
            data = np.load(os.path.join(data_path, f'{corruption}.npy'))
            data_all[counter:(counter+10000), :] = data[(shift-1)*10000:shift*10000, :]
            targets_all[counter:(counter+10000)] = targets[(shift-1)*10000:shift*10000]
            counter += 10000
        targets_all = torch.LongTensor(targets_all)
        train_dataset.data, train_dataset.targets = data_all, targets_all
        val_dataset.data, val_dataset.targets = data_all, targets_all
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, **kwargs
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=val_batch_size, shuffle=False, **kwargs
        )
    elif dataset in ['imagenet']:
        data_path = '/export/share/datasets/vision/imagenet'
        kwargs = {'num_workers': 8, 'pin_memory': True}
        train_dataset = datasets.__dict__['ImageNet'](data_path, split='train',
                                                      transform=transform_train)
        if split_size_2 > 0:
            train_dataset_notransform = datasets.__dict__['ImageNet'](
                data_path, split='train', transform=transform_test)
        if ddp:
            kwargs = {'num_workers': 7, 'pin_memory': True}
            train_sampler = DistributedSampler(train_dataset)
            train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=batch_size, sampler=train_sampler, **kwargs)
        else:
            train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.__dict__['ImageNet'](data_path, split='val', transform=transform_test),
            batch_size=val_batch_size, shuffle=False, **kwargs)

    # (optionally) split the dataset using fixed seed
    if split_size > 0:
        n_train = len(train_dataset)
        if split_size_2 > 0:
            dataset_split = torch.utils.data.random_split(
                train_dataset, [split_size, split_size_2, n_train - split_size - split_size_2],
                generator=torch.Generator().manual_seed(42)
            )
            dataset_split_notransform = torch.utils.data.random_split(
                train_dataset_notransform, [split_size, split_size_2, n_train - split_size - split_size_2],
                generator=torch.Generator().manual_seed(42)
            )
        else:
            dataset_split = torch.utils.data.random_split(
                train_dataset, [split_size, n_train - split_size],
                generator=torch.Generator().manual_seed(42)
            )
        if ddp:
            train_sampler = DistributedSampler(dataset_split[0])
            train_val_sampler = DistributedSampler(dataset_split[1])
            train_loader = torch.utils.data.DataLoader(
                dataset_split[0], batch_size=batch_size, sampler=train_sampler, **kwargs)
            train_val_loader = torch.utils.data.DataLoader(
                dataset_split[1], batch_size=batch_size, sampler=train_val_sampler, **kwargs)
            if split_size_2 > 0:
                val_loader = torch.utils.data.DataLoader(
                    dataset_split_notransform[2], batch_size=val_batch_size, shuffle=True, **kwargs
                )
        else:
            train_loader = torch.utils.data.DataLoader(
                dataset_split[0], batch_size=batch_size, shuffle=True, **kwargs)
            # if distill_temps is not None:
            #     indices = torch.randperm(n_train, generator=torch.Generator().manual_seed(42)).tolist()
            #     optimal_temps = torch.load(distill_temps).detach().squeeze().cpu()
            #     train_dataset.targets = torch.zeros(n_train)
            #     train_dataset.targets[indices[split_size:]] = optimal_temps
            #     train_dataset.targets = dataset.targets.numpy()
            train_val_loader = torch.utils.data.DataLoader(
                dataset_split[1], batch_size=batch_size, shuffle=True, **kwargs)
            if split_size_2 > 0:
                # overwrite val loader with split 3 of the training dataset
                val_loader = torch.utils.data.DataLoader(
                    dataset_split_notransform[2], batch_size=val_batch_size, shuffle=False, **kwargs
                )
    else:
        train_val_loader = None

    return train_loader, train_val_loader, val_loader


