import numpy as np
import torch
import torch.nn as nn
from src.utils.set_seeds import seed_everything
from src.metrics.structure_preservation_metrics import qnx_trust_cont, dist_preservation
from src.models.simple_mlp import SimpleMLP
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score
from scipy.stats import mode



def cls_importances(
    x_train, y_train, x_test, y_test, corr_matrix,
    n_neighbors=None, n_repeats=10,
    classifiers=('knn', 'mlp', 'svm'),
    device='cpu', random_state=None):
    """
    Computes classification feature importances via correlation-aware random sampling
    for selected classifiers (knn, mlp, svm), including ensemble majority vote.

    Returns:
        Dict[classifier_name -> (baseline_accuracy, importances_list)]
    """

    if random_state is not None:
        torch.manual_seed(random_state)
        np.random.seed(random_state)

    if n_neighbors is None:
        n_neighbors = int(np.sqrt(x_train.shape[0]))
    else:
        n_neighbors = min(n_neighbors, int(np.sqrt(x_train.shape[0])))

    corr_matrix = torch.tensor(np.abs(corr_matrix), device=device, dtype=torch.float32)
    x_train_torch = torch.tensor(x_train, device=device, dtype=torch.float32)
    x_test_torch = torch.tensor(x_test, device=device, dtype=torch.float32)
    y_train_torch = torch.tensor(y_train, device=device)
    y_test_torch = torch.tensor(y_test, device=device)

    combined_x = torch.cat([x_test_torch, x_train_torch], dim=0)
    n_test, n_features = x_test.shape
    n_train = x_train.shape[0]

    random_indices = torch.randint(
        0, combined_x.shape[0], (n_features, n_repeats, x_train.shape[0] + x_test.shape[0]), device=device
    )

    results = {}

    # === Train classifiers ===
    model = None
    clf_svm = None

    if 'mlp' in classifiers:
        model = SimpleMLP(input_dim=n_features, output_dim=len(np.unique(y_train))).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
        criterion = nn.CrossEntropyLoss()
        model.train()
        for _ in range(20):
            optimizer.zero_grad()
            out = model(x_train_torch)
            loss = criterion(out, y_train_torch)
            loss.backward()
            optimizer.step()
        model.eval()

    if 'svm' in classifiers:
        clf_svm = LinearSVC(random_state=random_state, dual='auto')
        clf_svm.fit(x_train, y_train)
        y_pred_svm = clf_svm.predict(x_test)
        baseline_svm = accuracy_score(y_test, y_pred_svm)

    if 'knn' in classifiers:
        dists = torch.cdist(x_test_torch, x_train_torch, p=2)
        knn_indices = torch.argsort(dists, dim=1)[:, :n_neighbors]
        knn_preds = y_train_torch[knn_indices]
        y_pred_knn = torch.mode(knn_preds, dim=1).values
        baseline_knn = (y_pred_knn == y_test_torch).float().mean().item()

    if 'mlp' in classifiers:
        with torch.no_grad():
            logits = model(x_test_torch)
            y_pred_mlp = logits.argmax(dim=1)
            baseline_mlp = (y_pred_mlp == y_test_torch).float().mean().item()

    # === Ensemble majority voting ===
    voting_preds = []
    if 'knn' in classifiers:
        voting_preds.append(y_pred_knn.cpu().numpy())
    if 'mlp' in classifiers:
        voting_preds.append(y_pred_mlp.cpu().numpy())
    if 'svm' in classifiers:
        voting_preds.append(y_pred_svm)

    if len(voting_preds) > 1:
        pred_matrix = np.vstack(voting_preds)  # shape (n_classifiers, n_test)
        y_pred_ensemble, _ = mode(pred_matrix, axis=0, keepdims=False)
        y_pred_ensemble = y_pred_ensemble.ravel()
        baseline_ensemble = accuracy_score(y_test, y_pred_ensemble)

    # === Perturbation loop ===
    if 'knn' in classifiers:
        imp_knn = torch.zeros(n_features, device=device)
    if 'mlp' in classifiers:
        imp_mlp = torch.zeros(n_features, device=device)
    if 'svm' in classifiers:
        imp_svm = torch.zeros(n_features, device=device)
    if len(voting_preds) > 1:
        imp_ensemble = torch.zeros(n_features, device=device)

    for feature_idx in range(n_features):
        sampled_values = combined_x[:, feature_idx][random_indices[feature_idx]]

        # sampled_train = sampled_values[:, n_test:]
        sampled_test = sampled_values[:, :n_test]

        for repeat_idx in range(n_repeats):
            # x_train_sampled = x_train_torch.clone()
            x_test_sampled = x_test_torch.clone()

            # x_train_sampled[:, feature_idx] = sampled_train[repeat_idx]
            x_test_sampled[:, feature_idx] = sampled_test[repeat_idx]

            prob_matrix = corr_matrix[feature_idx].unsqueeze(0)
            # replace_mask_train = (torch.rand(n_train, n_features, device=device) < prob_matrix)
            replace_mask_test = (torch.rand(n_test, n_features, device=device) < prob_matrix)

            # sampled_values_full_train = combined_x[random_indices[feature_idx, repeat_idx, n_test:], :]
            sampled_values_full_test = combined_x[random_indices[feature_idx, repeat_idx, :n_test], :]

            # x_train_sampled[replace_mask_train] = sampled_values_full_train[replace_mask_train]
            x_test_sampled[replace_mask_test] = sampled_values_full_test[replace_mask_test]

            ensemble_votes = []

            if 'knn' in classifiers:
                dists = torch.cdist(x_test_sampled, x_train_torch, p=2)
                knn_indices = torch.argsort(dists, dim=1)[:, :n_neighbors]
                knn_preds = y_train_torch[knn_indices]
                y_pred_knn_r = torch.mode(knn_preds, dim=1).values
                score = (y_pred_knn_r == y_test_torch).float().mean()
                imp_knn[feature_idx] += score / n_repeats
                ensemble_votes.append(y_pred_knn_r.cpu())

            if 'mlp' in classifiers:
                with torch.no_grad():
                    logits = model(x_test_sampled)
                    y_pred_mlp_r = logits.argmax(dim=1)
                    score = (y_pred_mlp_r == y_test_torch).float().mean()
                    imp_mlp[feature_idx] += score / n_repeats
                    ensemble_votes.append(y_pred_mlp_r.cpu())

            if 'svm' in classifiers:
                x_np = x_test_sampled.cpu().numpy()
                y_pred_svm_r = clf_svm.predict(x_np)
                score = accuracy_score(y_test, y_pred_svm_r)
                imp_svm[feature_idx] += score / n_repeats
                ensemble_votes.append(torch.tensor(y_pred_svm_r))

            if len(ensemble_votes) > 1:
                vote_stack = torch.stack(ensemble_votes)  # shape: (n_classifiers, n_test)
                y_vote, _ = torch.mode(vote_stack, dim=0)
                score = (y_vote == y_test_torch.cpu()).float().mean()
                imp_ensemble[feature_idx] += score / n_repeats

    # === Normalize importances: > 0 is important, otherwise it is harmful ===
    if 'knn' in classifiers:
        imp_knn = imp_knn - baseline_knn
        results['knn'] = (baseline_knn, imp_knn.cpu().numpy().tolist())

    if 'mlp' in classifiers:
        imp_mlp = imp_mlp - baseline_mlp
        results['mlp'] = (baseline_mlp, imp_mlp.cpu().numpy().tolist())

    if 'svm' in classifiers:
        imp_svm = imp_svm - baseline_svm
        results['svm'] = (baseline_svm, imp_svm.cpu().numpy().tolist())

    if len(voting_preds) > 1:
        imp_ensemble = imp_ensemble - baseline_ensemble
        results['ensemble'] = (baseline_ensemble, imp_ensemble.cpu().numpy().tolist())

    return results


