from torch.nn import functional as F
import torch
from tqdm import tqdm
import numpy as np
from src.metrics import brier_multi
from torchvision import transforms
from torchvision.datasets import CIFAR100
import calibration as cal


CIFAR_CORRUPTIONS = [
    'brightness',
    'contrast',
    'defocus_blur',
    'elastic_transform',
    'fog',
    'frost',
    'gaussian_noise',
    'glass_blur',
    'impulse_noise',
    'jpeg_compression',
    'motion_blur',
    'pixelate',
    'shot_noise',
    'snow',
    'zoom_blur',
]

IMAGENET_CORRUPTIONS = [
    'contrast',
    'defocus_blur',
    'elastic_transform',
    'gaussian_noise',
    'glass_blur',
    'impulse_noise',
    'jpeg_compression',
    'motion_blur',
    'pixelate',
    'shot_noise',
    'zoom_blur'
]


def calc_score(scoring_metric, probs, labels):

    if scoring_metric == "ECE-1":
        return cal.lower_bound_scaling_ce(probs, labels, p=1, debias=False, num_bins=15,
                      mode="top-label")
    elif scoring_metric == "ECE-2":
        return cal.lower_bound_scaling_ce(probs, labels, p=2, debias=True, num_bins=15,
                      mode="top-label")
    elif scoring_metric == "Brier":
        return brier_multi(probs, labels)
    else:
        return ValueError


class Dataset:
    
    
    def __init__(self, logits, labels, features=None, name=None, probs_only=False):
        
        self.name = name
        self.logits = logits
        self.features = features
        self.labels = labels
        self.n = self.logits.shape[0]
        
        if probs_only:
            self.probs = self.logits
        else:
            self.probs = F.softmax(torch.Tensor(logits), dim=-1).numpy()
        
        
    def measure_calibration(self, h=None, verbose=True, print_plot=False, h_type="platt"):
        if h and h_type == "platt":
            probs_cal = h.calibrate(self.probs)
            
        elif h and h_type == "MLP-p":
            probs_cal = h.predict_proba(self.probs)[:,1]
            
        elif h and h_type == "MLP-l":
            probs_cal = h.predict_proba(self.logits)[:,1]
            
        elif h and h_type == "MLP-f":
            probs_cal = h.predict_proba(self.features)[:,1]
            
        elif h and h_type == "histogram":
            probs_cal = h.calibrate(np.max(self.probs, -1))
            
        elif h and h_type == "temp":
            logits_cal = h(torch.Tensor(self.logits).cuda())
            probs_cal = F.softmax(logits_cal, -1).detach().cpu().numpy()
            probs_cal = np.max(probs_cal, -1)
            
        else:
            probs_cal = np.max(self.probs, -1)

        y_hat = np.argmax(self.probs, -1)
        y_correct = np.array((y_hat==self.labels), dtype=int)
        
        ece_1 = calc_score("ECE-1", probs_cal, y_correct)
        ece_2 = calc_score("ECE-2", probs_cal, y_correct)
        brier = calc_score("Brier", probs_cal, y_correct)

        if verbose:
            print("ECE-1", ece_1)
            print("ECE-2", ece_2)
            print("Brier", brier)

        if print_plot:
            plt.hist(probs_cal)
            plt.show()

        return {
            "ECE-1": ece_1,
            "ECE-2": ece_2,
            "brier": brier
        }

    
    def measure_accuracy(self, verbose=True):

        y_hat = np.argmax(self.probs, -1)
        y_correct = np.array((y_hat==self.labels), dtype=int)

        acc = np.sum(y_correct)/self.logits.shape[0]

        if verbose:
            print("Accuracy:", acc)
        return acc
    
    def random_sample(self, n_samples):
        
        indices = torch.randperm(self.n)
        self.logits = self.logits[indices][:n_samples]
        self.features = self.features[indices][:n_samples]
        self.labels = self.labels[indices][:n_samples]
        self.probs = self.probs[indices][:n_samples]
        self.n = self.logits.shape[0]
        
        print(self.logits.shape, self.features.shape, self.labels.shape)
        
    def random_split(self, n_samples):
        
        indices = torch.randperm(self.n)
        
        logits = self.logits[indices][n_samples:]
        features = self.features[indices][n_samples:]
        labels = self.labels[indices][n_samples:]
        split_dataset = Dataset(
            logits=logits,
            labels=labels,
            features=features
        )
        
        self.logits = self.logits[indices][:n_samples]
        self.features = self.features[indices][:n_samples]
        self.labels = self.labels[indices][:n_samples]
        self.probs = self.probs[indices][:n_samples]
        self.n = self.logits.shape[0]
        
        print(self.logits.shape, self.features.shape, self.labels.shape)
        print(split_dataset.logits.shape, split_dataset.features.shape, split_dataset.labels.shape)
        
        return split_dataset


def extract_dataset(model, data_loader, loader_type="torch"):
    
    all_logits = []
    all_features = []
    all_labels = []
    with torch.no_grad():
        for i, batch in enumerate(tqdm(data_loader)):
            
            if loader_type=="hf":
                images = batch["pixel_values"]
                target = batch["label"]
            else:
                images, target = batch

            images = images.cuda()
            target = target.cuda()

            # measure accuracy
            logits, features = model.forward(images)
            all_logits.append(logits.cpu())
            all_features.append(features.cpu())
            all_labels.append(target.cpu())

    logits = torch.vstack(all_logits).float().numpy()
    features = torch.vstack(all_features).numpy()
    labels = torch.concat(all_labels).numpy()
    
    print(logits.shape, features.shape, labels.shape)
    return Dataset(
        logits=logits,
        labels=labels,
        features=features
    )
    
    
