import numpy as np
from collections import namedtuple
import torch
from torch import nn
import random
import time
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

################################################################
## Components from https://github.com/davidcpage/cifar10-fast ##
################################################################

#####################
## data preprocessing
#####################

cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255
cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255

def normalise(x, mean=cifar10_mean, std=cifar10_std):
    x, mean, std = [np.array(a, np.float) for a in (x, mean, std)]
    x -= mean*255
    x *= 1.0/(255*std)
    return x

def pad(x, border=4):
    # print(len(x))
    res = np.pad(x, [(0, 0), (border, border), (border, border), (0, 0)], mode='reflect')
    # print(len(res))
    return res

def pad_svhn(x, border=4):
    # print(len(x))
    res = np.pad(x, [(0, 0), (0, 0), (border, border), (border, border)], mode='reflect')
    # print(len(res))
    return res

def pad_new(x, border=4):
    # print(len(x))
    res = np.pad(x, [(0, 0),(0, 0), (border, border), (border, border)], mode='reflect')
    # print(len(res))
    return res

def transpose(x, source='NHWC', target='NCHW'):
    res = x.transpose([source.index(d) for d in target]) 
    return res

#####################
## data augmentation
#####################

class Crop(namedtuple('Crop', ('h', 'w'))):
    def __call__(self, x, x0, y0):
        return x[:,y0:y0+self.h,x0:x0+self.w]

    def options(self, x_shape):
        C, H, W = x_shape
        return {'x0': range(W+1-self.w), 'y0': range(H+1-self.h)}
    
    def output_shape(self, x_shape):
        C, H, W = x_shape
        return (C, self.h, self.w)
    
class FlipLR(namedtuple('FlipLR', ())):
    def __call__(self, x, choice):
        return x[:, :, ::-1].copy() if choice else x 
        
    def options(self, x_shape):
        return {'choice': [True, False]}

class Cutout(namedtuple('Cutout', ('h', 'w'))):
    def __call__(self, x, x0, y0):
        x = x.copy()
        x[:,y0:y0+self.h,x0:x0+self.w].fill(0.0)
        return x

    def options(self, x_shape):
        C, H, W = x_shape
        return {'x0': range(W+1-self.w), 'y0': range(H+1-self.h)} 
    
    
class Transform():
    def __init__(self, dataset, transforms):
        self.dataset, self.transforms = dataset, transforms
        self.choices = None
        
    def __len__(self):
        return len(self.dataset)
           
    def __getitem__(self, index):
        self.set_random_choices()
        data, labels = self.dataset[index]
        for choices, f in zip(self.choices, self.transforms):
            args = {k: v[index] for (k,v) in choices.items()}
            data = f(data, **args)
        return data, labels
    
    def set_random_choices(self):
        self.choices = []
        x_shape = self.dataset[0][0].shape
        N = len(self)
        for t in self.transforms:
            options = t.options(x_shape)
            x_shape = t.output_shape(x_shape) if hasattr(t, 'output_shape') else x_shape
            self.choices.append({k:np.random.choice(v, size=N) for (k,v) in options.items()})

class Batches():
    def __init__(self, dataset):
        self.n = len(dataset)
        self.dataset = dataset
        self.idx = self.renew()
        self.pos = 0

    def renew(self):
        idx = list(range(1, self.n))
        random.shuffle(idx)
        return idx

    def get_batch(self, num):
        # print(time.time())
        if self.n - self.pos < num:
            self.idx = self.renew()
            self.pos = 0
        
        start_idx = self.pos
        end_idx = self.pos + num

        batch = self.idx[start_idx: end_idx]
        # self.idx = self.idx[num:]
        self.pos = end_idx

        X = []
        y = []
        for i in batch:
            # print(self.dataset)
            # print(self.dataset[i])
            X.append(self.dataset[i][0])
            y.append(self.dataset[i][1])
        # print(time.time())
        
        return torch.tensor(X), torch.tensor(y), None


        




class Transform_weight():
    def __init__(self, dataset, transforms):
        self.dataset, self.transforms = dataset, transforms
        self.choices = None
        
    def __len__(self):
        return len(self.dataset)
           
    def __getitem__(self, index):
        self.set_random_choices()
        data, labels, weights = self.dataset[index]
        for choices, f in zip(self.choices, self.transforms):
            args = {k: v[index] for (k,v) in choices.items()}
            data = f(data, **args)
        return data, labels, weights
    
    def set_random_choices(self):
        self.choices = []
        x_shape = self.dataset[0][0].shape
        N = len(self)
        for t in self.transforms:
            options = t.options(x_shape)
            x_shape = t.output_shape(x_shape) if hasattr(t, 'output_shape') else x_shape
            self.choices.append({k:np.random.choice(v, size=N) for (k,v) in options.items()})

