import torch 
from torch.utils.data import Dataset
import torchvision.datasets as datasets
from collections import defaultdict,Counter
import numpy as np
import torch.nn as nn
from torchvision.transforms import ToTensor,Normalize,Compose,Resize
from torchvision import models
from copy import deepcopy


class ImbalancedDatasetWrapper(Dataset):
    def __init__(self,original_dataset,class_ratio,transform = None) -> None:
        super().__init__()
        self.transform = transform
        labels = sorted(list(set(_[1] for _ in original_dataset)))
        data = defaultdict(list)
        for _ in original_dataset:
            data[_[1]].append(_[0])
        
        max_class_size = len(data[labels[0]])
        min_class_size = int(max_class_size*class_ratio)
        class_sizes = [int(_) for _ in np.linspace(max_class_size,min_class_size,num = len(labels))]
        print(f'\t\tThe data sample per labels are : {class_sizes}')
        self.images, self.labels = [],[]

        for c_l,l in zip(class_sizes,labels):
            self.images.extend(data[l][:c_l])
            self.labels.extend([l]*c_l)
        
        self.np_labels = np.array(self.labels)

        if self.transform:
            self.images = [self.transform(img) for img in self.images]

        self.label_set = sorted(list(set(self.labels)))

    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        return (self.images[index],torch.tensor(self.labels[index]))
        #return {'image':self.images[index],'label':self.labels[index]}

    def get_class_weights(self):
        class_counts = Counter(self.labels)
        #print(f"Class Counts are {class_counts}")
        class_sample_importance = np.array([class_counts[k] for k in range(class_counts.__len__())])
        class_sample_importance = np.max(class_sample_importance)/class_sample_importance
        return class_sample_importance

    def get_sample_weights(self):
        class_counts = Counter(self.labels)
        #print(f"Class Counts are {class_counts}")
        class_sample_importance = {k:1/v for k,v in class_counts.items()}
        #print(f"Class weights are {class_sample_importance}")
        
        sample_importance_unnormalised = [class_sample_importance[_] for _ in self.labels]
        sample_importance_normalised = np.array(sample_importance_unnormalised)/sum(sample_importance_unnormalised)

        return sample_importance_normalised
    
    def get_sample_indexes(self,size,selected_label = None):
        if selected_label is None:
            indexes = np.random.choice(self.__len__(),size= size,replace=False)
        else:
            viable_indexes = np.argwhere(self.np_labels==selected_label).squeeze()
            if len(viable_indexes)<=size:
                size = viable_indexes -1
            indexes = np.random.choice(viable_indexes,size=size,replace=False)
        
        return indexes

    def get_victim_sample(self,victim_label=None,victim_patient = None):
        print('\t\tWARNING: victim_patient is not used for this set of experiments. Only for backward compatilibility.')
        viable_label_idx = np.argwhere(self.np_labels==victim_label).squeeze() if victim_label else np.ones((self.__len__()))
        return np.random.choice(viable_label_idx,size=1).item()

    def create_poisoned_copy(self,poisoned_vectors,poisoned_indexes):
        poisoned_dataset = deepcopy(self)
        for idx,pi in enumerate(poisoned_indexes):
            poisoned_dataset.images[pi] = poisoned_vectors[idx]
        
        return poisoned_dataset
    
    

def mnist_transform():
    t = Compose([ToTensor(),Normalize((0.5,),(1.0))])
    return t

def cifar10_transform():
    t = Compose([Resize(224),ToTensor(),Normalize((0.4915, 0.4823, 0.4468),
                        (0.2470, 0.2435, 0.2616))])
    return t

def get_transform(dataset,modelname):
    if dataset=='cifar10':
        if modelname=='alexnet':
            t = Compose([Resize(224),ToTensor(),Normalize((0.4915, 0.4823, 0.4468),
                        (0.2470, 0.2435, 0.2616))])
        else:
            t = Compose([ToTensor(),Normalize((0.4915, 0.4823, 0.4468),
                        (0.2470, 0.2435, 0.2616))])
    
    if dataset =='mnist':
        if modelname =='alexnet':
            t = Compose([Resize(224),ToTensor(),Normalize((0.5,),(1.0))])
        else:
            t = Compose([ToTensor(),Normalize((0.5,),(1.0))])
    return t

DATA_CONFIG = {'mnist':{'dataset':datasets.MNIST,'n_channels':1,'num_classes':10,},#'transform':mnist_transform,},
               'cifar10':{'dataset':datasets.CIFAR10,'n_channels':3,'num_classes':10,},#'transform':cifar10_transform},
               }
MODEL_CONFIG = {'resnet18':{'model':models.resnet18},
                'resnet34':{'model':models.resnet34},
                'resnet50':{'model':models.resnet50},
                'alexnet':{'model':models.alexnet}
                }

def get_image_model(model_name,dataset_name):
    dataset_config = DATA_CONFIG[dataset_name]

    if 'resnet' in model_name:
        model_resnet = MODEL_CONFIG[model_name]['model'](pretrained=False)
        if dataset_config['n_channels']!=3:
            model_resnet.conv1 = nn.Conv2d(dataset_config['n_channels'],
                                            64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 
            nn.init.kaiming_normal_(model_resnet.conv1.weight, mode="fan_out", nonlinearity="relu")
        
        num_ftrs = model_resnet.fc.in_features
        model_resnet.fc = nn.Linear(num_ftrs, dataset_config['num_classes']) 
        nn.init.kaiming_normal_(model_resnet.fc.weight, mode="fan_out", nonlinearity="relu")
        return model_resnet
    
    if model_name=='alexnet':
        model_alexnet = MODEL_CONFIG[model_name]['model'](pretrained=False)
        #print(model_alexnet)
        # print(model_alexnet.features)
        # print(model_alexnet.features[0])
        if dataset_config['n_channels']!=3:
            model_alexnet.features[0] = nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=1)
        model_alexnet.classifier[1] = nn.Linear(9216,4096)
        model_alexnet.classifier[4] = nn.Linear(4096,1024)
        model_alexnet.classifier[6] = nn.Linear(1024,10)
        return model_alexnet
    
    assert False,f'Unsupported Neural Model type {model_name}'

def get_dataset(dataset_name,model_name,class_ratio = None):

    dataset_config = DATA_CONFIG[dataset_name]
    trainset = dataset_config['dataset'](root='./data', train=True, download=True, transform = get_transform(dataset_name,model_name))

    testset = dataset_config['dataset'](root='./data', train=False, download=True, transform = get_transform(dataset_name,model_name))

    if class_ratio:
        imbalanced_trainset = ImbalancedDatasetWrapper(trainset,class_ratio=class_ratio)
        imbalanced_testset = ImbalancedDatasetWrapper(testset,class_ratio=class_ratio)
        return trainset,testset,imbalanced_testset,imbalanced_trainset
    else:
        return trainset,testset

def get_modelstamp(args):
    modelstamp = f"{args['modelname']}_{args['dataset']}_lr{args['learningrate']}_ep{args['epochs']}"
    return modelstamp

if __name__=="__main__":
    mnist_trainset = datasets.MNIST(root='./data', train=True, download=True)
    dtst = ImbalancedDatasetWrapper(mnist_trainset,.01, transform=ToTensor())
    print(mnist_trainset.__len__())
    print(dtst.__len__())

    cifar_trainset = datasets.CIFAR10(root='./data',  train=True, download=True)
    dtst = ImbalancedDatasetWrapper(cifar_trainset,.01, transform=ToTensor())
    print(cifar_trainset.__len__())
    print(dtst.__len__())

    model = get_image_model('resnet34','cifar10')
    print(model)