def load_wilds_dataset(args, split):

    save_root = "../data/processed/{}/{}/{}_".format(args.dataset, args.model, split)
    logits = torch.load(save_root+"logits.pt").cpu().numpy()
    features = torch.load(save_root+"features.pt").cpu().numpy()
    labels = torch.load(save_root+"labels.pt").cpu().numpy()
    
    print(logits.shape, features.shape, labels.shape)
    return Dataset(
        logits=logits,
        labels=labels,
        features=features
    )

def get_wilds_dataloader(dataset, args):
    
    dataset = torch.utils.data.TensorDataset(
        torch.Tensor(dataset.features), 
        torch.Tensor(dataset.logits), 
        torch.Tensor(dataset.labels)
    )
    return torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)


def dataloader_from_dataset(dataset, args):
    
    dataset = torch.utils.data.TensorDataset(
        torch.Tensor(dataset.features), 
        torch.Tensor(dataset.logits), 
        torch.Tensor(dataset.labels)
    )
    return torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)


def get_dataloader(dataset, batch_size, pre_process, mix=False, n_samples=None):
    
    if mix:
        val_transforms = transforms.Compose([
            transforms.AugMix(severity=3,mixture_width=3),
            pre_process
        ])
    else:
        val_transforms = transforms.Compose([
            pre_process
        ])
        
    if dataset == "cifar-100":
        cifar100 = CIFAR100("../data/raw/cifar-100", transform=val_transforms, download=True, train=False)
        
        if n_samples is not None:
            total_ex = 10000
            cifar100, _ = torch.utils.data.random_split(cifar100, [n_samples, total_ex-n_samples])
            
        data_loader = torch.utils.data.DataLoader(cifar100,
                                                  batch_size=batch_size,
                                                  shuffle=True,
                                                  num_workers=8
                                                 ) 
    elif dataset == "imagenet-v2":
        from imagenetv2_pytorch import ImageNetV2Dataset
        # images = ImageNetV2Dataset(transform=val_transforms)
        images = ImageNetV2Dataset("matched-frequency", transform=val_transforms)
        # from torchvision import datasets
        # images = datasets.ImageFolder(root='imagenetv2-matched-frequency')
        data_loader = torch.utils.data.DataLoader(images, batch_size=batch_size, shuffle=True, num_workers=8)
    elif dataset == "imagenet-sketch":
        from datasets import load_dataset
        dataset = load_dataset("imagenet_sketch", split="train")
        dataset = dataset.with_format("torch")
        def ds_transforms(examples):
            examples["pixel_values"] = [val_transforms(image.convert("RGB")) for image in examples["image"]]
            del examples["image"]
            return examples
        dataset = dataset.with_transform(ds_transforms)
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)
    else:
        raise ValueError
        
    return data_loader


def load_dataset_from_store(dataset, model, split=None):
    
    if dataset == "cifar-100-c":
        
        assert split == "test"
    
        print("Load OOD CIFAR-100-c")
        levels = [1,2,3,4,5]

        logits = []
        labels = []
        features = []

        for corruption in tqdm(CIFAR_CORRUPTIONS):

            for i in range(len(levels)):

                save_root = "../data/processed/cifar-100-c/{}/{}_{}_".format(model, corruption, levels[i])

                domain_logits = torch.load(save_root+"logits.pt").cpu()
                domain_features = torch.load(save_root+"features.pt").cpu()
                domain_labels = torch.load(save_root+"labels.pt").cpu()

                logits.append(domain_logits)
                labels.append(domain_labels)
                features.append(domain_features)

        logits = torch.vstack(logits).float().numpy()
        features = torch.vstack(features).numpy()
        labels = torch.concat(labels).numpy()
        
    elif dataset == "cifar-100":
        
        save_root = "../data/processed/cifar-100/{}/{}_".format(model, split)
        logits = torch.load(save_root+"logits.pt").cpu().numpy()
        features = torch.load(save_root+"features.pt").cpu().numpy()
        labels = torch.load(save_root+"labels.pt").cpu().numpy()

    elif dataset == "imagenet-sketch":
        
        save_root = "../data/processed/imagenet-sketch/{}/{}_".format(model, split)
        logits = torch.load(save_root+"logits.pt").cpu().numpy()
        features = torch.load(save_root+"features.pt").cpu().numpy()
        labels = torch.load(save_root+"labels.pt").cpu().numpy()
        
    elif dataset == "imagenet-v2":
        
        save_root = "../data/processed/imagenet-v2/{}/{}_".format(model, split)
        logits = torch.load(save_root+"logits.pt").cpu().numpy()
        features = torch.load(save_root+"features.pt").cpu().numpy()
        labels = torch.load(save_root+"labels.pt").cpu().numpy()
        
    elif dataset == "imagenet":
        save_root = "../data/processed/imagenet/{}/{}_".format(model, split)
        logits = torch.load(save_root+"logits.pt").cpu().numpy()
        features = torch.load(save_root+"features.pt").cpu().numpy()
        labels = torch.load(save_root+"labels.pt").cpu().numpy()
        
    else:
        return ValueError
    
    print(logits.shape, features.shape, labels.shape)
    return Dataset(
        logits=logits,
        labels=labels,
        features=features
    )
