import pandas as pd
import numpy as np
from scipy.stats import spearmanr, sem, kendalltau
import config.config as c

def compute_rank_correlation(values_dict, utility_names, rank_metric="kendall"):
    num_utilities = len(utility_names)
    corr_matrix = pd.DataFrame(
        np.zeros((num_utilities, num_utilities)),
        columns=utility_names,
        index=utility_names,
    )

    for i in range(1, len(utility_names) + 1):
        for j in range(1, len(utility_names) + 1):
            if i - 1 == j - 1:
                corr_matrix.iloc[i - 1, j - 1] = 1
            else:
                values_i = values_dict[i]
                values_j = values_dict[j]
                if values_i is not None and values_j is not None:
                    if rank_metric == "kendall":
                        corr, _ = kendalltau(values_i, values_j)
                    elif rank_metric == "spearman":
                        corr, _ = spearmanr(values_i, values_j)
                    corr_matrix.iloc[i - 1, j - 1] = corr
                    corr_matrix.iloc[j - 1, i - 1] = corr

    return corr_matrix


def compute_rank_correlation_from_values(all_values, changing_param = c.CHANGING_PARAMS):
    results = {}
    utility_names = [
        f"{utility['utility_name']}_{utility['threshold']}"
        if "threshold" in utility
        else utility["utility_name"]
        for utility in changing_param["utility"]
    ]

    for dataset_name, values_runs in all_values.items():
        n_runs = len(next(iter(values_runs.values())))
        value_names = values_runs.keys()
        corrs_all_semivalues = {name: [] for name in value_names}

        for run in range(n_runs):
            for value_name in value_names:
                values = values_runs[value_name][run]
                corr_matrix = compute_rank_correlation(values, utility_names)
                corrs_all_semivalues[value_name].append(corr_matrix)

        for semivalue_name, corrs in corrs_all_semivalues.items():
            corrs = np.array([corr.values for corr in corrs])
            mean_corr = np.mean(corrs, axis=0)
            sem_corr = sem(corrs, axis=0)

            mean_corr_df = pd.DataFrame(
                mean_corr, columns=utility_names, index=utility_names
            )
            sem_corr_df = pd.DataFrame(
                sem_corr, columns=utility_names, index=utility_names
            )

            results.setdefault(dataset_name, {})[semivalue_name] = {
                "mean": mean_corr_df,
                "error": sem_corr_df,
            }

    return results


def display_results(results):
    for dataset_name, semivalue_results in results.items():
        for semi_value_name, result in semivalue_results.items():
            mean_corr = result["mean"]
            error_corr = result["error"]

            formatted_matrix = pd.DataFrame(
                index=mean_corr.index, columns=mean_corr.columns
            )
            for row in mean_corr.index:
                for col in mean_corr.columns:
                    mean_value = mean_corr.loc[row, col]
                    error_value = error_corr.loc[row, col]
                    formatted_matrix.loc[row, col] = (
                        f"{mean_value:.2f} (±{error_value:.2f})"
                    )

            print(
                f"Rank correlations with standard error for {dataset_name} ({semi_value_name}):"
            )
            print(formatted_matrix.to_string())
            print("\n" + "=" * 50 + "\n")