from sklearn.feature_selection import mutual_info_regression
from sklearn.linear_model import LinearRegression
import numpy as np
from tqdm import tqdm
from scipy.stats import spearmanr
import torch

def continuous_mutual_info(mus, ys):
    """Compute continuous mutual information.
    Args:
        mus: np.array num_points x num_points
        ys: np.array num_points x num_attributes
    """
    # num_codes = mus.shape[1]
    # num_attributes = ys.shape[1]
    # m = np.zeros([num_codes, num_attributes])
    # for i in tqdm(range(num_attributes)):
    #     m[:, i] = mutual_info_regression(mus, ys[:, i])
    m = mutual_info_regression(mus, ys)
    return m

def continuous_entropy(ys):
    """Compute continuous mutual entropy
    Args:
        ys: np.array num_points x num_attributes
    """
    # num_factors = ys.shape[1]
    # h = np.zeros(num_factors)
    # for j in tqdm(range(num_factors)):
    #     h[j] = mutual_info_regression(
    #         ys[:, j].reshape(-1, 1), ys[:, j]
    #     )
    h = mutual_info_regression(ys.reshape(-1, 1), ys)
    return h

def compute_interpretability_metric(latent_codes_input, attributes, attr_list):
    """
    Computes the interpretability metric for each attribute
    Args:
        latent_codes: np.array num_points x num_codes
        attributes: np.array num_points x num_attributes
        attr_list: list of string corresponding to attribute names
    """
    interpretability_metrics = {}
    total = 0
    latent_codes = latent_codes_input
    for i, attr_name in tqdm(enumerate(attr_list)):
        if attr_name == 'MIC E.coli' or attr_name == 'MIC S.aureus':
            finite_mask = np.isfinite(attributes[:,i])
            latent_codes = latent_codes_input[finite_mask,:]
            attr_labels = attributes[finite_mask, i]
        else:
            attr_labels = attributes[:, i]
        mutual_info = mutual_info_regression(latent_codes, attr_labels)
        dim = np.argmax(mutual_info)

        # compute linear regression score
        reg = LinearRegression().fit(latent_codes[:, dim:dim + 1], attr_labels)
        score = reg.score(latent_codes[:, dim:dim + 1], attr_labels)
        interpretability_metrics[attr_name] = (int(dim), float(score))
        total += float(score)
    interpretability_metrics["mean"] = (-1, total / len(attr_list))
    return interpretability_metrics


def compute_mig(latent_codes, attributes, attr_list):
    """
    Computes the mutual information gap (MIG) metric
    Args:
        latent_codes: np.array num_points x num_codes
        attributes: np.array num_points x num_attributes
    """
    score_dict = {}
    mig_scores = None
    padded_mig_partly = np.array([])
    for i, attr_name in tqdm(enumerate(attr_list)):
        if attr_name == 'MIC E.coli' or attr_name == 'MIC S.aureus':
            finite_mask = np.isfinite(attributes[:,i])
            m = continuous_mutual_info(latent_codes[finite_mask,:], attributes[finite_mask,i])
            entropy = continuous_entropy(attributes[finite_mask,i])
        else:
            m = continuous_mutual_info(latent_codes, attributes[:,i])
            entropy = continuous_entropy(attributes[:,i])
        sorted_m = np.sort(m, axis=0)[::-1]
        mig_scores_partly = np.divide(sorted_m[0] - sorted_m[1], entropy[:]).reshape(-1, 1)
        if mig_scores is not None:
                rows_to_pad = mig_scores.shape[0] - mig_scores_partly.shape[0]

                padded_mig_partly = np.pad(mig_scores_partly, ((0, rows_to_pad), (0, 0)), 
                                        mode='constant', constant_values=np.nan)
        else:
                padded_mig_partly = mig_scores_partly
        if mig_scores is None:
            mig_scores = mig_scores_partly
        else:
            mig_scores = np.column_stack((mig_scores, padded_mig_partly))
        score_dict[attr_name] = padded_mig_partly
    score_dict['mean'] = np.nanmean(mig_scores)
    return score_dict


def compute_modularity(latent_codes, attributes, attr_list):
    """
    Computes the modularity metric
    Args:
        latent_codes: np.array num_points x num_codes
        attributes: np.array num_points x num_attributes
    """
    scores = {}
    mi = None
    padded_mi_partly = np.array([])
    for i, attr_name in tqdm(enumerate(attr_list)):
        if attr_name == 'MIC E.coli' or attr_name == 'MIC S.aureus':
            finite_mask = np.isfinite(attributes[:,i])
            mi_partly = continuous_mutual_info(latent_codes[finite_mask,:], attributes[finite_mask,i]).reshape(-1, 1)
        else:
            mi_partly = continuous_mutual_info(latent_codes, attributes[:,i]).reshape(-1, 1)
        if mi is not None:
            rows_to_pad = mi.shape[0] - mi_partly.shape[0]

            padded_mi_partly = np.pad(mi_partly, ((0, rows_to_pad), (0, 0)), 
                                        mode='constant', constant_values=np.nan)
        else:
            padded_mi_partly = mi_partly
        if mi is None:
            mi = mi_partly
        else:
            mi = np.column_stack((mi, padded_mi_partly))
        modularity = _modularity(mi_partly.reshape(-1, 56))
        scores[attr_name] = modularity.item()
    scores['mean'] = np.nanmean(_modularity(mi))
    return scores


