import io
import os
import sys
import matlab 
import matlab.engine as mateng
from IPython.display import clear_output
import ipywidgets as widgets

import numpy as np

from math import log2
from scipy.sparse import find
from sklearn.metrics.cluster import contingency_matrix
from sklearn.metrics.cluster._supervised import entropy

from typing import List

from torch import float64

log2e = 1.442695040888963
loge2 = 0.693147180559945

def bin_data_equisize(data, nbins = 20):
    # Check if data is a masked array, if so use the compressed version for histogram
    if isinstance(data, np.ma.core.MaskedArray): data = data.compressed()

    _, bins = np.histogram(data, bins = nbins)
    return np.digitize(data, bins)

def bin_data_equipop(data, nbins = 20):
    # Check if data is a masked array, if so use the compressed version for histogram
    if isinstance(data, np.ma.core.MaskedArray): data = data.compressed()
    
    # Calculate the bin edges by sorting the data array and then splitting it into nbins chuncks
    sort_idx = np.argsort(data)

    bins = [data[sort_idx[i]] for i in np.linspace(0, len(data) - 1, num = nbins + 1, dtype = np.int)]
    return np.digitize(data, bins)

def _pt_bias(p : np.ndarray, N : int) -> float:
    '''
        This function computes the Panzeri-Treves (1996) bais estimate
        (due to subsampling) for the mutual information. It uses a 
        bayesian estimate for computing the correct support of the
        probability distribution.

        Inputs:
        - p[np.ndarray]: 1-D Probability distribution
        - N[int]: Number of trials

        Return:
        - R[float]: Estimate of Panzeri-Treves bias 
    '''
    # assert np.abs(np.sum(p)) - 1. < np.finfo(np.float).eps, np.abs(np.sum(p)) 

    # Get the number of non-zero probabilities
    pnz = p[p > np.finfo(np.float).eps]

    # The naive estimate for the true support of the
    # probability is the number of non-empty bins
    R = pnz.size

    # * Refine (if necessary) the R estimate via bayesian updates
    if R < p.size:
        # Refining is needed
        R_exp = R - np.sum((1 - pnz)**N)

        ΔR_prev = p.size
        ΔR_curr = np.abs(R - R_exp)

        thr = 0.

        while ΔR_curr < ΔR_prev and R + thr < p.size:
            thr += 1
            R_exp = 0.

            # Compute the number of occupied bins
            Γ = thr * (1. - (N / (N + R))**(1. / N))
            P = (1. - Γ) / (N + R) * (pnz * N + 1.)
            R_exp = (1. - (1. - P)**N).sum()

            # Compute the number of non-occupied bins
            P = Γ / thr
            R_exp += thr * (1 - (1 - P)**N)

            # Update the ΔR estimate
            ΔR_prev = ΔR_curr
            ΔR_curr = np.abs(R - R_exp)

        R += thr - (not ΔR_curr < ΔR_prev)

    # * Now that the support is known, compute the bias
    return (R - 1) / (2 * N * loge2)

def H_p(p : np.ndarray) -> np.ndarray:
    '''
        Computes the entropy of a (collection of) probability
        distribution(s) p.

        NOTE: This function normalizes the input.
    '''
    eps = np.finfo(np.float).eps
    
    # Add the batch dimension if missing
    p = np.atleast_2d(p)

    # Normalize the provided distribution
    p_sum = np.sum(p, axis = -1).reshape(*p.shape[:-1], 1)  
    p_sum = np.ma.array(p_sum, mask = p_sum < eps)

    # Extract non-zero values to avoid numerical errors
    nzp = np.ma.array(p, copy = False, mask = p <= eps) / p_sum
    
    out = -np.ma.sum(nzp * np.ma.log2(nzp), axis = -1).squeeze()

    # Return the unmasked version of the array
    return np.ma.getdata(out)

def H(X : np.ndarray, bias : str = None) -> float:
    '''
        Computes the entropy H(X) of a random variable X.

        Input:
        - X[np.ndarray]: Array of shape (n_samples) collecting
          the recorded occurrencies of the random variable X.
        
        Kwargs:
        - bias [str|None]: The type of bias correction to use. If
          None no correction will be applied. Currently supported
          choice for bias are: ('pt',)

        NOTE: For proper use the X array should take discrete
              values (use one of the binning strategy to discretize
              the raw continuous observations of X).

        Return:
        H(X) [float]: Entropy of the random variable X (in bits).
    '''
    if len(X) == 0: return np.nan

    # * Compute the naive entropy
    # Get the id of individual occurrences count
    label_idx = np.unique(X, return_inverse = True)[1]
    pi = np.bincount(label_idx, minlength = max(X) + 1).astype(np.float64)
    
    pnz = pi[pi > np.finfo(np.float).eps]

    pi_sum = np.sum(pnz)

    ent = -np.sum((pnz / pi_sum) * (np.log2(pnz) - log2(pi_sum)))

    if bias is None: bias = 0
    elif bias == 'pt': 
            bias = _pt_bias(pi / pi_sum, X.size)

    else:
        raise ValueError(f'Unsupported bias correction {bias} for entropy.')

    return ent + bias

