# adapted from https://github.com/lukas-kuhn/MMD_calc/blob/master/calib/density_aware_calib.py
import numpy as np
from scipy import optimize

import time
import gc


def np_softmax(x):
    max = np.max(
        x, axis=1, keepdims=True
    )  # returns max of each row and keeps same dims
    e_x = np.exp(x - max)  # subtracts each row with its max value
    sum = np.sum(
        e_x, axis=1, keepdims=True
    )  # returns sum of each row and keeps same dims
    f_x = e_x / sum
    return f_x


class DATS(object):
    def __init__(
        self,
        ood_values_num=1,
        tol=1e-12,
        eps=1e-7,
        disp=False,
        mmd_bounds=False,  # to print optimization process
    ):
        """
        T = (w_i * knn_score_i) + w0
        p = softmax(logits / T)
        """
        self.method = "L-BFGS-B"

        self.ood_values_num = ood_values_num
        print("ood_values_num: ", self.ood_values_num)

        self.tol = tol
        self.eps = eps
        self.disp = disp

        if not mmd_bounds:
            self.bnds = [[0, 10000.0]] * self.ood_values_num + [[-100.0, 100.0]]
        else:
            self.bnds = [[-10.0, 10000.0]] * self.ood_values_num + [[-100.0, 100.0]]

        self.init = [1.0] * self.ood_values_num + [1.0]

    def get_temperature(self, w, ood_score):
        if self.ood_values_num == 1:
            if type(ood_score).__module__ == np.__name__:
                if len(ood_score.shape) == 1:
                    ood_score = [ood_score]
                else:
                    ood_score = [ood_score[i, :] for i in range(ood_score.shape[0])]

        assert len(ood_score) == self.ood_values_num, (
            ood_score,
            len(ood_score),
            self.ood_values_num,
        )

        if len(ood_score) != 0:
            sample_size = len(ood_score[0])
            t = np.zeros(sample_size)

            for i in range(self.ood_values_num):
                t += w[i] * ood_score[i]
            t += w[-1]
        else:
            # temperature scaling
            t = np.zeros(1)
            t += w[-1]

        # return t
        # temperature should be a positive value
        return np.clip(t, 1e-20, None)

    def mse_lf(self, w, *args):
        ## find optimal temperature with MSE loss function
        logit, label, ood_score = args
        t = self.get_temperature(w, ood_score)
        logit = logit / t[:, None]
        p = np_softmax(logit)
        mse = np.mean((p - label) ** 2)
        return mse

    def ll_lf(self, w, *args):
        ## find optimal temperature with Cross-Entropy loss function
        logit, label, ood_score = args
        t = self.get_temperature(w, ood_score)
        logit = logit / t[:, None]
        p = np_softmax(logit)
        N = p.shape[0]
        ce = -np.sum(label * np.log(p + 1e-12)) / N
        return ce

    def optimize(self, logit, label, ood_score, loss="ce"):
        """
        logit (N, C): classifier's outputs before softmax
        label (N, C): true labels, one-hot
        ood_score (N, number_of_scores): OOD scores for each sample.
            the value that represents how far the sample is in the feature space.
        """
        if not isinstance(self.eps, list):
            self.eps = [self.eps]

        if loss == "ce":
            func = self.ll_lf
        elif loss == "mse":
            func = self.mse_lf
        else:
            raise NotImplementedError

        # func:ll_t, 1.0:initial guess, args: args of the func, ..., tol: tolerence of minimization
        st = time.time()
        params = optimize.minimize(
            func,
            self.init,
            args=(logit, label, ood_score),
            method=self.method,
            bounds=self.bnds,
            tol=self.tol,
            options={"eps": self.eps, "disp": self.disp},
        )
        ed = time.time()

        w = params.x
        print("DAC Optimization done!: ({} sec)".format(ed - st))
        print(f"T = {w[:-1]} * ood_score_i + {w[-1]}")

        optim_value = params.fun
        self.w = w

        return self.get_optim_params()

    def calibrate(self, logits, ood_score):
        w = self.w
        t = self.get_temperature(w, ood_score)
        return np_softmax(logits / t[:, None])

    def calibrate_before_softmax(self, logits, ood_score):
        w = self.w
        t = self.get_temperature(w, ood_score)
        return logits / t[:, None], t

    def get_optim_params(self):
        # print(f"T = {self.w} * ood_score")
        return self.w