import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch import nn, optim
from torchbearer import Trial
from torchbearer.callbacks import Mixup, MultiStepLR
from torchvision.datasets import VisionDataset
import matplotlib.pyplot as plt
from models import resnet
import torch.nn.functional as F
import augs_alter
import random
import numpy as np
import cv2
import math
import argparse

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--device', default='cuda', type=str, help='Device on which to run')
parser.add_argument('--path', type=str)
parser.add_argument('--dataset-path', type=str, default=None, help='Optional dataset path')
parser.add_argument('--batch-size', default=128, type=int, help='batch size')
args = parser.parse_args()

random.seed(0)
trainset = torchvision.datasets.CIFAR10(root=args.dataset_path, train=True, download=True,
                                           transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),]))
valset = torchvision.datasets.CIFAR10(root=args.dataset_path, train=False, download=True,
                                           transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),]))

valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, shuffle=True, num_workers=8)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=8)
device = args.device if torch.cuda.is_available() else "cpu"

n_cls = len(valset.classes)
from scipy.stats import beta
from torchvision.datasets import VisionDataset
class AugmentedDataset(VisionDataset):
    def __init__(self, dataset, method, alpha):
        self.alpha = alpha
        self.dataset = dataset
        self.method = method
    def __len__(self):
            return len(self.dataset)
    def __getitem__(self, index):
        img1, target = self.dataset[index][0], self.dataset[index][1]
        lam = beta.rvs(self.alpha+1, self.alpha)
        img = getattr(augs_alter, self.method)(img1, None, lam)
        return img, target

def get_wrong_pred(net, loader):
    net.eval()
    cnt = torch.zeros((n_cls))

    for image, labels in loader:
        image = image.to(device)
        labels = labels.to(device)
        outputs = net(image)
        _, predicted = torch.max(outputs.detach(), 1)
        for i in range(0, len(predicted)):
            if predicted[i] != labels[i]:
                cnt[predicted[i]] += 1
    return cnt

def get_index(dif):
    dif_base = torch.clamp(dif,0)
    dif_base/= dif_base.sum(axis=1).repeat(dif_base.shape[1],1).transpose(0,1)
    cls = torch.argmax(dif_base.mean(axis=0))
    print(cls)
    local_maximum = torch.max(dif_base[:,cls])
    perc = dif_base[:,cls]/dif_base.sum(axis=1)
    max_ratio = dif_base[:,cls]
    dif_base = perc * max_ratio
    return dif_base.mean(), dif_base.std()

valloader_aug =  torch.utils.data.DataLoader(AugmentedDataset(valset, "cutout", 1.), batch_size=args.batch_size, shuffle=True, num_workers=8)
base = torch.zeros((5, n_cls))
distorted = torch.zeros((5, n_cls))
for i in range(0, 5):
    net = resnet.ResNet18().to(device)
    net.load_state_dict(torch.load(args.path + str(i) + '.pt', map_location=lambda storage, loc: storage)['model'])
    wr = get_wrong_pred(net, valloader)
    base[i] = wr/torch.sum(wr)
    wr = get_wrong_pred(net, valloader_aug)
    distorted[i] = wr/torch.sum(wr)
print(get_index(distorted - base))
