import libmr
import torch
import torch.nn.functional as F
from abc import abstractmethod
from torch import nn
from tqdm import tqdm
import numpy as np
import torch
import libmr


class OOD_detection_base(nn.Module):
    def __init__(self, model: nn.Module):
        super(OOD_detection_base, self).__init__()
        self.model = model

    @abstractmethod
    def forward(self, inputs):
        pass

    def attackForward(self, inputs):
        return self.__class__.forward(self, inputs)


class OOD_ATOM(OOD_detection_base):
    def forward(self, inputs):
        outputs = self.model(inputs)
        scores = (F.softmax(outputs, dim=1)[:, -1])
        return scores

    # def attackForward(self, inputs):
    #     outputs = self.model(inputs)
    #     scores = (F.softmax(outputs, dim=1)[:, -1])
    #     return scores


class OOD_MSP(OOD_detection_base):
    def forward(self, inputs):
        logits = self.model(inputs)
        probs = torch.softmax(logits, dim=1)
        scores, _ = torch.max(probs, dim=1)
        return -scores

    def attackForward(self, inputs):
        logits: torch.tensor = self.model(inputs)
        # return (logits.mean(1) - torch.logsumexp(logits, dim=1))
        scores = torch.mean(torch.log_softmax(logits, dim=1), dim=1)
        return scores


class OOD_maha_dist(OOD_detection_base):
    def __init__(self, model: nn.Module, trainloader_in, device):
        super().__init__(model)
        self.device = device
        dataset_name, data_loader, self.indist_classes = trainloader_in
        trainloader_in = (dataset_name, data_loader)
        self.preprocess_maha(trainloader_in)

    def forward(self, inputs):
        embedding = self.model.embedding(self.model, inputs).double()
        indist_dists = []
        for c in range(len(self.model.class_means)):
            indist_offset_now = embedding - self.model.class_means[c].reshape([1, -1])
            maha_dists_now = torch.sum(
                torch.matmul(indist_offset_now, self.model.class_cov_invs[c]) * indist_offset_now,
                dim=1)
            indist_dists.append(maha_dists_now)
        indist_dists_byclass = torch.stack(indist_dists, dim=1)
        indist_min, _ = torch.min(indist_dists_byclass, dim=1)
        return indist_min

    @torch.no_grad()
    def preprocess_maha(self, trainloader_in):
        name_in, trainloader_in = trainloader_in
        total_batches = len(trainloader_in)

        train_embeds_in, y_train = [], []
        print("TOTAL TRAINING SAMPLES:", total_batches * 128)
        for step, data in enumerate(tqdm(trainloader_in, total=total_batches, leave=False)):
            images, labels = data
            images = images.to(self.device)
            embedding = self.model.embedding(self.model, images)
            train_embeds_in.append(embedding.cpu())
            y_train.append(labels)

        y_train = torch.cat(y_train).cpu().numpy()
        train_embeds_in = torch.cat(train_embeds_in).cpu()

        description, maha_intermediate_dict = OOD_maha_dist.get_maha_distance_scores(
            np.array(train_embeds_in)[:, :],
            y_train,
            indist_classes=self.indist_classes,
            subtract_mean=False,
            normalize_to_unity=False,
            subtract_train_distance=False,
        )

        class_means = maha_intermediate_dict["class_means"]
        class_cov_invs = maha_intermediate_dict["class_cov_invs"]

        ######## to tensor
        for c in range(self.indist_classes):
            class_means[c] = torch.tensor(class_means[c])
            class_cov_invs[c] = torch.tensor(class_cov_invs[c])
            # print(f"{c} min = {torch.min(torch.abs(class_cov_invs[c])).item()}  max = {torch.max(torch.abs(class_cov_invs[c])).item()} 50% = {torch.quantile(torch.abs(class_cov_invs[c]), 0.5).item()}")
        ########
        self.model.class_means = class_means
        self.model.class_cov_invs = class_cov_invs
        for c in range(self.indist_classes):
            self.model.class_means[c] = self.model.class_means[c].to(self.device)
            self.model.class_cov_invs[c] = self.model.class_cov_invs[c].to(self.device)
        ###########

        self.model.training_mean = torch.tensor(maha_intermediate_dict["mean"]).to(self.device)
        self.model.training_cov_inv = torch.tensor(maha_intermediate_dict["cov_inv"]).to(self.device)

    @staticmethod
    def get_maha_distance_scores(
            indist_train_embeds_in,
            indist_train_labels_in,
            subtract_mean=True,
            normalize_to_unity=True,
            subtract_train_distance=True,
            indist_classes=100,
            norm_name="L2",
    ):
        # storing the replication results
        maha_intermediate_dict = dict()

        description = ""

        all_train_mean = np.mean(indist_train_embeds_in, axis=0, keepdims=True)

        indist_train_embeds_in_touse = indist_train_embeds_in

        if subtract_mean:
            indist_train_embeds_in_touse -= all_train_mean
            description = description + " subtract mean,"

        if normalize_to_unity:
            indist_train_embeds_in_touse = indist_train_embeds_in_touse / np.linalg.norm(indist_train_embeds_in_touse,
                                                                                         axis=1, keepdims=True)
            description = description + " unit norm,"

        # full train single fit
        mean = np.mean(indist_train_embeds_in_touse, axis=0)
        cov = np.cov((indist_train_embeds_in_touse - (mean.reshape([1, -1]))).T)

        eps = 1e-8
        try:
            cov_inv = np.linalg.inv(cov)
        except:
            cov_inv = np.linalg.pinv(cov)

        # getting per class means and covariances
        class_means = []
        class_cov_invs = []
        class_covs = []
        for c in range(indist_classes):
            mean_now = np.mean(indist_train_embeds_in_touse[indist_train_labels_in == c], axis=0)

            cov_now = np.cov(
                (indist_train_embeds_in_touse[indist_train_labels_in == c] - (mean_now.reshape([1, -1]))).T)
            class_covs.append(cov_now)
            # print(c)

            eps = 1e-8
            try:
                cov_inv_now = np.linalg.inv(cov_now)
            except:
                cov_inv_now = np.linalg.pinv(cov_now)

            class_cov_invs.append(cov_inv_now)
            class_means.append(mean_now)

        # the average covariance for class specific
        cov_arr = np.mean(np.stack(class_covs, axis=0), axis=0)
        try:
            inv_cov_arr = np.linalg.inv(cov_arr)
        except:
            inv_cov_arr = np.linalg.pinv(cov_arr)

        class_cov_invs = [inv_cov_arr] * len(class_covs)

        maha_intermediate_dict["class_cov_invs"] = class_cov_invs
        maha_intermediate_dict["class_means"] = class_means
        maha_intermediate_dict["cov_inv"] = cov_inv
        maha_intermediate_dict["mean"] = mean

        return description, maha_intermediate_dict