def knn_emb_accuracy(x_train, y_train, x_test, y_test, device='cpu'):
    """
    Compute average kNN accuracy in 2D embedding space across a range of k values.
    
    Args:
        x_train (np.ndarray): Training data of shape (n_train, 2)
        y_train (np.ndarray): Training labels of shape (n_train,)
        x_test (np.ndarray): Test data of shape (n_test, 2)
        y_test (np.ndarray): Test labels of shape (n_test,)
        device (str): 'cpu' or 'cuda'

    Returns:
        float: Average kNN accuracy over the specified k values
    """
    assert x_train.shape[1] == 2 and x_test.shape[1] == 2, "Embedding must be 2D"

    # Convert to tensors
    x_train_torch = torch.tensor(x_train, dtype=torch.float32, device=device)
    x_test_torch = torch.tensor(x_test, dtype=torch.float32, device=device)
    y_train_torch = torch.tensor(y_train, device=device)
    y_test_torch = torch.tensor(y_test, device=device)

    n_train = len(y_train)
    k_values = list(range(5, int(np.sqrt(n_train)) + 1, 10))

    dists = torch.cdist(x_test_torch, x_train_torch, p=2)
    sorted_indices = torch.argsort(dists, dim=1)

    accs = []
    for k in k_values:
        knn_preds = y_train_torch[sorted_indices[:, :k]]
        y_pred = torch.mode(knn_preds, dim=1).values
        acc = (y_pred == y_test_torch).float().mean().item()
        accs.append(acc)

    return float(np.mean(accs))


