import torch
import numpy as np
import time
from sklearn.metrics import roc_curve, auc
import tensorly as tl
from tensorly.decomposition import tucker
import matplotlib.pyplot as plt


tl.set_backend('numpy')


def simulate_cp_tensor(dim, sparse_modes, d=100.0, noise_level=0.1, sparsity=0.5):
    n, p, q = dim

    def generate_factor(length, sparse):
        factor = torch.randn(length)
        if sparse:
            mask = (torch.rand(length) > sparsity).float()
            factor = factor * mask
        return factor

    u = generate_factor(n, sparse_modes.get(0, False))
    v = generate_factor(p, sparse_modes.get(1, False))
    w = generate_factor(q, sparse_modes.get(2, False))

    X_clean = torch.einsum('i,j,k->ijk', u, v, w)
    noise = noise_level * torch.randn(n, p, q)
    X = d * X_clean + noise
    return X, u, v, w


def evaluate_tp_fp(pred_vector, true_vector):
    pred_binary = (np.abs(pred_vector) > 1e-6).astype(int)
    true_binary = (true_vector != 0).astype(int)
    tp = np.sum((pred_binary == 1) & (true_binary == 1)) / np.sum(true_binary)
    fp = np.sum((pred_binary == 1) & (true_binary == 0)) / max(np.sum(true_binary == 0), 1)
    return tp, fp


def run_tucker_with_metrics(X_np, factors_true, rank=[1, 1, 1]):
    start = time.time()
    core, factors = tucker(X_np, rank=rank)
    recon = tl.tucker_to_tensor((core, factors))
    mse = np.mean((X_np - recon) ** 2)
    duration = time.time() - start

    roc_results = {}
    tp_fp = []
    mode_names = ["U", "V", "W"]

    for i in range(3):
        pred_vector = np.linalg.norm(factors[i], axis=1)
        true_vector = factors_true[i].detach().cpu().numpy()
        tp_fp.append(evaluate_tp_fp(pred_vector, true_vector))

        binarized_pred = (pred_vector > 1e-6).astype(int)
        true_labels = (true_vector != 0).astype(int)
        fpr, tpr, _ = roc_curve(true_labels, binarized_pred)
        roc_auc = auc(fpr, tpr)
        roc_results[mode_names[i]] = (fpr, tpr, roc_auc)

    return tp_fp, mse, duration, roc_results


def run_geospca_with_metrics(X, factors_true, epsilon=0.1, maxiter=50):
    from geospca import geospca_solver
    start = time.time()
    roc_results = {}
    tp_fp = []
    mode_names = ["U", "V", "W"]

    for mode in range(X.dim()):
        unfolded_A = X.permute(mode, *[i for i in range(X.dim()) if i != mode]).reshape(X.size(mode), -1)
        A_trans = unfolded_A.T
        n_mode = X.size(mode)
        result = geospca_solver(A_trans, nc=1, k=n_mode // 2, epsilon=epsilon, maxiter=maxiter, device='cpu')

        pred_vector = np.zeros(n_mode)
        if result["Bindices"] is not None:
            pred_vector[result["Bindices"]] = 1

        true_vector = factors_true[mode].detach().cpu().numpy()
        tp_fp.append(evaluate_tp_fp(pred_vector, true_vector))

        true_labels = (true_vector != 0).astype(int)
        fpr, tpr, _ = roc_curve(true_labels, pred_vector)
        roc_auc = auc(fpr, tpr)
        roc_results[mode_names[mode]] = (fpr, tpr, roc_auc)

    duration = time.time() - start
    return tp_fp, duration, roc_results


def average_roc(results_list):
    mean_roc = {}
    mode_names = ["U", "V", "W"]
    for mode in mode_names:
        fpr_list, auc_list = [], []
        for res in results_list:
            fpr, tpr, auc_val = res[mode]
            fpr_list.append(np.interp(np.linspace(0, 1, 100), fpr, tpr))
            auc_list.append(auc_val)
        mean_tpr = np.mean(fpr_list, axis=0)
        mean_auc = np.mean(auc_list)
        mean_roc[mode] = (np.linspace(0, 1, 100), mean_tpr, mean_auc)
    return mean_roc

def plot_roc_comparison(roc_dict, scenario_name="Average over Trials"):
    logic_mode_names = ["U", "V", "W"]
    display_mode_names = [r"$u_1$", r"$v_1$", r"$w_1$"]

    for logic_mode, display_mode in zip(logic_mode_names, display_mode_names):
        plt.figure(figsize=(6, 5))
        for label, roc_data in roc_dict.items():
            fpr, tpr, auc_val = roc_data[logic_mode]
            plt.plot(fpr, tpr, label=f'{label} (AUC={auc_val:.2f})', linewidth=2)
        plt.plot([0, 1], [0, 1], 'k--', linewidth=1)
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f"{scenario_name}: {display_mode}")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()