#####################
## dataset
#####################

def select_class_new(train_set, classes, num_per_class = 5000, order=[]):
    # output four datasets:
    # train: the selected classes with each class num_per_class samples
    # train_extra: the extra number of samples in selected classes
    # train_other: the other samples in the training set
    # test: testing set for the selected classes

    # order: the order of the indexes after shuffling

    #  = torchvision.datasets.CIFAR10(root=root, train=True, download=True)
    #  = torchvision.datasets.CIFAR10(root=root, train=False, download=True)

    idx = list(range(1, len(train_set)))

    random.shuffle(idx)
    if len(order) > 0:
        idx = order
        

    cnt = {i:0 for i in classes}

    train_x = []
    train_y = []
    train_extra_x = []
    train_extra_y = []
    train_other_x = []
    train_other_y = []

    for i in idx:
        target = train_set.targets[i]
        data = train_set.data[i].astype(float)
        if target in cnt and cnt[target] < num_per_class:
            train_x.append(data)
            train_y.append(target)
            cnt[target] += 1
        elif target in cnt:
            # target = new_class[target]
            train_extra_x.append(data)
            train_extra_y.append(target)
        else:
            # target = new_class[target]
            train_other_x.append(data)
            train_other_y.append(target)

    return {'data': train_x, 'labels': train_y}, {'data': train_extra_x, 'labels': train_extra_y},  {'data': train_other_x, 'labels': train_other_y}, idx


def select_class(train_set, classes, num_per_class = 5000, order=[]):
    # output four datasets:
    # train: the selected classes with each class num_per_class samples
    # train_extra: the extra number of samples in selected classes
    # train_other: the other samples in the training set
    # test: testing set for the selected classes

    # order: the order of the indexes after shuffling

    #  = torchvision.datasets.CIFAR10(root=root, train=True, download=True)
    #  = torchvision.datasets.CIFAR10(root=root, train=False, download=True)

    idx = list(range(1, len(train_set)))

    random.shuffle(idx)
    if len(order) > 0:
        idx = order
        

    cnt = {i:0 for i in classes}
    new_class = {i:0 for i in classes}
    i = 0
    for j in new_class:
        new_class[j] = i
        i += 1

    train_x = []
    train_y = []
    train_extra_x = []
    train_extra_y = []
    train_other_x = []
    train_other_y = []

    for i in idx:
        target = train_set.targets[i]
        data = train_set.data[i].astype(float)
        if target in cnt and cnt[target] < num_per_class:
            target = new_class[target]
            train_x.append(data)
            train_y.append(target)
            cnt[target] += 1
        elif target in cnt:
            # target = new_class[target]
            train_extra_x.append(data)
            train_extra_y.append(target)
        else:
            # target = new_class[target]
            train_other_x.append(data)
            train_other_y.append(target)

    return {'data': train_x, 'labels': train_y}, {'data': train_extra_x, 'labels': train_extra_y},  {'data': train_other_x, 'labels': train_other_y}, idx



def append_dataset(dataset1, dataset2, classes, num_per_class , order = []):

    idx = list(range(1, len(dataset2['labels'])))
    # print(len(dataset2['data']), len(dataset2['labels']))

    random.shuffle(idx)
    if len(order) > 0:
        idx = order

    cnt = {i:0 for i in classes}
    weight_classes = {i:0 for i in classes}

    train_x = []
    train_y = []
    train_w = []

    if 'weight' in dataset1:
        for i in idx:
            # print(i)
            target = dataset2['labels'][i]
            data = dataset2['data'][i]
            weight = dataset2['weight'][i]
            if target in cnt and cnt[target] < num_per_class:
                # print(data.shape)
                weight_classes[target] += weight
                train_x.append(data)
                train_y.append(target)
                train_w.append(weight)
                cnt[target] += 1
        
        train_w = [train_w[i]/weight_classes[train_y[i]]*num_per_class for i in range(len(train_w)) ]
        return  {'data': dataset1['data'] + train_x, 'labels': dataset1['labels'] + train_y, 'weight': dataset1['weight']+train_w}, idx
    else:
        for i in idx:
            target = dataset2['labels'][i]
            data = dataset2['data'][i]
            if target in cnt and cnt[target] < num_per_class:
                # print(data.shape)
                train_x.append(data)
                train_y.append(target)
                cnt[target] += 1
        return  {'data': dataset1['data'] + train_x, 'labels': dataset1['labels'] + train_y}, idx

    


