import numpy as np
from collections import namedtuple
import torch
from torch import nn
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

################################################################
## Components from https://github.com/davidcpage/cifar10-fast ##
################################################################

#####################
## data preprocessing
#####################

cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255
cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255

def normalise(x, mean=cifar10_mean, std=cifar10_std):
    x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)]
    x -= mean*255
    x *= 1.0/(255*std)
    return x

def pad(x, border=4):
    return np.pad(x, [(0, 0), (border, border), (border, border), (0, 0)], mode='reflect')

def transpose(x, source='NHWC', target='NCHW'):
    return x.transpose([source.index(d) for d in target]) 

#####################
## data augmentation
#####################

class Crop(namedtuple('Crop', ('h', 'w'))):
    def __call__(self, x, x0, y0):
        return x[:,y0:y0+self.h,x0:x0+self.w]

    def options(self, x_shape):
        C, H, W = x_shape
        return {'x0': range(W+1-self.w), 'y0': range(H+1-self.h)}
    
    def output_shape(self, x_shape):
        C, H, W = x_shape
        return (C, self.h, self.w)
    
class FlipLR(namedtuple('FlipLR', ())):
    def __call__(self, x, choice):
        return x[:, :, ::-1].copy() if choice else x 
        
    def options(self, x_shape):
        return {'choice': [True, False]}

class Cutout(namedtuple('Cutout', ('h', 'w'))):
    def __call__(self, x, x0, y0):
        x = x.copy()
        x[:,y0:y0+self.h,x0:x0+self.w].fill(0.0)
        return x

    def options(self, x_shape):
        C, H, W = x_shape
        return {'x0': range(W+1-self.w), 'y0': range(H+1-self.h)} 
    
    
class Transform():
    def __init__(self, dataset, transforms):
        self.dataset, self.transforms = dataset, transforms
        self.choices = None
        
    def __len__(self):
        return len(self.dataset)
           
    def __getitem__(self, index):
        data, labels = self.dataset[index]
        for choices, f in zip(self.choices, self.transforms):
            args = {k: v[index] for (k,v) in choices.items()}
            data = f(data, **args)
        return data, labels
    
    def set_random_choices(self):
        self.choices = []
        x_shape = self.dataset[0][0].shape
        N = len(self)
        for t in self.transforms:
            options = t.options(x_shape)
            x_shape = t.output_shape(x_shape) if hasattr(t, 'output_shape') else x_shape
            self.choices.append({k:np.random.choice(v, size=N) for (k,v) in options.items()})

#####################
## dataset
#####################

def cifar10(root):
    train_set = torchvision.datasets.CIFAR10(root=root, train=True, download=True)
    test_set = torchvision.datasets.CIFAR10(root=root, train=False, download=True)
    return {
        'train': {'data': train_set.data, 'labels': train_set.targets},
        'test': {'data': test_set.data, 'labels': test_set.targets}
    }

#####################
## data loading
#####################

class Batches():
    def __init__(self, dataset, batch_size, shuffle, set_random_choices=False, num_workers=0, drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.set_random_choices = set_random_choices
        self.dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=shuffle, drop_last=drop_last
        )
    
    def __iter__(self):
        if self.set_random_choices:
            self.dataset.set_random_choices()
        # return ({'input': x.to(device).half(), 'target': y.to(device).long()} for (x,y) in self.dataloader)
        return ({'input': x.to(device).float(), 'target': y.to(device).long()} for (x,y) in self.dataloader)
    
    def __len__(self): 
        return len(self.dataloader)


########################################################################################
## Calculate the values and the corresponding probability for imitation
########################################################################################

def probability(sign, p=2/3.):
    res = 1/2.
    for s in range(1, len(sign)):
        if sign[s] == sign[s-1]:
            res *= p
        else:
            res *= (1 - p)
    return res

def backtrace(n, p, sign=[], prob=dict()):
    if n <= 0:
        val = np.abs(np.sum(sign))
        if val not in prob:
            prob[val] = 0
        prob[val] += probability(sign, p)
        return
    backtrace(n-1, p, sign + [1], prob)
    backtrace(n-1, p, sign + [-1], prob)
    return prob

def imitation(n, p=2/3.):
    items = backtrace(n, p, [], dict())
    vals, probs = [], [0]

    for val in sorted(items.keys()):
        vals.append(val)
        probs.append(probs[-1] + items[val])
    return vals, probs[1:]

"""
Copied from https://github.com/Harry24k/catastrophic-overfitting/blob/main/defenses/loaders/datasets.py
"""
import zipfile
import os
from urllib.request import urlretrieve
from shutil import copyfile

class TinyImageNet() :
    def __init__(self, root="data",
                 train=True,
                 transform=None) :
        
        if root[-1] == "/" :
            root = root[:-1]
        
        self._ensure_dataset_loaded(root)
        
        if train :
            self.data = torchvision.datasets.ImageFolder(root+'/tiny-imagenet-200/train', 
                                          transform=transform)
        else :
            self.data = torchvision.datasets.ImageFolder(root+'/tiny-imagenet-200/val_fixed',
                                          transform=transform)
        
    def _download_dataset(self, path,
                          url='http://cs231n.stanford.edu/tiny-imagenet-200.zip',
                          tar_name='tiny-imagenet-200.zip'):
        if not os.path.exists(path):
            os.mkdir(path)
            
        if os.path.exists(os.path.join(path, tar_name)):
            print("Files already downloaded and verified")
            return
        else :
            print("Downloading Files...")
            urlretrieve(url, os.path.join(path, tar_name))
    #         print (os.path.join(path, tar_name))

            print("Un-zip Files...")
            zip_ref = zipfile.ZipFile(os.path.join(path, tar_name), 'r')
            zip_ref.extractall(path=path)
            zip_ref.close()

    def _ensure_dataset_loaded(self, root):
        self._download_dataset(root)

        val_fixed_folder = root+"/tiny-imagenet-200/val_fixed"
        if os.path.exists(val_fixed_folder):
            return
        os.mkdir(val_fixed_folder)

        with open(root+"/tiny-imagenet-200/val/val_annotations.txt") as f:
            for line in f.readlines():
                fields = line.split()

                file_name = fields[0]
                clazz = fields[1]

                class_folder = root+ "/tiny-imagenet-200/val_fixed/" + clazz
                if not os.path.exists(class_folder):
                    os.mkdir(class_folder)

                original_image_path = root+ "/tiny-imagenet-200/val/images/" + file_name
                copied_image_path = class_folder + "/" + file_name

                copyfile(original_image_path, copied_image_path)