from sklearn.metrics import mutual_info_score
import numpy as np
import pandas as pd
from scipy.stats import gaussian_kde, entropy
from sklearn.feature_selection import mutual_info_regression
from sklearn.metrics import mutual_info_score
from sklearn.neighbors import KernelDensity
from sklearn.metrics.cluster import normalized_mutual_info_score

import matplotlib.pyplot as plt
from scipy.stats import spearmanr, pearsonr, ConstantInputWarning
import warnings
import json
from models.entity import Entity
from models.relationship import Relationship 


def get_prior_matrix(root_path):
    with open(f"{root_path}/nodes.json", "r") as f:
        nodes = json.load(f)
    for node in nodes:
        nodes[node] = Entity.from_dict(nodes[node])
    with open(f"{root_path}/edges.json", "r") as f:
        edges = json.load(f)
        for edge in edges:
            edges[edge] = Relationship.from_dict(edges[edge])
    
    name_list = [nodes[node].name for node in nodes]
    # print(name_list)
    llm_prior_df = pd.DataFrame(
        data=np.zeros((len(name_list), len(name_list))), 
        index=name_list, 
        columns=name_list
    )
    # node_df = pd.DataFrame(index=name_list, columns=name_list)
    #default is 1 
    llm_prior_df.loc[name_list, name_list] = 1

    for edge in edges:
        edge_weight = edges[edge].weight
        llm_prior_df.loc[edges[edge].entity_1_name, edges[edge].entity_2_name] = float(edge_weight)
        llm_prior_df.loc[edges[edge].entity_2_name, edges[edge].entity_1_name] = float(edge_weight)

    return llm_prior_df

def fisher_transform(r, eps=1e-8):
    r = np.clip(r, -1 + eps, 1 - eps)
    return 0.5 * np.log((1 + r) / (1 - r))


def freedman_diaconis_bins(data):
    """Dynamic bin calculation using Freedman-Diaconis rule."""
    data = data[~np.isnan(data)]
    if len(data) < 2:
        return 1
    data_range = np.max(data) - np.min(data)
    if data_range == 0:
        return 1
    h = 2 * np.percentile(data, 75) - np.percentile(data, 25) / (len(data) ** (1/3))
    bins = int(np.ceil(data_range / h)) if h > 0 else 1
    return np.clip(bins, 5, 50)  # Limit between 5-50 bins

def estimate_entropy_kde(x, grid_size=1000):
    kde = gaussian_kde(x)
    grid = np.linspace(np.min(x), np.max(x), grid_size)
    p = kde(grid)
    p /= p.sum()  # Normalize
    return entropy(p)


def estimate_mi_knn(x, y, n_neighbors=10):
    bins1 = freedman_diaconis_bins(x)
    bins2 = freedman_diaconis_bins(y)
    bins = max(bins1, bins2)
    hist_2d, _, _ = np.histogram2d(x, y, bins=bins)

                # Mutual Information (unnormalized)
    mi = mutual_info_score(None, None, contingency=hist_2d)
    return mi


def estimate_mi_knn(x, y, n_neighbors=10):
    x = x.reshape(-1, 1)
    return mutual_info_regression(x, y, n_neighbors=n_neighbors, random_state=0)[0]


def estimate_mi_knn(x, y, n_neighbors=10):
    x = x.reshape(-1, 1)
    return mutual_info_regression(x, y, n_neighbors=n_neighbors, random_state=0)[0]

def estimate_mi_kde(x, y, bw=0.3):
    data = np.vstack([x, y]).T
    kde_xy = KernelDensity(kernel='gaussian', bandwidth=bw).fit(data)
    kde_x = KernelDensity(kernel='gaussian', bandwidth=bw).fit(x.reshape(-1, 1))
    kde_y = KernelDensity(kernel='gaussian', bandwidth=bw).fit(y.reshape(-1, 1))
    log_p_xy = kde_xy.score_samples(data)
    log_p_x = kde_x.score_samples(x.reshape(-1, 1))
    log_p_y = kde_y.score_samples(y.reshape(-1, 1))
    return np.mean(log_p_xy - (log_p_x + log_p_y))

def estimate_nmi_kde(x, y):
    mi = estimate_mi_kde(x, y)
    h_x = estimate_entropy_kde(x)
    h_y = estimate_entropy_kde(y)
    return mi / np.sqrt(h_x * h_y)


def bootstrap_mi(x, y, bw=0.3, n_boot=50):
    mis = []
    n = len(x)
    for _ in range(n_boot):
        idx = np.random.randint(0, n, n)
        mis.append(estimate_mi_kde(x[idx], y[idx], bw))
    return np.var(mis)

