import numpy as np
import sklearn
import torch
import numpy as np
from scipy.optimize import linear_sum_assignment


def hungarian_algorithm(cost_matrix):
    # List of tuples of size 2 containing flat arrays
    indices = list(map(linear_sum_assignment, cost_matrix.cpu().detach().numpy()))
    indices = torch.LongTensor(np.array(indices))
    smallest_cost_matrix = torch.stack(
        [
            cost_matrix[i][indices[i, 0], indices[i, 1]]
            for i in range(cost_matrix.shape[0])
        ]
    )
    device = cost_matrix.device
    return smallest_cost_matrix.to(device) * -1, indices.to(device)

def mcc(Z: torch.Tensor, hZ: torch.Tensor) -> float:
    # Latent dim
    d = Z.shape[1]

    # Get matrix of MCC scores
    mcc_mat = (torch.abs(torch.corrcoef(torch.cat((Z.T, hZ.T), 0))))[:d, d:]

    # Resolve permutation
    _, inds = hungarian_algorithm(mcc_mat.view(1, d, d) * -1)
    perm_mcc_mat = mcc_mat[:, inds[0][1]]

    # Get on-diagonal correlation
    mcc_on = perm_mcc_mat.diag()

    # Get MCC
    mcc_score = mcc_on.mean().item()
    
    return mcc_mat, mcc_score

def standardize(X, mean=None, std=None):
    if mean is None:
        mean = X.mean(dim=0, keepdim=True)
    if std is None:
        std = X.std(dim=0, keepdim=True)
    assert std.min() > 0, "Standard deviation must be positive"
    return (X - mean) / std, mean, std

def rbf_kernel(X1, X2, gamma=1.0):
    X1_sq = X1.pow(2).sum(1).view(-1, 1)
    X2_sq = X2.pow(2).sum(1).view(1, -1)
    dist_sq = X1_sq + X2_sq - 2 * X1 @ X2.T
    return torch.exp(-gamma * dist_sq)

def kernel_ridge_regression(X_train, y_train, X_test, alpha=1e-3, gamma=1.0, standardize_y=True):
    device = X_train.device

    # Standardize X
    X_train_std, X_mean, X_std = standardize(X_train)
    X_test_std, _, _ = standardize(X_test, mean=X_mean, std=X_std)

    # Optional: Standardize y
    if standardize_y:
        y_mean = y_train.mean(dim=0, keepdim=True)
        y_std = y_train.std(dim=0, keepdim=True).clamp_min(1e-6)
        y_train_std = (y_train - y_mean) / y_std
    else:
        y_train_std = y_train

    n = X_train.shape[0]

    # Compute regularized kernel matrix
    K = rbf_kernel(X_train_std, X_train_std, gamma=gamma)
    K_reg = K + alpha * torch.eye(n, device=device)

    # Solve the kernel ridge regression problem
    alpha_coef = torch.linalg.solve(K_reg, y_train_std)

    # Predict
    K_test = rbf_kernel(X_test_std, X_train_std, gamma=gamma)
    y_pred_std = K_test @ alpha_coef

    # De-standardize predictions
    if standardize_y:
        y_pred = y_pred_std * y_std + y_mean
    else:
        y_pred = y_pred_std

    return y_pred

def r2(Z: torch.Tensor, hZ: torch.Tensor) -> np.ndarray:
    d = Z.shape[1]

    # Initialize matrix of R2 scores
    r2_mat = np.zeros((d, d))

    # Use KRR to predict ground-truth from inferred latents
    for i in range(d):
        for j in range(d):
            ZS = Z[:, i].reshape(-1, 1)
            hZS = hZ[:, j].reshape(-1, 1)

            # Split data into train and eval sets (90% train, 10% eval)
            ZS_train, ZS_eval = torch.split(ZS, [int(0.9 * len(ZS)), len(ZS) - int(0.9 * len(ZS))])
            hZS_train, hZS_eval = torch.split(hZS, [int(0.9 * len(hZS)), len(hZS) - int(0.9 * len(hZS))])

            # Fit KRR model
            hz_pred_val = kernel_ridge_regression(ZS_train, hZS_train, ZS_eval)

            # Populate R2 score matrix
            hz_eval = hZS_eval.cpu().numpy()
            hz_pred_val = hz_pred_val.cpu().numpy()
            r2_mat[i, j] = sklearn.metrics.r2_score(hz_eval, hz_pred_val)

    # Get matrix of R2 scores (non-negative entries only)
    r2_mat = torch.nn.functional.relu(torch.from_numpy(r2_mat))

    # Resolve permutation
    r2_row_ind, r2_col_ind = linear_sum_assignment(-r2_mat)
    r2_score = r2_mat[r2_row_ind, r2_col_ind].mean().item()

    return r2_mat, r2_score
