import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import random
from torchvision import transforms
from src.datasets.celeba import CelebaGroupedDataset
from src.datasets.cifar100 import CIFAR100
import skimage
import torchvision
from skimage.filters import gaussian
import skimage as sk
from src.classification_metrics import Accuracy, MetricAggregator

if torch.cuda.is_available():
    DEVICE_TYPE = 'cuda'
else:
    DEVICE_TYPE  = "cpu"
    
DEVICE  = torch.device(DEVICE_TYPE)

if torch.cuda.is_available():
    DEVICE  = torch.device('cuda')
else:
    DEVICE  = torch.device('cpu')
    
print("Using",DEVICE)
import logging
logger = logging.getLogger(__name__)
DATA_DIR = ... # SET TO OWN DATA DIR, sorry badly coded ;(

def evaluate(model, val_loader):
    model.eval()    
    metrics = MetricAggregator([Accuracy()])        
        
    with torch.no_grad(): 
        for i,batch in enumerate(val_loader):                    
            # Get data
            data, targets = batch[0], batch[1]
            if len(batch) == 2:
                tasks = None
                data, targets = data.to(DEVICE), [elt.to(DEVICE) for elt in targets]
            elif len(batch) == 3:
                tasks = batch[2]
                data, targets,tasks = data.to(DEVICE),targets.to(DEVICE),tasks.to(DEVICE)
#            data, targets = data.to(DEVICE), [elt.to(DEVICE) for elt in targets]
            # Forward
            outputs,_,_ = model(data,tasks=tasks)
            metrics.update(outputs, targets)

        task_metrics = metrics.compute()
        return metrics.get_mean()

def get_gaussian_noise_corruption(severity=1,dataset_class=None):
    if "cifar" in str(dataset_class).lower():
        c = [0,0.01,0.03,0.06,0.08,0.1,0.12][severity]
    else:
         c = [0,0.01,0.05,0.08,0.1,0.12][severity]
         
    def gauss_noise_tensor(img):
        if not img.is_floating_point():
            img = img.to(torch.float32)
        out = torch.clip(img + torch.normal(mean=img.new_zeros(img.shape),std=c), 0, 1) 
        return out
    return gauss_noise_tensor

def get_impulse_noise_corruption(severity=1,dataset_class=None):
    if "cifar" in str(dataset_class).lower():
        c = [0,0.004,0.008,0.01,0.02,0.04,0.05][severity]
    else:
        c = [0,0.002,0.005,0.008,0.01,0.02][severity]
    def impulse_noise(img):
        return torch.clip(torch.from_numpy(skimage.util.random_noise(np.array(img), mode='s&p', amount=c)), 0, 1)
    return impulse_noise

def get_gaussian_blur_corruption(severity=1,dataset_class=None):
    c = [0,0.25,0.5,0.75,0.9,1,1.15][severity]
    def gaussian_blur(img):
        img = gaussian( img.numpy() , sigma=c,channel_axis=0)
        return torch.from_numpy(np.clip(img, 0, 1))
    return gaussian_blur

def get_contrast_corruption(severity=1,dataset_class=None):
    if "cifar" in str(dataset_class).lower():
        c = [1,0.9,0.8,0.7,0.6,0.4,0.2][severity]
    else:
        c = [1,0.85,0.7,0.6,0.5,0.4][severity]
    def contrast(img):
        x = img.permute((1,2,0)).numpy()  
        means = np.mean(x, axis=(0, 1), keepdims=True)
        return torch.from_numpy(np.clip((x - means) * c + means, 0, 1)).permute((2,0,1))
    return contrast

def get_brightness_corruption(severity=1,dataset_class=None):
    if "cifar" in str(dataset_class).lower():
        c = [0,0.2,0.3,0.4,0.5,0.6,0.7][severity]
    else:
        c = [0,0.1,0.2, 0.3,0.4, 0.5,0.6][severity]
    def brightness(img):
        x = img.permute((1,2,0)).numpy()
        x = sk.color.rgb2hsv(x)
        x[:, :, 2] = np.clip(x[:, :, 2] + c, 0, 1)
        x = sk.color.hsv2rgb(x)
        return torch.from_numpy( np.clip(x, 0, 1) ).permute((2,0,1))
    return brightness


def get_elastic_transform_corruption(severity=1,dataset_class=None):
    c = [0.,30.,50.,65., 80., 95., 110.][severity]
    el_transform = torchvision.transforms.ElasticTransform(alpha=c, sigma=5.0)
    def elastic_transform(img):
        return el_transform(img)
    return elastic_transform

class NoisyActivation(nn.Module):
    def __init__(self,activation,noise_fn):
        super().__init__()
        self.activation = activation
        self.noise_fn = noise_fn
        
    def forward(self,x):
        return self.activation(self.noise_fn(x)) 
    
def get_modulated_layer_corruption(severity=1):
    std = [0,0.04, 0.08, .12, .15, .18][severity]
    def add_noise(x):
        return x + torch.normal(mean=x.new_zeros(x.shape),std=std)
    def model_corruption(model):
        model.network.get_layer_to_modulate().conv.act = NoisyActivation(model.network.get_layer_to_modulate().conv.act,add_noise)
        return model
    return model_corruption


def test_model(model=None,dataset_class=None,img_corruption=None, task_groups=None):

    if "celeba" in str(dataset_class).lower():
        test_dataset = CelebaGroupedDataset(data_dir="/celeba", split='test', image_size=64, task_groups=task_groups)
        test_dataset.transform = transforms.Compose([transforms.Resize(64),  transforms.ToTensor(), img_corruption])
        corrupt_ds = torch.utils.data.Subset(test_dataset, random.sample( list(range(0, len(test_dataset))), len(test_dataset)//5) )
    elif "cifar" in str(dataset_class).lower():
        custom_transform = transforms.Compose([transforms.Resize(32),  transforms.ToTensor(),img_corruption,transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
        corrupt_ds = CIFAR100(data_dir=DATA_DIR,split="test", add_augmentations=False,image_size=32,coarse_labels=False,randomized_super_classes=None,custom_transform=custom_transform)
    
    test_loader = DataLoader(corrupt_ds, batch_size=64, shuffle=False , num_workers=4)    

    return evaluate(model, test_loader)


def compute_corruptions_results_model(model,dataset_class,task_groups,seed,device):
    model.eval()
    corruptions = [get_brightness_corruption,get_contrast_corruption,get_elastic_transform_corruption,get_gaussian_noise_corruption, get_impulse_noise_corruption, get_gaussian_blur_corruption]
    results_dict = {}
    severities = [0,1,2,3,4,5,6]
    for corruption in corruptions:
        corruption_name = corruption.__name__.split("get_")[1].split("_corruption")[0]
        results_dict[corruption_name] = []
        for sev in severities:
            logger.info("Testing model with {} corruption with severity {}".format(corruption_name,sev))
            corrup = corruption(severity=sev,dataset_class=dataset_class)
            corrup_scores = test_model(model=model,dataset_class=dataset_class,img_corruption=corrup,task_groups=task_groups)
            results_dict[corruption_name].append(float(corrup_scores))
    return results_dict