def get_mi(x, y, method='knn'):
    if method == 'knn':
        mi = estimate_mi_knn(x,y)
    elif method == 'kde':
        mi = estimate_mi_kde(x,y)
    
    return mi



def update_prior(mu0,Sigma0,r1,V1, reg_strength=1e-8):
    # Sigma0 += np.eye(Sigma0.shape[0]) * reg_strength  # Regularize Sigma0
    # V1 += np.eye(V1.shape[0]) * reg_strength
    Sigma0_inv = np.linalg.pinv(Sigma0)
    V1_inv = np.linalg.pinv(V1)
    Sigma1 = np.linalg.pinv(Sigma0_inv + V1_inv)
    # print('mu0',mu0)
    # print('Sigma0',Sigma0)
    # print('r1',r1)
    # print('V1',V1)
    mu1 = Sigma1 @ (Sigma0_inv @ mu0 + V1_inv @ r1)
    # print('mu1',mu1)
    return mu1,Sigma1

def compute_y_and_rank(x, S, stretch=False):
    xmin, xmax = min(x), max(x)
    # 1) normalize to [0,1]
    if xmax == xmin:
        z = [0.5] * len(x)
    else:
        z = [(xi - xmin) / (xmax - xmin) for xi in x]
    # 2) mix forward/reverse
    y = [S * zi + (1 - S) * (1 - zi) for zi in z]
    # 3) optional stretch to full [0,1]
    if stretch:
        ymin, ymax = min(y), max(y)
        y = [(yi - ymin) / (ymax - ymin) for yi in y]
    # 4) ranking: 1 = highest y
    ranks = pd.Series(y).rank(ascending=False, method='min').astype(int).tolist()
    return z, y, ranks



def freedman_diaconis_bins(data):
    """Dynamic bin calculation using Freedman-Diaconis rule."""
    data = data[~np.isnan(data)]
    if len(data) < 2:
        return 1
    data_range = np.max(data) - np.min(data)
    if data_range == 0:
        return 1
    h = 2 * np.percentile(data, 75) - np.percentile(data, 25) / (len(data) ** (1/3))
    bins = int(np.ceil(data_range / h)) if h > 0 else 1
    return np.clip(bins, 5, 50)  # Limit between 5-50 bins




def compute_mutual_info_stats2(df, metrics, focus_metrics, bin_method='fd'):
    """
    Compute mutual information for each user across all metric pairs
    """

    user_data = df
    # print(user_id)
    # Initialize mutual information matrix for this user
    n_metrics = len(metrics)
    # print(n_metrics)
    mi_matrix = np.ones((n_metrics, n_metrics))
    mi_var_matrix = np.ones((n_metrics, n_metrics))
    mi_sample_size_matrix = np.zeros((n_metrics, n_metrics))
    boot_matrix = {}
    # Compute mutual information for each pair of metrics
    for i in range(n_metrics):
        for j in range(i+1, n_metrics):
            metric1 = metrics[i]
            metric2 = metrics[j]
            
            # Get data for both metrics
            data1 = user_data[metric1].values
            data2 = user_data[metric2].values
            
            # Remove any rows with NaN values
            valid_mask = ~(np.isnan(data1) | np.isnan(data2))
            data1 = data1[valid_mask]
            data2 = data2[valid_mask]
            # print('samples used to compute mutual info', min(len(data1),len(data2)))
            if len(data1) > 0 and len(data2) > 0:
                # Determine bins dynamically
                if isinstance(bin_method, int):
                    bins = bin_method
                elif bin_method == 'fd':
                    bins1 = freedman_diaconis_bins(data1)
                    bins2 = freedman_diaconis_bins(data2)
                    bins = max(bins1, bins2)
                elif bin_method == 'sqrt':
                    bins = int(np.sqrt(len(data1)))
                else:
                    raise ValueError("Invalid bin_method")
                    
                # Compute 2D histogram
                
                
                # Compute mutual information and normalize it
                mi_list = []
                for _ in range(50):
                    min_len = min(len(data1), len(data2))
                    random_indices = np.random.choice(min_len, min_len)   
                    hist_2d, _, _ = np.histogram2d(data1[random_indices], data2[random_indices], bins=bins)
                    mi_unnormalized = mutual_info_score(None, None, contingency=hist_2d)

                # print('mi_list')
                # print(mi_list)
                    
                    # print(mi)
                    hist1, _ = np.histogram(data1, bins=bins)
                    hist2, _ = np.histogram(data2, bins=bins)
                    
                    prob1 = hist1 / hist1.sum()
                    prob2 = hist2 / hist2.sum()
                    
                    h1 = entropy(prob1, base=np.e)
                    h2 = entropy(prob2, base=np.e)
                    
                    # Normalized mutual information
                    normalized_mi = mi_unnormalized / min(h1, h2) if min(h1, h2) > 0 else 0
                
                    mi_list.append(normalized_mi) 
                print(mi_list)
                var = np.cov(mi_list)*1 
                boot_matrix[(i,j)] = mi_list
                boot_matrix[(j,i)] = mi_list
                mi = np.mean(mi_list)

                mi_matrix[i, j] = mi
                mi_matrix[j, i] = mi
                mi_var_matrix[i,j] = var
                mi_var_matrix[j,i] = var
                mi_sample_size_matrix[i,j] = min(len(data1),len(data2))
                mi_sample_size_matrix[j,i] = min(len(data1),len(data2))
    
    mi_df = pd.DataFrame(mi_matrix, index=metrics, columns=metrics)
    mi_var_df = pd.DataFrame(mi_var_matrix, index=metrics, columns=metrics)
    mi_sample_size_df = pd.DataFrame(mi_sample_size_matrix, index=metrics, columns=metrics)
    
        # # Store sum of mutual information for this user
        # cols = [col for col in mi_df.columns if any(col.startswith(y) for y in focus_metrics)]
        # # print(cols)
        # # print(mi_df.shape)
        # mi_df = mi_df.loc[cols]#.shape
        # # print(mi_df)
        # user_mi_scores[user_id] =  mi_df#np.sum(mi_df.values)
    
    return mi_df, mi_var_df,mi_sample_size_df, boot_matrix #user_mi_scores


