import torch
import numpy as np
from exps.metrics import Metric
from typing import Any, Callable, Optional
from math import pi, sqrt
import torch.nn as nn
__all__ = ['DEP']

class MetricMI(nn.Module):
    def __init__(self):
        super(MetricMI, self).__init__()
        # self.opts = opts
        # self.kernel1 = getattr(models, opts.kernel_type)(sigma=opts.sigma)
        # self.kernel2 = getattr(models, opts.kernel_type)(sigma=opts.sigma_2)

    def forward(self, inputs, target):

        def calculate_gram_mat(x, sigma):
            """calculate gram matrix for variables x
                Args:
                x: random variable with two dimensional (N,d).
                sigma: kernel size of x (Gaussain kernel)
            Returns:
                Gram matrix (N,N)
            """
            x = x.view(x.shape[0], -1)
            instances_norm = torch.sum(x ** 2, -1).reshape((-1, 1))
            dist = -2 * torch.mm(x, x.t()) + instances_norm + instances_norm.t()
            return torch.exp(-dist / sigma)

        def renyi_entropy(x, sigma, alpha):
            """calculate entropy for single variables x (Eq.(9) in paper)
                Args:
                x: random variable with two dimensional (N,d).
                sigma: kernel size of x (Gaussain kernel)
                alpha:  alpha value of renyi entropy
            Returns:
                renyi alpha entropy of x.
            """

            k = calculate_gram_mat(x, sigma)
            # import pdb; pdb.set_trace()

            k = k / torch.trace(k)
            eigv = torch.abs(torch.symeig(k, eigenvectors=True)[0])
            eig_pow = eigv ** alpha
            entropy = (1 / (1 - alpha)) * torch.log2(torch.sum(eig_pow))
            return entropy

        def joint_entropy(x, y, s_x, s_y, alpha):
            """calculate joint entropy for random variable x and y (Eq.(10) in paper)
                Args:
                x: random variable with two dimensional (N,d).
                y: random variable with two dimensional (N,d).
                s_x: kernel size of x
                s_y: kernel size of y
                alpha:  alpha value of renyi entropy
            Returns:
                joint entropy of x and y.
            """

            x = calculate_gram_mat(x, s_x)
            y = calculate_gram_mat(y, s_y)
            k = torch.mul(x, y)
            k = k / torch.trace(k)
            eigv = torch.abs(torch.symeig(k, eigenvectors=True)[0])
            eig_pow = eigv ** alpha
            entropy = (1 / (1 - alpha)) * torch.log2(torch.sum(eig_pow))
            return entropy

        """calculate Mutual information between random variables x and y
        Args:
            x: random variable with two dimensional (N,d).
            y: random variable with two dimensional (N,d).
            s_x: kernel size of x
            s_y: kernel size of y
            normalize: bool True or False, noramlize value between (0,1)
        Returns:
            Mutual information between x and y (scale)
        """
        # Hx = renyi_entropy(inputs, sigma=.005, alpha=1.01)  # kernel
        Hx = renyi_entropy(inputs, sigma=1.5, alpha=1.01)
        Hy = renyi_entropy(target, sigma=1.5, alpha=1.01)
        Hxy = joint_entropy(inputs, target, 1.5, 1.5, alpha=1.01)
        # if normalize:
        Ixy = Hx + Hy - Hxy
        # import pdb; pdb.set_trace()
        Ixy = Ixy / (torch.max(Hx, Hy))
        # else:
        # Ixy = Hx + Hy - Hxy
        return Ixy


####################################################### main ##################################################3
class DEP(Metric):
    def __init__(self,
                 compute_on_step: bool = False,
                 dist_sync_on_step: bool = False,
                 process_group: Optional[Any] = None,
                 dist_sync_fn: Callable = None,
                 ):
        super().__init__(
            compute_on_step=compute_on_step,
            dist_sync_on_step=dist_sync_on_step,
            process_group=process_group,
            dist_sync_fn=dist_sync_fn,
        )
        self.add_state("zz", default=[], dist_reduce_fx=None)
        self.add_state("ss", default=[], dist_reduce_fx=None)

    def update(self, z, s):

        zz = z
        ss = s

        self.zz.append(zz)
        self.ss.append(ss)

    def compute(self):
        ind_measure = MetricMI()
        zz = torch.cat(self.zz, dim=0)
        ss = torch.cat(self.ss, dim=0)
        ind = ind_measure(zz, ss)
        return ind