class OOD_rel_maha_dist(OOD_maha_dist):
    def forward(self, inputs):
        embedding = self.model.embedding(self.model, inputs).double()
        indist_dists = []
        for c in range(len(self.model.class_means)):
            indist_offset_now = embedding - self.model.class_means[c].reshape([1, -1])
            maha_dists_now = torch.sum(
                torch.matmul(indist_offset_now, self.model.class_cov_invs[c]) * indist_offset_now,
                dim=1)
            indist_dists.append(maha_dists_now)
        indist_dists_byclass = torch.stack(indist_dists, dim=1)
        indist_min, _ = torch.min(indist_dists_byclass, dim=1)
        training_offset_now = embedding - self.model.training_mean.reshape([1, -1])
        training_dists_now = torch.sum(
            torch.matmul(training_offset_now, self.model.training_cov_inv) * training_offset_now,
            dim=1)

        return indist_min - training_dists_now


class OOD_openMax(OOD_detection_base):
    def __init__(self, model: nn.Module, trainloader_in, device, robust_model=True):
        super().__init__(model)
        self.device = device
        dataset_name, data_loader, self.indist_classes = trainloader_in
        trainloader_in = (dataset_name, data_loader)
        self.preprocess_openmax(trainloader_in, robust_model)

    def forward(self, inputs):
        n = inputs.shape[0]
        logits = self.model(inputs).double()
        c = logits.shape[1]
        intermed = self.model(inputs).double()

        weibull_scores = torch.ones_like(logits).double()
        for i in range(n):
            for j in range(c):
                mav = self.model.mean_activation_vectors[j].double()
                dist = torch.norm(intermed[i] - mav, p=2)
                weibull_scores[i, j] = 1 - self.model.weibulls[j].getp(dist)

        return -torch.log(torch.sum(torch.exp(logits * weibull_scores), dim=1))

    def preprocess_openmax(self, trainloader_in, robust_model):
        activation_vectors, mean_activation_vectors, weibulls = OOD_openMax.our_precalc_weibull(trainloader_in,
                                                                                                self.model,
                                                                                                robust_model)
        self.model.mean_activation_vectors = {}
        for class_id in mean_activation_vectors:
            self.model.mean_activation_vectors[class_id] = torch.tensor(mean_activation_vectors[class_id],
                                                                        device=self.device)

        self.model.weibulls = {}
        for class_ind in weibulls:
            self.model.weibulls[class_ind] = TorchWeibulls(weibulls[class_ind])

    def attackForward(self, inputs):
        n = inputs.shape[0]
        logits = self.model(inputs).double()
        c = logits.shape[1]
        intermed = self.model(inputs).double()

        weibull_scores = torch.ones_like(logits).double()
        for i in range(n):
            for j in range(c):
                mav = self.model.mean_activation_vectors[j].double()
                dist = torch.norm(intermed[i] - mav, p=2)
                weibull_scores[i, j] = 1 - self.model.weibulls[j].getpDerPass(dist)

        return -torch.log(torch.sum(torch.exp(logits * weibull_scores), dim=1))

    @staticmethod
    def our_precalc_weibull(dataloader_train, model, robust_model):
        dataset_name, dataloader_train = dataloader_train
        total_batches = len(dataloader_train)

        WEIBULL_TAIL_SIZE = 20

        # First generate pre-softmax 'activation vectors' for all training examples
        print("Weibull: computing features for all correctly-classified training data")
        activation_vectors = {}

        from modelZoo.ECCV2020OSAD.advertorch.attacks import PGDAttack
        adversary = PGDAttack(predict1=model, predict2=None, nb_iter=10)

        for step, (images, labels) in enumerate(tqdm(dataloader_train, total=total_batches, leave=False)):
            if step == total_batches:
                break

            images, labels = images.cuda(), labels.long().cuda()

            if robust_model:
                advimg = adversary.perturb(images, labels)
            else:
                advimg = images

            with torch.no_grad():
                logits = model(advimg)

            correctly_labeled = (logits.data.max(1)[1] == labels)
            labels_np = labels.cpu().numpy()
            logits_np = logits.data.cpu().numpy()
            for i, label in enumerate(labels_np):
                if not correctly_labeled[i]:
                    continue
                # If correctly labeled, add this to the list of activation_vectors for this class
                if label not in activation_vectors:
                    activation_vectors[label] = []
                activation_vectors[label].append(logits_np[i])
        print("Computed activation_vectors for {} known classes".format(len(activation_vectors)))
        for class_idx in activation_vectors:
            print("Class {}: {} images".format(class_idx, len(activation_vectors[class_idx])))

        # Compute a mean activation vector for each class
        print("Weibull computing mean activation vectors...")
        mean_activation_vectors = {}
        for class_idx in activation_vectors:
            mean_activation_vectors[class_idx] = np.array(activation_vectors[class_idx]).mean(axis=0)

        # Initialize one libMR Wiebull object for each class
        print("Fitting Weibull to distance distribution of each class")
        weibulls = {}
        for class_idx in activation_vectors:
            distances = []
            mav = mean_activation_vectors[class_idx]
            for v in activation_vectors[class_idx]:
                distances.append(np.linalg.norm(v - mav))
            mr = libmr.MR()
            tail_size = min(len(distances), WEIBULL_TAIL_SIZE)
            mr.fit_high(distances, tail_size)
            weibulls[class_idx] = mr
            print("Weibull params for class {}: {}".format(class_idx, mr.get_params()))

        return activation_vectors, mean_activation_vectors, weibulls


def MSPAT(model, inputs):
    logits = model(inputs)
    scores, _ = torch.max(logits, dim=1)
    return -scores


def UNIFORM(model, inputs):
    logits = model(inputs)
    scores = F.cross_entropy(logits, torch.ones_like(logits) / len(model.class_means), reduction="none")
    return -scores


class TorchWeibulls:
    def __init__(self, weibull: libmr.MR):
        self.weibull = weibull
        self.scale, self.shape, self.sign, self.translate_amount, self.small_score = weibull.get_params()
        assert (self.scale > 0)
        assert (self.shape > 0)

    def getp(self, x):
        return self.weibull.w_score(x)

    def getpDerPass(self, x):
        translated_x = x * self.sign + self.translate_amount - self.small_score
        wscore = self.weibull_cdf(translated_x)
        return wscore

    def weibull_cdf(self, x):
        if x < 0:
            return 0

        tempVal = x / self.scale
        tempVal1 = torch.pow(tempVal, self.shape)
        cdf = 1 - torch.exp(-1 * tempVal1)
        return cdf