def H_XY(X : np.ndarray, Y : np.ndarray, kind : str = 'x|y', bias : str = None) -> float:
    '''
        Computes the entropy H(X, Y) of two random variables X and Y.
        The variables are assumed to be discretized appropriately
        so that a contingency matrix can be computed.

        This function offers the calculation of three types of entropies:
        - [x|y]: Entropy of the first variable, conditioned on the second
        - [y|x]: Entropy of the second variable, conditioned on the first
        - [x,y]: Joint entropy of the two variables

        Inputs:
        - X[np.ndarray]: 1D array of discretized occurrencies of random variable X
        - Y[np.ndarray]: 1D array of discretized occurrencies of random variable Y

        Kwargs:
        - kind[str|None]: String representing which type of entropy of two variables to
                          compute. Can be one of: ('x|y', 'y|x', 'x,y')
        - bias[str|None]: String representing which type of bias correction to use.
                          Currently supported biases are: ('pt')

        Return:
        - H_XY[float]: Scalar number represented the computed entropy 
    '''
    if len(X) == 0 or len(Y) == 0: return np.nan
    
    # * Input sanitization
    valid_entr = ('x|y', 'y|x', 'x,y')
    valid_bias = (None, 'pt')
    ermsg_entr = f'Unsupported entropy kind {kind} in H_XY. Kind should be one of {valid_entr}'
    ermsg_bias = f'Unsupported bias correction {bias} in H_XY. Kind should be one of {valid_bias}'
    ermsg_size = f'X and Y arrays should have equal size. Got X:{X.size} and Y:{Y.size}'
    if kind not in valid_entr: raise ValueError(ermsg_entr)
    if bias not in valid_bias: raise ValueError(ermsg_bias)
    if X.size != Y.size:       raise ValueError(ermsg_size)

    # Compute the contingency matrix of X and Y from which all the entropies can be derived
    cmXY = contingency_matrix(X, Y, sparse = True)

    _, _, nzp = find(cmXY)
    pi_sum = np.sum(nzp)
    pXY = cmXY.toarray() 

    # * Select which entropy to compute
    if   kind == 'x|y':
        P = np.sum(pXY / pi_sum, axis = 0)
        ent = (P * H_p(pXY.T / pi_sum).squeeze()).sum()

        pAB = (pXY / np.sum(pXY, axis = 0)).T
        Ns = np.bincount(np.unique(Y, return_inverse = True)[1])

    elif kind == 'y|x':
        P  = np.sum(pXY / pi_sum, axis = 1)
        ent = (P * H_p(pXY / pi_sum).squeeze()).sum()

        pAB = pXY / np.sum(pXY, axis = 1).reshape(pXY.shape[0], 1)
        Ns = np.bincount(np.unique(X, return_inverse = True)[1])

    elif kind == 'x,y':
        ent = -np.sum((nzp / pi_sum) * (np.log2(nzp) - log2(pi_sum)))

        P, pAB, Ns = [1], [pXY.reshape(-1) / pi_sum], [X.size]

    else: raise ValueError('Error in H_XY')

    # * Compute bias
    if bias is None: bias = 0
    elif bias == 'pt':
        bias = np.dot(P, [_pt_bias(p, N) for p, N in zip(pAB, Ns)])
    else:
        raise ValueError(f'Unsupported bias correction {bias} for joint entropy.')

    return ent + bias