def _modularity(mutual_information):
    """
    Computes the modularity from mutual information.
    Args:
        mutual_information: np.array num_codes x num_attributes
    """
    squared_mi = np.square(mutual_information)
    max_squared_mi = np.max(squared_mi, axis=1)
    numerator = np.sum(squared_mi, axis=1) - max_squared_mi
    denominator = max_squared_mi * (squared_mi.shape[1] - 1.)
    delta = numerator / denominator
    modularity_score = 1. - delta
    index = (max_squared_mi == 0.)
    modularity_score[index] = 0.
    # return np.mean(modularity_score)
    return modularity_score

def compute_correlation_score(latent_codes, attributes, attr_list):
    """
    Computes the correlation score
    Args:
        latent_codes: np.array num_points x num_codes
        attributes: np.array num_points x num_attributes
    """
    corr_matrix = _compute_correlation_matrix(latent_codes, attributes, attr_list)
    scores = {}
    for i, attr_name in tqdm(enumerate(attr_list)):
        scores[attr_name] = np.max(corr_matrix, axis=0)[i]
    scores['mean'] = np.mean(np.max(corr_matrix, axis=0))
    return scores


def _compute_correlation_matrix(mus, ys, attr_list):
    """
    Compute correlation matrix for correlation score metric
    """
    num_latent_codes = mus.shape[1]
    num_attributes = ys.shape[1]
    score_matrix = np.zeros([num_latent_codes, num_attributes])
    for i in tqdm(range(num_latent_codes)):
        for j in range(num_attributes):
            if attr_list[j] == 'MIC E.coli' or attr_list[j] == 'MIC S.aureus':
                finite_mask = np.isfinite(ys[:,j])
                mu_i = mus[finite_mask, i]
                y_j = ys[finite_mask, j]
            else:
                mu_i = mus[:, i]
                y_j = ys[:, j]
            rho, p = spearmanr(mu_i, y_j)
            if p <= 0.05:
                score_matrix[i, j] = np.abs(rho)
            else:
                score_matrix[i, j] = 0.
    return score_matrix

def _compute_avg_diff_top_two(matrix):
    sorted_matrix = np.sort(matrix, axis=0)
    return sorted_matrix[-1, :] - sorted_matrix[-2, :]

def compute_sap_score(latent_codes, attributes, attr_list):
    """
    Computes the separated attribute predictability (SAP) score
    Args:
        latent_codes: np.array num_points x num_codes
        attributes: np.array num_points x num_attributes
    """
    score_matrix = _compute_score_matrix(latent_codes, attributes, attr_list)
    # Score matrix should have shape [num_codes, num_attributes].
    assert score_matrix.shape[0] == latent_codes.shape[1]
    assert score_matrix.shape[1] == attributes.shape[1]
    sap_scores = _compute_avg_diff_top_two(score_matrix)
    scores = {}
    for i, attr_name in tqdm(enumerate(attr_list)):
        scores[attr_name] = sap_scores[i]
    scores['mean'] = np.mean(sap_scores)    
    return scores


def _compute_score_matrix(mus, ys, attr_list):
    """
    Compute score matrix for sap score computation.
    """
    num_latent_codes = mus.shape[1]
    num_attributes = ys.shape[1]
    score_matrix = np.zeros([num_latent_codes, num_attributes])
    for i in tqdm(range(num_latent_codes)):
        for j in range(num_attributes):
            if attr_list[j] == 'MIC E.coli' or attr_list[j] == 'MIC S.aureus':
                finite_mask = np.isfinite(ys[:,j])
                mu_i = mus[finite_mask, i]
                y_j = ys[finite_mask, j]
            else:
                mu_i = mus[:, i]
                y_j = ys[:, j]
            # Attributes are considered continuous.
            cov_mu_i_y_j = np.cov(mu_i, y_j, ddof=1)
            cov_mu_y = cov_mu_i_y_j[0, 1] ** 2
            var_mu = cov_mu_i_y_j[0, 0]
            var_y = cov_mu_i_y_j[1, 1]
            if var_mu > 1e-12:
                score_matrix[i, j] = cov_mu_y * 1. / (var_mu * var_y)
            else:
                score_matrix[i, j] = 0.
    return score_matrix

def extract_relevant_attributes(labels, reg_dim): 
    attr_list = ['Length', 'Charge', 'Hydrophobicity', 'MIC E.coli', 'MIC S.aureus', 'Nontoxicity']
    attr_final = []
    for i in reg_dim:
        attr_final.append(attr_list[i])
    attr_labels = labels[:, reg_dim]
    return attr_labels, attr_final 
