# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import torch
import torch.nn as nn
from godin.threshold import find_threshold_ind_data
import numpy as np
from godin.utils_godin import show_stats
from godin.utils_godin import generate_scores
from torch.autograd import Variable


def norm(x):
    norm = torch.norm(x, p=2, dim=1)
    x = x / (norm.expand(1, -1).t() + .0001)
    return x


# self.weights = torch.nn.Parameter(torch.randn(size = (num_classes, in_features)) * math.sqrt(2 / (in_features)))
class CosineDeconf(nn.Module):
    def __init__(self, in_features, num_classes):
        super(CosineDeconf, self).__init__()

        self.h = nn.Linear(in_features, num_classes, bias=False)
        self.init_weights()

    def init_weights(self):
        nn.init.kaiming_normal_(self.h.weight.data, nonlinearity="relu")

    def forward(self, x):
        x = norm(x)
        w = norm(self.h.weight)

        ret = (torch.matmul(x, w.T))
        return ret


class EuclideanDeconf(nn.Module):
    def __init__(self, in_features, num_classes):
        super(EuclideanDeconf, self).__init__()

        self.h = nn.Linear(in_features, num_classes, bias=False)
        self.init_weights()

    def init_weights(self):
        nn.init.kaiming_normal_(self.h.weight.data, nonlinearity="relu")

    def forward(self, x):
        x = x.unsqueeze(2)  # (batch, latent, 1)
        h = self.h.weight.T.unsqueeze(0)  # (1, latent, num_classes)
        ret = -((x - h).pow(2)).mean(1)
        return ret


class InnerDeconf(nn.Module):
    def __init__(self, in_features, num_classes):
        super(InnerDeconf, self).__init__()

        self.h = nn.Linear(in_features, num_classes)
        self.init_weights()

    def init_weights(self):
        nn.init.kaiming_normal_(self.h.weight.data, nonlinearity="relu")
        self.h.bias.data = torch.zeros(size=self.h.bias.size())

    def forward(self, x):
        return self.h(x)


class DeconfNet(nn.Module):
    def __init__(self, underlying_model, in_features, num_classes, h, baseline):
        super(DeconfNet, self).__init__()

        self.num_classes = num_classes

        self.underlying_model = underlying_model

        self.h = h

        self.baseline = baseline

        if baseline:
            self.ones = nn.Parameter(torch.Tensor([1]), requires_grad=True)
        else:
            self.g = nn.Sequential(
                nn.Linear(in_features, 1),
                nn.BatchNorm1d(1),
                nn.Sigmoid()
            )

        self.softmax = nn.Softmax()

    def forward(self, x):
        output = self.underlying_model(x)
        numerators = self.h(output)

        if self.baseline:
            denominators = torch.unsqueeze(self.ones.expand(len(numerators)), 1)
        else:
            denominators = self.g(output)

        # Now, broadcast the denominators per image across the numerators by division
        quotients = numerators / denominators

        # logits, numerators, and denominators
        return quotients, numerators, denominators, output


class DeconfNetOOD(nn.Module):
    def __init__(self, underlying_model, in_features, num_classes, h, baseline,
                 noise_magnitude):
        """
        Check the paper: https://arxiv.org/pdf/2002.11297.pdf

        :param underlying_model: the underlying main model, e.g., ResNet-34.
        :param in_features: input features
        :param num_classes: number of classes
        :param h: the h function to compute the OOD scores p(y, d_in | x)
        :param baseline:
        :param noise_magnitude: for GODIN, pre-processing strategy to modiy
            input, noise magnitude is the perturbation magnitude, which in this
            case does not rely on the out-of-distribution data.
        :param val_set: the validation set to find the
        :param d
        """
        super(DeconfNetOOD, self).__init__()

        self.num_classes = num_classes

        self.underlying_model = underlying_model

        self.h = h

        self.baseline = baseline

        self.noise_magnitude = noise_magnitude

        # The threshold for scores. If a score has a value below the threshold
        # then we set the value to the threshold.
        self.threshold = None

        if baseline:
            self.ones = nn.Parameter(torch.Tensor([1]), requires_grad=True)
        else:
            self.g = nn.Sequential(
                nn.Linear(in_features, 1),
                nn.BatchNorm1d(1),
                nn.Sigmoid()
            )

    def set_threshold(self, percentile_threshold, val_set, device):
        """
        percentile_threshold: we find the threshold for which the
            percentile of the data points have lower scores than the threshold
        :return: threshold
        """
        scores = generate_scores(
            model=self, CUDA_DEVICE=device, data_loader=val_set,
            title='Scores for thresholding from below of the scores.')
        show_stats(scores)

        # Pre-set the self.threshold to None to find the initial raw scores
        # without a lower bound on them.
        self.threshold = None
        threshold = find_threshold_ind_data(
            ind_test_scores=scores, percentile=percentile_threshold)
        # Set the newly found threshold.
        self.threshold = threshold
        return threshold

    def forward(self, x):
        logits, _, _ = self.forward_for_scores(images=x)
        return logits

    def forward_for_scores(self, images):
        self.output = self.underlying_model(images)
        numerators = self.h(self.output)

        if self.baseline:
            denominators = torch.unsqueeze(self.ones.expand(len(numerators)), 1)
        else:
            denominators = self.g(self.output)

        # Now, broadcast the denominators per image across the numerators by division
        quotients = numerators / denominators

        # logits, numerators, and denominators
        # quotients, numerators, denominators
        logits = quotients
        h = numerators
        g = denominators

        return logits, h, g

    def get_scores(self, images, score_func='h'):
        # images = Variable(images.to(images.device), requires_grad=True)
        logits, h, g = self.forward_for_scores(images=images)

        if score_func == 'h':
            scores = h
        elif score_func == 'g':
            scores = g
        elif score_func == 'logit':
            scores = logits
        else:
            raise Exception(f"Unsupported score function: {score_func}.")

        # Calculating the perturbation we need to add, that is,
        # the sign of gradient of the numerator w.r.t. input

        max_scores, _ = torch.max(scores, dim=1)
        max_scores.backward(torch.ones(len(max_scores)).to(images.device))

        # Normalizing the gradient to binary in {-1, 1}
        if images.grad is not None:
            gradient = torch.ge(images.grad.data, 0)
            gradient = (gradient.float() - 0.5) * 2
            # Normalizing the gradient to the same space of image
            gradient[::, 0] = (gradient[::, 0]) / (63.0 / 255.0)
            gradient[::, 1] = (gradient[::, 1]) / (62.1 / 255.0)
            gradient[::, 2] = (gradient[::, 2]) / (66.7 / 255.0)
            # Adding small perturbations to images
            tempInputs = torch.add(images.data, gradient,
                                   alpha=self.noise_magnitude)

            # Now calculate score
            logits, h, g = self.forward_for_scores(tempInputs)

            if score_func == 'h':
                scores = h
            elif score_func == 'g':
                scores = g
            elif score_func == 'logit':
                scores = logits

        raw_scores = torch.detach(scores).data.cpu().numpy()
        results = np.max(raw_scores, axis=1)
        results = 1 / results

        if self.threshold is not None:
            results = np.maximum(results, self.threshold)

        return results