def I_XY(
    X : np.ndarray,
    Y : np.ndarray,
    bias : str = None,
    norm : str = None,
    order: str = 'x,y') ->float:
    '''
        This function computes the (normalized) Mutual Information.
        Several possible normalization tecniques are available:
        [Following Wikipedia naming conventions]

        - Cyx: Proficiency (norm with respect to H(X))
        - Cxy: Proficiency (norm with respect to H(Y))
        - R  : Redundancy (norm with respect to H(X) + H(Y))
        - U  : Simmetric redundancy (U = 2R)
        - TC : Total Correlation (norm with respect to min[H(X), H(Y)])
        - IQR: Information Quality Ratio (norm with respect to H(X, Y))
        - PMI: Pearson Mutual Information (norm with respect to sqrt(H(X)H(Y)))

        Inputs:
        - X[np.ndarray]: 1D array of discretized occurrencies of random variable X
        - Y[np.ndarray]: 1D array of discretized occurrencies of random variable Y
        
        Kwargs:
        - bias[str|None]: String representing which type of bias correction to use.
                          Currently supported biases are: ('pt')
        - norm[str|None]: String representing which type of information normalization
                          to use. If None, no normalization will be applied.
        - order[str]: Which expression to use for the computation of I. It can be one of:
                                    'x,y' => I(X,Y) = H(X) - H(X|Y)
                                    'y,x' => I(X,Y) = H(Y) - H(Y|X)
                      In theory these two should be equal, but bias-correction might
                      introduce an asymmetry.

        Returns:
        - nMI[np.ndarray]: Normalized Mutual Information
    '''
    # * Input sanitization
    valid_bias = (None, 'pt')
    valid_norm = (None, 'Cyx', 'Cxy', 'R', 'U', 'TC', 'IQR', 'PMI')
    valid_ordr = ('x,y', 'y,x')
    ermsg_bias = f'Unsupported bias correction {bias} in I_XY. Kind should be one of {valid_bias}'
    ermsg_norm = f'Unsupported normalization {norm} in I_XY. Norm should be one of {valid_norm}'
    ermsg_size = f'X and Y arrays should have equal size. Got X:{X.size} and Y:{Y.size}'
    ermsg_ordr = f'Order should be one of {valid_ordr}. Got {order}'
    if norm  not in valid_norm: raise ValueError(ermsg_norm)
    if bias  not in valid_bias: raise ValueError(ermsg_bias)
    if order not in valid_ordr: raise ValueError(ermsg_ordr)
    if X.size != Y.size:        raise ValueError(ermsg_size)

    # * Compute the entropies needed for the mutual information
    HX, HY  = H(X, bias = bias), H(Y, bias = bias)
    HXY = H_XY(X, Y, kind = 'x|y' if order == 'x,y' else 'y|x', bias = bias)

    I = (HX if order == 'x,y' else HY) - HXY

    eps = np.finfo(np.float).eps

    # * Normalize the mutual information
    if   norm is None:   out = I
    elif norm == 'Cyx' : out = I / (HX + eps)
    elif norm == 'Cxy' : out = I / (HY + eps)
    elif norm == 'R'   : out = I / (HX + HY)
    elif norm == 'U'   : out = 2 * I / (HX + HY)
    elif norm == 'TC'  : out = I / (np.minimum(HX, HY) + eps)
    elif norm == 'IQR' : out = I / H_XY(X, Y, kind = 'x,y', bias = bias)
    elif norm == 'PMI' : out = I / (np.sqrt(HX * HY) + eps)

    else: raise ValueError(f'Unknown normalization {norm}.')

    return out

def nMI(
    HX : np.ndarray, 
    HY : np.ndarray, 
    HXY : np.ndarray,
    kind : str = 'Cyx'
    ) -> np.ndarray:
    '''
        This function computes the normalized Mutual Information.
        Several possible normalization tecniques are available:
        [Following Wikipedia naming conventions]

        - Cyx: Proficiency (norm with respect to H(X))
        - Cxy: Proficiency (norm with respect to H(Y))
        - R  : Redundancy (norm with respect to H(X) + H(Y))
        - U  : Simmetric redundancy (U = 2R)
        - TC : Total Correlation (norm with respect to min[H(X), H(Y)])
        - IQR: Information Quality Ratio (norm with respect to H(X, Y))
        - PMI: Pearson Mutual Information (norm with respect to sqrt(H(X)H(Y)))

        Inputs:
        - HX [np.ndarray]: Entropy of first random variable X
        - HY [np.ndarray]: Entropy of second random variable Y
        - HXY[np.ndarray]: Joint Entropy of X and Y
        - kind [str]: String selecting the desired normalization

        Returns:
        - nMI[np.ndarray]: Normalized Mutual Information
    '''

    I = HX + HY - HXY

    if   kind == 'Cyx' : out = I / (HX + 1e-20)
    elif kind == 'Cxy' : out = I / (HY + 1e-20)
    elif kind == 'R'   : out = I / (HX + HY)
    elif kind == 'U'   : out = 2 * I / (HX + HY)
    elif kind == 'TC'  : out = I / (np.minimum(HX, HY) + 1e-20)
    elif kind == 'IQR' : out = I / HXY
    elif kind == 'PMI' : out = I / (np.sqrt(HX * HY) + 1e-20)
    elif kind == 'none': out = I

    else: raise ValueError(f'Unknown normalization {kind}.')

    return out