import numpy as np
import functools
import random
from tqdm import tqdm
import torch
import inspect
from log import make_logger

logger = make_logger('logs','conditional-eval')


def gaussian_kernel_decorator(function):
    def wrap_kernel(self, *args, **kwargs):
        # Get the function's signature
        sig = inspect.signature(function)
        params = list(sig.parameters.keys())
        
        # Determine if `compute_kernel` specific parameter is in args or kwargs
        bound_args = sig.bind_partial(*args, **kwargs).arguments
        compute_kernel = bound_args.get('compute_kernel', True)

        if compute_kernel is True:
            args = list(args)  # To be able to edit args
            if 'X' in params:
                index = params.index('X') - 1
                args[index] = self.gaussian_kernel(args[index])

            if 'Y' in params:
                index = params.index('Y') - 1
                if args[index] is not None:
                    args[index] = self.gaussian_kernel(args[index])

        return function(self, *args, **kwargs)

    return wrap_kernel


def entropy_q(p, q=1):
    p_ = p[p > 0]
    if q == 1:
        return -(p_ * torch.log2(p_)).sum()
    if q == "inf":
        return -torch.log2(torch.max(p))
    return torch.log2((p_ ** q).sum()) / (1 - q)


class ConditionalEvaluation:
    def __init__(self, similarity_function=None, sigma=None):
        if similarity_function is None and sigma is None:
            raise ValueError("Both similarity_function and sigma can not be None")
        self.similarity = similarity_function
        self.sigma = sigma
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # def predict_n_clusters_v2(self, X, n_samples=1024):
    #     similarity_holder = []
    #     for i in tqdm(range(n_samples)):
    #         i, j = np.random.randint(X.shape[0], size=2)
    #         similarity_holder.append(self.similarity(X[i], X[j])**2)
    #     return 1 / np.mean(similarity_holder)

    def gaussian_kernel(self, x, y=None, sigma=None, batchsize=256, normalize=True):
        '''
        calculate the kernel matrix, the shape of x and y should be equal except for the batch dimension

        x:
            input, dim: [batch, dims]
        y:
            input, dim: [batch, dims], If y is `None` then y = x and it will compute k(x, x).
        sigma:
            bandwidth parameter
        batchsize:
            Batchify the formation of kernel matrix, trade time for memory
            batchsize should be smaller than length of data

        return:
            scalar : mean of kernel values
        '''
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).to(self.device)
            y = x if y is None else torch.from_numpy(y).to(self.device)
        else:
            x = x.to(self.device)
            y = x if y is None else y.to(self.device)

        batch_num = (y.shape[0] // batchsize) + 1
        assert (x.shape[1:] == y.shape[1:])

        if sigma is None:
            sigma = self.sigma

        total_res = torch.zeros((x.shape[0], 0), device=x.device)
        for batchidx in range(batch_num):
            y_slice = y[batchidx*batchsize:min((batchidx+1)*batchsize, y.shape[0])]
            res = torch.norm(x.unsqueeze(1)-y_slice, dim=2, p=2).pow(2)
            res = torch.exp((- 1 / (2*sigma*sigma)) * res)
            total_res = torch.hstack([total_res, res])

            del res, y_slice

        if normalize is True:
            total_res = total_res / np.sqrt(x.shape[0] * y.shape[0])

        return total_res
        
    def cosine_kernel(self, x, y=None, batchsize=256, normalize=True):
        '''
        Calculate the cosine similarity kernel matrix. The shape of x and y should be equal except for the batch dimension.

        x:
            Input tensor, dim: [batch, dims]
        y:
            Input tensor, dim: [batch, dims]. If y is `None`, then y = x and it will compute cosine similarity k(x, x).
        batchsize:
            Batchify the formation of the kernel matrix, trade time for memory.
            batchsize should be smaller than the length of data.

        return:
            Scalar: Mean of cosine similarity values
        '''
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).to(self.device)
            y = x if y is None else torch.from_numpy(y).to(self.device)
        else:
            x = x.to(self.device)
            y = x if y is None else y.to(self.device)

        batch_num = (y.shape[0] // batchsize) + 1
        assert (x.shape[1:] == y.shape[1:])

        total_res = torch.zeros((x.shape[0], 0), device=x.device)
        for batchidx in range(batch_num):
            y_slice = y[batchidx * batchsize:min((batchidx + 1) * batchsize, y.shape[0])]
            
            # Normalize x and y_slice
            x_norm = x / x.norm(dim=1, keepdim=True)
            y_norm = y_slice / y_slice.norm(dim=1, keepdim=True)
            
            # Calculate cosine similarity
            res = torch.mm(x_norm, y_norm.T)

            total_res = torch.hstack([total_res, res])

            del res, y_slice

        if normalize is True:
            total_res = total_res / (x.shape[0] * y.shape[0])

        return total_res

    @gaussian_kernel_decorator
    def compute_entropy(self, X, order, compute_kernel=True):
        assert X.shape[0] == X.shape[1]

        if order == 1:  # Shannon entropy
            vals = torch.linalg.eigvalsh(X)
            return entropy_q(vals, q=order)  # TODO check if we want exp or not. also check log or log2!
        if order == 1.5:
            vals = torch.linalg.eigvalsh(X)
            return entropy_q(vals, q=order)
        elif order == 2:
                frobenius_norm_squared = torch.linalg.norm(X, 'fro')**2
                trace_X_squared = frobenius_norm_squared
                # Calculate S_2
                return -torch.log2(trace_X_squared)
        else:
            raise NotImplementedError()

    @gaussian_kernel_decorator
    def compute_joint_entropy(self, X, Y, order, compute_kernel=True):
        assert X.shape[0] == X.shape[1]
        XoY = X * Y  # Hadamard product
        S_AB = XoY / torch.trace(XoY)
        return self.compute_entropy(X=S_AB, order=order, compute_kernel=False)

    @gaussian_kernel_decorator
    def conditional_entropy(self, X, Y,  order, n_samples=10_000, compute_kernel=True):  # H_a(X|Y)
        entropy_Y = self.compute_entropy(Y, order=order, compute_kernel=False)
        entropy_X = self.compute_entropy(X, order=order, compute_kernel=False)
        entropy_joint = self.compute_joint_entropy(X, Y, order=order, compute_kernel=False)
        if True:
            logger.info(f'joint: {entropy_joint}, X: {entropy_X}, Y: {entropy_Y}')
            logger.info(f'H(X|Y) = {entropy_joint - entropy_Y}, I(X; Y) = {entropy_X + entropy_Y - entropy_joint}')
        return entropy_joint - entropy_Y, entropy_X + entropy_Y - entropy_joint, entropy_joint, entropy_X, entropy_Y  # H(X|Y), I(X, Y), H(X, Y), H(X), H(Y)
        #return entropy_joint - entropy_Y TODO: You should only return this

    @gaussian_kernel_decorator
    def mutual_information_order_2(self, X, Y, order, compute_kernel=True, entropy='shannon'):
        i_xy = self.compute_entropy(X, order=order, compute_kernel=False) + self.compute_entropy(Y, order=order, compute_kernel=False) - self.compute_joint_entropy(X, Y, order=order, compute_kernel=False)
        logger.info(f'I(X, Y) = {i_xy}')
        return i_xy

    # def visulize_modes()