def compute_mutual_info_stats(df, metrics, focus_metrics, bin_method='fd'):
    """
    Compute mutual information for each user across all metric pairs
    """

    user_data = df
    # print(user_id)
    # Initialize mutual information matrix for this user
    n_metrics = len(metrics)
    # print(n_metrics)
    mi_matrix = np.ones((n_metrics, n_metrics))
    mi_var_matrix = np.ones((n_metrics, n_metrics))
    mi_sample_size_matrix = np.zeros((n_metrics, n_metrics))
    boot_matrix = {}
    # Compute mutual information for each pair of metrics
    for i in range(n_metrics):
        for j in range(i+1, n_metrics):
            metric1 = metrics[i]
            metric2 = metrics[j]
            
            # Get data for both metrics
            data1 = user_data[metric1].values
            data2 = user_data[metric2].values
            
            # Remove any rows with NaN values
            valid_mask = ~(np.isnan(data1) | np.isnan(data2))
            data1 = data1[valid_mask]
            data2 = data2[valid_mask]
            assert len(data1) == len(data2)
            # print('samples used to compute mutual info', min(len(data1),len(data2)))
            if len(data1) < 3 or np.nanstd(data1) == 0 or np.nanstd(data2) == 0:
                print("One of the inputs is constant or too small to compute mutual info",metric1,metric2)
                mi_matrix[i, j] = np.nan
                mi_matrix[j, i] = np.nan
                mi_var_matrix[i,j] = np.nan
                mi_var_matrix[j,i] = np.nan
                mi_sample_size_matrix[i,j] = len(data1)
                mi_sample_size_matrix[j,i] = len(data1)
            else:
                # Determine bins dynamically
          
                if isinstance(bin_method, int):
                    bins = bin_method
                elif bin_method == 'fd':
                    bins1 = freedman_diaconis_bins(data1)
                    bins2 = freedman_diaconis_bins(data2)
                    bins = max(bins1, bins2)
                elif bin_method == 'sqrt':
                    bins = int(np.sqrt(len(data1)))
                else:
                    raise ValueError("Invalid bin_method")
                    
                
                # Compute mutual information and normalize it
                mi_list = []
                for _ in range(50):
                    min_len = min(len(data1), len(data2))
                    random_indices = np.random.choice(min_len, min_len)   
                    hist_2d, _, _ = np.histogram2d(data1[random_indices], data2[random_indices], bins=bins)
                    mi_unnormalized = mutual_info_score(None, None, contingency=hist_2d)

                    hist1, _ = np.histogram(data1, bins=bins)
                    hist2, _ = np.histogram(data2, bins=bins)
                    
                    prob1 = hist1 / hist1.sum()
                    prob2 = hist2 / hist2.sum()
                    
                    h1 = entropy(prob1, base=np.e)
                    h2 = entropy(prob2, base=np.e)
                    
                    # Normalized mutual information
                    normalized_mi = mi_unnormalized / min(h1, h2) if min(h1, h2) > 0 else 0
                
                    mi_list.append(normalized_mi) 

                mi_var = np.var(mi_list)
                boot_matrix[(i,j)] = mi_list
                boot_matrix[(j,i)] = mi_list
                mi = np.mean(mi_list)

                mi_matrix[i, j] = mi
                mi_matrix[j, i] = mi
                mi_var_matrix[i,j] = mi_var
                mi_var_matrix[j,i] = mi_var
                mi_sample_size_matrix[i,j] = len(data1)
                mi_sample_size_matrix[j,i] = len(data1)
        

    mi_df = pd.DataFrame(mi_matrix, index=metrics, columns=metrics)
    mi_var_df = pd.DataFrame(mi_var_matrix, index=metrics, columns=metrics)
    mi_sample_size_df = pd.DataFrame(mi_sample_size_matrix, index=metrics, columns=metrics)
    
        # # Store sum of mutual information for this user
        # cols = [col for col in mi_df.columns if any(col.startswith(y) for y in focus_metrics)]
        # # print(cols)
        # # print(mi_df.shape)
        # mi_df = mi_df.loc[cols]#.shape
        # # print(mi_df)
        # user_mi_scores[user_id] =  mi_df#np.sum(mi_df.values)
    
    return mi_df, mi_var_df,mi_sample_size_df, boot_matrix #user_mi_scores