def structure_importances(
    x_train, emb_train, x_test, emb_test, corr_matrix,
    n_repeats=10, device='cpu', random_state=None):

    if random_state is not None:
        seed_everything(random_state)

    corr_matrix = torch.tensor(np.abs(corr_matrix), device=device, dtype=torch.float32)
    x_train = torch.tensor(x_train, device=device, dtype=torch.float32)
    x_test = torch.tensor(x_test, device=device, dtype=torch.float32)
    emb_train = torch.tensor(emb_train, device=device)
    emb_test = torch.tensor(emb_test, device=device)

    combined_x = torch.cat([x_test, x_train], dim=0)
    n_test, n_features = x_test.shape
    n_train = x_train.shape[0]

    # dist_ld = torch.cdist(emb_test, torch.cat([emb_test, emb_train], dim=0), p=2)
    dist_ld = torch.cdist(emb_test, emb_train, p=2)

    # === Baseline (non-perturbed) metrics ===
    # dist_hd_orig = torch.cdist(x_test, combined_x, p=2)
    dist_hd_orig = torch.cdist(x_test, x_train, p=2)
    baseline_qnx, baseline_trust, baseline_cont = qnx_trust_cont(dist_hd_orig, dist_ld, device=device)
    baseline_pearson, baseline_spear, baseline_stress = dist_preservation(dist_hd_orig, dist_ld)

    # === Per-feature importance (perturbed) ===
    qnx_imp = torch.zeros(n_features, device=device)
    trust_imp = torch.zeros(n_features, device=device)
    cont_imp = torch.zeros(n_features, device=device)
    pearson_imp = torch.zeros(n_features, device=device)
    spear_imp = torch.zeros(n_features, device=device)
    stress_imp = torch.zeros(n_features, device=device)

    random_indices = torch.randint(
        0, combined_x.shape[0], (n_features, n_repeats, n_train + n_test), device=device
    )

    for feature_idx in range(n_features):
        qnx_accum = 0.0
        trust_accum = 0.0
        cont_accum = 0.0
        pearson_accum = 0.0
        spear_accum = 0.0
        stress_accum = 0.0

        # Pre-sample feature values
        sampled_values = combined_x[:, feature_idx][random_indices[feature_idx]]  # shape n_repeats x n_samples, each row = univariate points with sampled feature i values

        # sampled_train = sampled_values[:, n_test:]
        sampled_test = sampled_values[:, :n_test]

        for repeat_idx in range(n_repeats):
            # x_train_sampled = x_train.clone()
            x_test_sampled = x_test.clone()

            # Replace feature values for current repeat
            # x_train_sampled[:, feature_idx] = sampled_train[repeat_idx]
            x_test_sampled[:, feature_idx] = sampled_test[repeat_idx]

            # Correlation-based replacements (vectorized)
            prob_matrix = corr_matrix[feature_idx].unsqueeze(0)  # Shape: (1, n_features)
            # replace_mask_train = (torch.rand(n_train, n_features, device=device) < prob_matrix)
            replace_mask_test = (torch.rand(n_test, n_features, device=device) < prob_matrix)

            # sampled_values_full_train = combined_x[random_indices[feature_idx, repeat_idx, n_test:], :]
            sampled_values_full_test = combined_x[random_indices[feature_idx, repeat_idx, :n_test], :]

            # x_train_sampled[replace_mask_train] = sampled_values_full_train[replace_mask_train]
            x_test_sampled[replace_mask_test] = sampled_values_full_test[replace_mask_test]

            # dist_hd_sampled = torch.cdist(x_test_sampled, torch.cat([x_test_sampled, x_train_sampled]), p=2)
            dist_hd_sampled = torch.cdist(x_test_sampled, x_train, p=2)

            qnx_r, trust_r, cont_r = qnx_trust_cont(dist_hd_sampled, dist_ld, device=device)
            pearson_r, spear_r, stress_r = dist_preservation(dist_hd_sampled, dist_ld)

            qnx_accum += qnx_r
            trust_accum += trust_r
            cont_accum += cont_r
            pearson_accum += pearson_r
            spear_accum += spear_r
            stress_accum += stress_r

        # Average over repeats
        qnx_avg = qnx_accum / n_repeats
        trust_avg = trust_accum / n_repeats
        cont_avg = cont_accum / n_repeats
        pearson_avg = pearson_accum / n_repeats
        spear_avg = spear_accum / n_repeats
        stress_avg = stress_accum / n_repeats

        # Normalize
        qnx_imp[feature_idx] = qnx_avg - baseline_qnx
        trust_imp[feature_idx] = trust_avg - baseline_trust
        cont_imp[feature_idx] = cont_avg - baseline_cont
        pearson_imp[feature_idx] = pearson_avg - baseline_pearson
        spear_imp[feature_idx] = spear_avg - baseline_spear
        stress_imp[feature_idx] = baseline_stress - stress_avg

    return {
        'qnx': (baseline_qnx, qnx_imp.cpu().numpy()),
        'trust': (baseline_trust, trust_imp.cpu().numpy()),
        'cont': (baseline_cont, cont_imp.cpu().numpy()),
        'pearson': (baseline_pearson, pearson_imp.cpu().numpy()),
        'spear': (baseline_spear, spear_imp.cpu().numpy()),
        'stress': (baseline_stress, stress_imp.cpu().numpy()),
    }