def run_trials_with_summary(n_trials=5, dim=(100, 100, 100), rank=[1, 1, 1],
                            sparse_modes={0: True, 1: True, 2: True}, d=100.0,
                            noise_level=0.1, sparsity=0.5):
    method_names = ["sparseGeoHOPCA", "HOPCA"]
    roc_storage = {method: [] for method in method_names}
    tp_fp_storage = {method: [] for method in method_names}
    mse_storage = {"HOPCA": []}
    time_storage = {method: [] for method in method_names}

    for _ in range(n_trials):
        X, u_true, v_true, w_true = simulate_cp_tensor(dim, sparse_modes, d, noise_level, sparsity)
        X_np = X.numpy()
        factors_true = [u_true, v_true, w_true]

        # Tucker
        tp_fp_tucker, mse, tucker_time, roc_tucker = run_tucker_with_metrics(X_np, factors_true, rank)
        tp_fp_storage["HOPCA"].append(tp_fp_tucker)
        mse_storage["HOPCA"].append(mse)
        time_storage["HOPCA"].append(tucker_time)
        roc_storage["HOPCA"].append(roc_tucker)

        # GeoSPCA
        tp_fp_geo, geo_time, roc_geo = run_geospca_with_metrics(X, factors_true)
        tp_fp_storage["sparseGeoHOPCA"].append(tp_fp_geo)
        time_storage["sparseGeoHOPCA"].append(geo_time)
        roc_storage["sparseGeoHOPCA"].append(roc_geo)

    return tp_fp_storage, mse_storage, time_storage, roc_storage


def run_and_plot_roc(n_trials=50,dim=(1000, 20, 20), rank=[1, 1, 1],
                            sparse_modes={0: True, 1: False, 2: False}, d=100.0,
                            noise_level=0.1, sparsity=0.5,scenario_name="Scenario 2"):
    tp_fp_storage, mse_storage, time_storage, roc_storage = run_trials_with_summary(n_trials=n_trials,
                            dim=dim, rank=rank,
                            sparse_modes=sparse_modes, d=d,
                            noise_level=noise_level, sparsity=sparsity)


    averaged_rocs = {method: average_roc(roc_storage[method]) for method in roc_storage.keys()}


    plot_roc_comparison(averaged_rocs, scenario_name=scenario_name)

    return tp_fp_storage, mse_storage, time_storage, averaged_rocs


if __name__ == "__main__":
    tp_fp_storage, mse_storage, time_storage, averaged_rocs = run_and_plot_roc(n_trials=50,dim=(100, 100, 100), rank=[1, 1, 1],
                            sparse_modes={0: True, 1: True, 2: True}, d=100.0,
                            noise_level=0.1, sparsity=0.5,scenario_name="Scenario 2")


    for method in tp_fp_storage:
        stats = np.array(tp_fp_storage[method])
        mean = stats.mean(axis=0)
        std = stats.std(axis=0)
        print(f"\n📊 {method} TP/FP:")
        for i, mode in enumerate(["U", "V", "W"]):
            print(f"  Mode {mode}: TP={mean[i][0]:.3f} ± {std[i][0]:.3f}, FP={mean[i][1]:.3f} ± {std[i][1]:.3f}")


    if "HOPCA" in mse_storage:
        mse_arr = np.array(mse_storage["HOPCA"])
        print(f"\n📉 HOPCA MSE: {mse_arr.mean():.4f} ± {mse_arr.std():.4f}")

    for method in time_storage:
        times = np.array(time_storage[method])
        print(f"\n⏱️ {method} Time: {times.mean():.3f} ± {times.std():.3f} seconds")

