import torch
import numpy as np
from scipy.optimize import linear_sum_assignment
import sklearn
from torch import vmap
from torch.func import jacfwd


def hungarian_algorithm(cost_matrix):
    # List of tuples of size 2 containing flat arrays
    indices = list(map(linear_sum_assignment, cost_matrix.cpu().detach().numpy()))
    indices = torch.LongTensor(np.array(indices))
    smallest_cost_matrix = torch.stack(
        [
            cost_matrix[i][indices[i, 0], indices[i, 1]]
            for i in range(cost_matrix.shape[0])
        ]
    )
    device = cost_matrix.device
    return smallest_cost_matrix.to(device) * -1, indices.to(device)

def mcc(Z: torch.Tensor, hZ: torch.Tensor) -> float:
    # Latent dim
    d = Z.shape[1]

    # Get matrix of MCC scores
    mcc_mat = (torch.abs(torch.corrcoef(torch.cat((Z.T, hZ.T), 0))))[:d, d:]

    # Resolve permutation
    _, inds = hungarian_algorithm(mcc_mat.view(1, d, d) * -1)
    perm_mcc_mat = mcc_mat[:, inds[0][1]]

    # Get on-diagonal correlation
    mcc_on = perm_mcc_mat.diag()

    # Get MCC
    mcc_score = mcc_on.mean().item()
    
    return mcc_mat, mcc_score

def standardize(X, mean=None, std=None):
    if mean is None:
        mean = X.mean(dim=0, keepdim=True)
    if std is None:
        std = X.std(dim=0, keepdim=True)
    assert std.min() > 0, "Standard deviation must be positive"
    return (X - mean) / std, mean, std

def rbf_kernel(X1, X2, gamma=1.0):
    X1_sq = X1.pow(2).sum(1).view(-1, 1)
    X2_sq = X2.pow(2).sum(1).view(1, -1)
    dist_sq = X1_sq + X2_sq - 2 * X1 @ X2.T
    return torch.exp(-gamma * dist_sq)

def kernel_ridge_regression(X_train, y_train, X_test, alpha=1e-3, gamma=1.0, standardize_y=True):
    device = X_train.device

    # Standardize X
    X_train_std, X_mean, X_std = standardize(X_train)
    X_test_std, _, _ = standardize(X_test, mean=X_mean, std=X_std)

    # Optional: Standardize y
    if standardize_y:
        y_mean = y_train.mean(dim=0, keepdim=True)
        y_std = y_train.std(dim=0, keepdim=True).clamp_min(1e-6)
        y_train_std = (y_train - y_mean) / y_std
    else:
        y_train_std = y_train

    n = X_train.shape[0]

    # Compute regularized kernel matrix
    K = rbf_kernel(X_train_std, X_train_std, gamma=gamma)
    K_reg = K + alpha * torch.eye(n, device=device)

    # Solve the kernel ridge regression problem
    alpha_coef = torch.linalg.solve(K_reg, y_train_std)

    # Predict
    K_test = rbf_kernel(X_test_std, X_train_std, gamma=gamma)
    y_pred_std = K_test @ alpha_coef

    # De-standardize predictions
    if standardize_y:
        y_pred = y_pred_std * y_std + y_mean
    else:
        y_pred = y_pred_std

    return y_pred

def r2(Z: torch.Tensor, hZ: torch.Tensor) -> np.ndarray:
    d = Z.shape[1]

    # Initialize matrix of R2 scores
    r2_mat = np.zeros((d, d))

    # Use KRR to predict ground-truth from inferred latents
    for i in range(d):
        for j in range(d):
            ZS = Z[:, i].reshape(-1, 1)
            hZS = hZ[:, j].reshape(-1, 1)

            # Split data into train and eval sets (90% train, 10% eval)
            ZS_train, ZS_eval = torch.split(ZS, [int(0.9 * len(ZS)), len(ZS) - int(0.9 * len(ZS))])
            hZS_train, hZS_eval = torch.split(hZS, [int(0.9 * len(hZS)), len(hZS) - int(0.9 * len(hZS))])

            # Fit KRR model
            hz_pred_val = kernel_ridge_regression(ZS_train, hZS_train, ZS_eval)

            # Populate R2 score matrix
            hz_eval = hZS_eval.cpu().numpy()
            hz_pred_val = hz_pred_val.cpu().numpy()
            r2_mat[i, j] = sklearn.metrics.r2_score(hz_eval, hz_pred_val)

    # Get matrix of R2 scores (non-negative entries only)
    r2_mat = torch.nn.functional.relu(torch.from_numpy(r2_mat))

    # Resolve permutation
    r2_row_ind, r2_col_ind = linear_sum_assignment(-r2_mat)
    r2_score = r2_mat[r2_row_ind, r2_col_ind].mean().item()

    return r2_mat, r2_score

def eval_model(model, data_loader, device, estimated_rho):
    model.eval()
    with torch.no_grad():
        run_recon_loss, run_vol, run_jacnorm, run_jacnorm_reg, run_ima_contrast = 0, 0, 0, 0, 0
        b_it = 0
        for x, z in data_loader:
            x = x.to(device)
            z = z.to(device)
            sh, xh = model(x)

            # Get reconstruction loss
            run_recon_loss += ((x - xh).square().mean()).item()

            # Get Jacobian volume and norm
            jac = vmap(jacfwd(model.decoder))(sh.flatten(1))
            run_vol += torch.mean(torch.det(jac.transpose(-1, -2) @ jac)).cpu().item()
            G = jac.shape[-1] * torch.eye(jac.shape[-1], device=device) - \
                torch.ones(jac.shape[-1], jac.shape[-1], device=device)
            run_trace_vol = torch.diagonal(G @ jac.transpose(-1, -2) @ jac, dim1=-2, dim2=-1).sum(-1).mean()
            run_jacnorm += jac.abs().sum(dim=(1, 2)).mean().cpu()
            run_ima_contrast += (-0.5 * torch.logdet(jac.transpose(-1, -2) @ jac) + torch.sum(torch.log(jac.norm(dim=1)), dim=1)).mean()
            if estimated_rho != 0:
                run_jacnorm_reg += torch.nn.functional.softplus(run_jacnorm - estimated_rho).mean()

            # Save latents
            if b_it == 0:
                Z = z
                Zh = sh
            else:
                Z = torch.cat((Z, z))
                Zh = torch.cat((Zh, sh))

            b_it += 1

    # MCC score
    mcc_mat, mcc_score = mcc(Z, Zh)

    # R2 score
    r2_mat, r2_score = r2(Z, Zh)

    return (run_recon_loss / b_it), (run_vol / b_it), (run_trace_vol / b_it), (run_jacnorm / b_it), (run_jacnorm_reg / b_it), (run_ima_contrast / b_it),\
                r2_score, r2_mat, mcc_score, mcc_mat