def compute_correlation_stats(df, metrics, type = 'spearman'):
    """
    Compute mutual information for each user across all metric pairs
    """

    user_data = df
    # print(user_id)
    # Initialize mutual information matrix for this user
    n_metrics = len(metrics)
    # print(n_metrics)
    cor_matrix = np.zeros((n_metrics, n_metrics))
    p_matrix = np.zeros((n_metrics, n_metrics))
    sample_size_matrix = np.zeros((n_metrics, n_metrics))
    boot_matrix = {}
    # Compute mutual information for each pair of metrics
    for i in range(n_metrics):
        for j in range(i+1, n_metrics):
            metric1 = metrics[i]
            metric2 = metrics[j]
            # print(metric1, metric2)
            # Get data for both metrics
            data1 = user_data[metric1].values
            data2 = user_data[metric2].values
            
            # Remove any rows with NaN values
            valid_mask = ~(np.isnan(data1) | np.isnan(data2))
            data1 = data1[valid_mask]
            data2 = data2[valid_mask]
            assert len(data1) == len(data2)
            if len(data1) < 3 or np.nanstd(data1) == 0 or np.nanstd(data2) == 0:
                print("One of the inputs is constant or too small to compute correlation",metric1,metric2)
              
                cor_matrix[i, j] = np.nan
                cor_matrix[j, i] = np.nan
                p_matrix[i,j] = np.nan
                p_matrix[j,i] = np.nan
                sample_size_matrix[i,j] = len(data1)
                sample_size_matrix[j,i] = len(data1)
            else:
                # print(data1,data2)
                if type == 'spearman':
                    rho, p_val = spearmanr(data1, data2)
                elif type == 'pearson':
                    rho, p_val = pearsonr(data1, data2)
       
                cor_matrix[i, j] = rho
                cor_matrix[j, i] = rho
                p_matrix[i,j] = p_val
                p_matrix[j,i] = p_val
                sample_size_matrix[i,j] = len(data1)
                sample_size_matrix[j,i] = len(data1)


                # with warnings.catch_warnings(record=True) as w:
                #     warnings.simplefilter("always", ConstantInputWarning)

                #     if type == 'spearman':
                #         rho, p_val = spearmanr(data1, data2)
                #     elif type == 'pearson':
                #         rho, p_val = pearsonr(data1, data2)
                #     else:
                #         raise ValueError(f"Unsupported correlation type: {type}")

                #     # Check for ConstantInputWarning
                #     for warning in w:
                #         if issubclass(warning.category, ConstantInputWarning):
                #             print(f"⚠️ ConstantInputWarning triggered for: {metric1}, {metric2}")
                #             print("data1:", data1)
                #             print("data2:", data2)
                #             print(rho,p_val)

    
    cor_df = pd.DataFrame(cor_matrix, index=metrics, columns=metrics)
    p_df = pd.DataFrame(p_matrix, index=metrics, columns=metrics)
    sample_size_df = pd.DataFrame(sample_size_matrix, index=metrics, columns=metrics)
    
        # # Store sum of mutual information for this user
        # cols = [col for col in mi_df.columns if any(col.startswith(y) for y in focus_metrics)]
        # # print(cols)
        # # print(mi_df.shape)
        # mi_df = mi_df.loc[cols]#.shape
        # # print(mi_df)
        # user_mi_scores[user_id] =  mi_df#np.sum(mi_df.values)
    
    return cor_df, p_df, sample_size_df
