import torch
import torch.nn as nn
import torch.nn.functional as Fnn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import os

import numpy as np
import sys
from matplotlib import pyplot as plt
import warnings
from scipy.linalg import qr, sqrtm
import seaborn as sns
from tqdm import tqdm
from pytorch_metric_learning import losses
from sklearn.decomposition import PCA
import argparse
import math
import pandas as pd
from scipy.linalg import block_diag

from utils import *



DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[run.py] Using device: {DEVICE}")

def load_on_device(path):
    return torch.load(path, map_location=DEVICE, weights_only=False)


import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from pathlib import Path
import json


def data_gen_citeseq_dtm(root_dir, n):

    dir1 = f'{root_dir}/citeseq/rna_pca.csv'
    dir2 = f'{root_dir}/citeseq/adt_pca.csv'

    df1 = pd.read_csv(dir1)
    X_all = np.array(df1.drop(df1.columns[0], axis=1))
    df2 = pd.read_csv(dir2)
    Y_all = np.array(df2.drop(df2.columns[0], axis=1))

    lab1_dir = f'{root_dir}/citeseq/lab1.csv'
    lab2_dir = f'{root_dir}/citeseq/lab2.csv'
    lab1 = pd.read_csv(lab1_dir)
    lab1 = np.array(lab1.drop(lab1.columns[0], axis=1)).ravel()
    lab2 = pd.read_csv(lab2_dir)
    lab2 = np.array(lab2.drop(lab2.columns[0], axis=1)).ravel()

    d_x = X_all.shape[1]
    d_y = Y_all.shape[1]
    d_z = 0

    idx = np.arange(X_all.shape[0])
    np.random.shuffle(idx)


    X = X_all[:n,:]
    Y = Y_all[:n,:]
    X_test = X_all[n:,:]
    Y_test = Y_all[n:,:]

    lab1_train = lab1[:n]
    lab2_train = lab2[:n]
    lab1_test = lab1[n:]
    lab2_test = lab2[n:]

    rna_wts = pd.read_csv(f"{root_dir}/citeseq/rna_wts.csv")
    wts_train = rna_wts[:n]['x']
    wts_test = rna_wts[n:]['x']

    return X, Y, X_test, Y_test, lab1_train, lab1_test, lab2_train, lab2_test, wts_train, wts_test



import numpy as np
import torch
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import NearestNeighbors

def _to_np(x):
    if isinstance(x, np.ndarray):
        return x
    if hasattr(x, "detach"):  # torch tensor
        return x.detach().cpu().numpy()
    return np.asarray(x)

def per_class_accuracy(y_true, y_pred, labels=None):

    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    if labels is None:
        labels = np.unique(y_true)
    out = {}
    for k in labels:
        m = (y_true == k)
        out[k] = np.nan if m.sum() == 0 else float((y_pred[m] == k).mean())
    return out

def knn_same_label_scores_testset(X_test, y_test, K=10, metric="euclidean", exclude_self=True):

    X_test = _to_np(X_test)
    y_test = _to_np(y_test)

    K_eff = K + (1 if exclude_self else 0)
    nn = NearestNeighbors(n_neighbors=K_eff, metric=metric)
    nn.fit(X_test)
    dist, idx = nn.kneighbors(X_test, return_distance=True)

    if exclude_self:
        new_idx = np.empty((idx.shape[0], K), dtype=idx.dtype)
        for i in range(idx.shape[0]):
            row = idx[i]
            mask = (row != i)
            filtered = row[mask]
            new_idx[i] = filtered[:K] if filtered.size >= K else row[1:K+1]
        idx = new_idx
    else:
        idx = idx[:, :K]

    nn_labels = y_test[idx]               # (n_test, K)
    same = (nn_labels == y_test[:, None])
    scores = same.mean(axis=1)

    overall = float(scores.mean())
    per_class = {}
    for k in np.unique(y_test):
        m = (y_test == k)
        per_class[k] = np.nan if m.sum() == 0 else float(scores[m].mean())
    macro = float(np.nanmean(list(per_class.values())))
    return overall, per_class, macro

def evaluate_classification(
    method,  # "linear" or "knn"
    XX, YY, FF, GG,                   
    XX_test, YY_test, FF_test, GG_test,  # test features
    result_X, result_Y,               
    lab1_train, lab1_test, lab2_train, lab2_test,
    *,
    K=10, metric="euclidean", exclude_self=True,
    lr_kwargs=None
):
    
    lr_kwargs = lr_kwargs or {"max_iter": 2000}


    with torch.no_grad():
        W_x      = result_X['models']['w'](torch.tensor(_to_np(XX), dtype=torch.float32)).cpu().numpy()
        W_y      = result_Y['models']['w'](torch.tensor(_to_np(YY), dtype=torch.float32)).cpu().numpy()
        W_x_test = result_X['models']['w'](torch.tensor(_to_np(XX_test), dtype=torch.float32)).cpu().numpy()
        W_y_test = result_Y['models']['w'](torch.tensor(_to_np(YY_test), dtype=torch.float32)).cpu().numpy()

    C_x      = np.concatenate([_to_np(XX),      W_x],      axis=1)
    C_y      = np.concatenate([_to_np(YY),      W_y],      axis=1)
    C_x_test = np.concatenate([_to_np(XX_test), W_x_test], axis=1)
    C_y_test = np.concatenate([_to_np(YY_test), W_y_test], axis=1)

    label_sets = {
        "lab1": (lab1_train, lab1_test),
        "lab2": (lab2_train, lab2_test),
    }

    accs_x, accs_y = {}, {}

    if method.lower() == "linear":

        for label_name, (lab_tr, lab_te) in label_sets.items():
            y_tr = _to_np(lab_tr)
            y_te = _to_np(lab_te)

            # X-side
            clf_x  = LogisticRegression(**lr_kwargs).fit(_to_np(XX), y_tr)
            clf_f  = LogisticRegression(**lr_kwargs).fit(_to_np(FF), y_tr)
            clf_w  = LogisticRegression(**lr_kwargs).fit(_to_np(W_x), y_tr)
            clf_cx = LogisticRegression(**lr_kwargs).fit(_to_np(C_x), y_tr)

            pred_x  = clf_x.predict(_to_np(XX_test))
            pred_f  = clf_f.predict(_to_np(FF_test))
            pred_w  = clf_w.predict(_to_np(W_x_test))
            pred_cx = clf_cx.predict(_to_np(C_x_test))

            overall_x = {
                "X":   accuracy_score(y_te, pred_x),
                "F":   accuracy_score(y_te, pred_f),
                "W":   accuracy_score(y_te, pred_w),
                "C_x": accuracy_score(y_te, pred_cx),
            }
            per_class_x = {
                "X":   per_class_accuracy(y_te, pred_x),
                "F":   per_class_accuracy(y_te, pred_f),
                "W":   per_class_accuracy(y_te, pred_w),
                "C_x": per_class_accuracy(y_te, pred_cx),
            }
            macro_x = {m: float(np.nanmean(list(per_class_x[m].values()))) for m in ["X", "F", "W", "C_x"]}
            accs_x[label_name] = {"overall": overall_x, "per_class": per_class_x, "macro_avg": macro_x}

            # Y-side
            clf_y  = LogisticRegression(**lr_kwargs).fit(_to_np(YY), y_tr)
            clf_g  = LogisticRegression(**lr_kwargs).fit(_to_np(GG), y_tr)
            clf_wy = LogisticRegression(**lr_kwargs).fit(_to_np(W_y), y_tr)
            clf_cy = LogisticRegression(**lr_kwargs).fit(_to_np(C_y), y_tr)

            pred_y  = clf_y.predict(_to_np(YY_test))
            pred_g  = clf_g.predict(_to_np(GG_test))
            pred_wy = clf_wy.predict(_to_np(W_y_test))
            pred_cy = clf_cy.predict(_to_np(C_y_test))

            overall_y = {
                "X":   accuracy_score(y_te, pred_y),
                "F":   accuracy_score(y_te, pred_g),
                "W":   accuracy_score(y_te, pred_wy),
                "C_y": accuracy_score(y_te, pred_cy),
            }
            per_class_y = {
                "X":   per_class_accuracy(y_te, pred_y),
                "F":   per_class_accuracy(y_te, pred_g),
                "W":   per_class_accuracy(y_te, pred_wy),
                "C_y": per_class_accuracy(y_te, pred_cy),
            }
            macro_y = {m: float(np.nanmean(list(per_class_y[m].values()))) for m in ["X", "F", "W", "C_y"]}
            accs_y[label_name] = {"overall": overall_y, "per_class": per_class_y, "macro_avg": macro_y}

    elif method.lower() == "knn":

        for label_name, (_, lab_te) in label_sets.items():
            y_te = _to_np(lab_te)

            # X-side test reps
            reps_x = {
                "X":   (_to_np(XX_test), y_te),
                "F":   (_to_np(FF_test), y_te),
                "W":   (_to_np(W_x_test), y_te),
                "C_x": (_to_np(C_x_test), y_te),
            }
            overall_x, per_class_x, macro_x = {}, {}, {}
            for name, (X_te, y_te_) in reps_x.items():
                ov, pc, ma = knn_same_label_scores_testset(X_te, y_te_, K=K, metric=metric, exclude_self=exclude_self)
                overall_x[name], per_class_x[name], macro_x[name] = ov, pc, ma
            accs_x[label_name] = {"overall": overall_x, "per_class": per_class_x, "macro_avg": macro_x}

            # Y-side test reps
            reps_y = {
                "X":   (_to_np(YY_test), y_te),
                "F":   (_to_np(GG_test), y_te),
                "W":   (_to_np(W_y_test), y_te),
                "C_y": (_to_np(C_y_test), y_te),
            }
            overall_y, per_class_y, macro_y = {}, {}, {}
            for name, (Y_te, y_te_) in reps_y.items():
                ov, pc, ma = knn_same_label_scores_testset(Y_te, y_te_, K=K, metric=metric, exclude_self=exclude_self)
                overall_y[name], per_class_y[name], macro_y[name] = ov, pc, ma
            accs_y[label_name] = {"overall": overall_y, "per_class": per_class_y, "macro_avg": macro_y}
    else:
        raise ValueError("method must be 'linear' or 'knn'")

    return accs_x, accs_y



parser = argparse.ArgumentParser(description="Run experiment with chosen dataset")
parser.add_argument("--idx", type=int, required=True,
                    help="index")
parser.add_argument("--lam", type=float, required=True,
                    help="parameter")
args = parser.parse_args()

idx = args.idx
lam = args.lam

dataset_nam = "citeseq"


seed = 2025
torch.manual_seed(seed)
np.random.seed(seed)




# === Dataset ===
n = 15000              # Number of training samples
n_test = 1000         # Number of test samples


root_dir = "../citeseq_data"


data_dir = root_dir + "/data_unnormalized"


XX, YY, XX_test, YY_test, lab1_train, lab1_test, lab2_train, lab2_test, wts_train, wts_test = data_gen_citeseq_dtm(data_dir, n,)



group_defs = {
    "Mono/DC": {
        'CD14 Mono', 'cDC2', 'pDC', 'CD16 Mono'
    },
    "B cell": {
        'Memory B', 'Naive B', 'Plasmablast'
    },
    "T cell": {
        'CD8 Naive', 'CD8 Memory_2', 'gdT', 'CD8 Memory_1', 'Treg',
        'CD4 Memory', 'CD8 Effector_1', 'CD8 Effector_2', 'CD4 Naive', 'MAIT'
    },
    "NK": {
        'NK', 'CD56 bright NK'
    },
    "Progenitor cells": {
        'HSC', 'Prog_B 1', 'LMPP', 'Prog_B 2', 'Prog_RBC', 'Prog_DC', 'Prog_Mk', 'GMP'
    }
}


from train import *



for arch in ['transformer', 'deep']:

    for norm_ in [True]:


        save_dir = f'./results/{dataset_nam}_norm{norm_}'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        clip_results = torch.load(f'{save_dir}/clip_results_{arch}_{idx}.pt', weights_only=False)

        XX_clip, YY_clip = clip_results['model_x'](torch.Tensor(XX)).detach().numpy(), clip_results['model_y'](torch.Tensor(YY)).detach().numpy()

        XX, FF = XX, XX_clip
        YY, GG = YY, YY_clip

        for objective in ['disen', 'fact', 'recons', 'clip']:

            if objective != 'clip':
                result_X = torch.load(f'{save_dir}/result_X_{objective}_{arch}_{idx}_{lam}.pt', weights_only=False)
                result_Y = torch.load(f'{save_dir}/result_Y_{objective}_{arch}_{idx}_{lam}.pt', weights_only=False)

                FF_test = clip_results['model_x'](torch.Tensor(XX_test)).detach().numpy()
                GG_test = clip_results['model_y'](torch.Tensor(YY_test)).detach().numpy()
            else:
                result_X = torch.load(f'{save_dir}/result_X_recons_{arch}_{idx}_{lam}.pt', weights_only=False)
                result_Y = torch.load(f'{save_dir}/result_Y_recons_{arch}_{idx}_{lam}.pt', weights_only=False)

                FF_test = clip_results['model_x'](torch.Tensor(XX_test)).detach().numpy()
                GG_test = clip_results['model_y'](torch.Tensor(YY_test)).detach().numpy()


            class_method = "knn"  # "linear" or "knn"'

            accs_x, accs_y = evaluate_classification(
                class_method,
                XX, YY, FF, GG, XX_test, YY_test, FF_test, GG_test,
                result_X, result_Y,
                lab1_train, lab1_test, lab2_train, lab2_test,
                lr_kwargs={"max_iter": 2000}
            )

            if objective != 'clip':
                x_lab2 = np.array(list(accs_x["lab2"]["per_class"]["C_x"].values()))
                y_lab2 = np.array(list(accs_y["lab2"]["per_class"]["C_y"].values()))
                labs   = list(accs_y["lab2"]["per_class"]["C_y"].keys())
            else:
                x_lab2 = np.array(list(accs_x["lab2"]["per_class"]["F"].values()))
                y_lab2 = np.array(list(accs_y["lab2"]["per_class"]["F"].values()))
                labs   = list(accs_y["lab2"]["per_class"]["F"].keys())


            # Construct DataFrame
            rt_lab2 = pd.DataFrame({
                "value": -np.log(y_lab2 / (x_lab2 + 1e-5)),  # log ratio
                "labs": labs
            })

            # Sort by value
            df_sorted = rt_lab2.sort_values("value").reset_index(drop=True)



            val_map = dict(zip(df_sorted["labs"], df_sorted["value"]))
            mask_present = np.array([lbl in val_map for lbl in lab2_test])

            df_obs = pd.DataFrame({
                "label": np.array(lab2_test)[mask_present],
                "RNA.weights": np.array(wts_test)[mask_present],
            })
            df_obs["class_value"] = df_obs["label"].map(val_map)


            order_log_acc = df_sorted.sort_values("value")["labs"]  # order by log(acc ratio)
            class_means = df_obs.groupby("label")["RNA.weights"].mean().sort_values()
            order_rna = class_means.index.tolist()


            unique_labels = sorted(df_obs["label"].unique())
            palette_dict = dict(zip(unique_labels, sns.color_palette("tab20", len(unique_labels))))


            rank_log = {lab: i for i, lab in enumerate(order_log_acc)}
            rank_rna = {lab: i for i, lab in enumerate(order_rna)}


            labels_common = [lab for lab in unique_labels if lab in rank_log and lab in rank_rna]
            rank_diff = {lab: abs(rank_log[lab] - rank_rna[lab]) for lab in labels_common}
            close_labels = [lab for lab, d in rank_diff.items() if d < 3]  # Δrank < 3


            medians = df_obs.groupby("label")["RNA.weights"].median()


            def mark_close_classes(ax, order, medians_series, close_set):
                for i, lab in enumerate(order):
                    if lab in close_set and lab in medians_series.index:
                        y = 1.2
                        if pd.notna(y):
                            ax.scatter(i, y, marker="*", s=180, facecolors="none",
                                    edgecolors="black", linewidths=1.5, zorder=5)



            from scipy.stats import spearmanr, kendalltau


            rank_log = {lab: i for i, lab in enumerate(order_log_acc)}
            rank_rna = {lab: i for i, lab in enumerate(order_rna)}

            rank_log_vec = np.array([rank_log[lab] for lab in unique_labels if lab in rank_log])
            rank_rna_vec = np.array([rank_rna[lab] for lab in unique_labels if lab in rank_rna])


            labels_common = list(rank_log.keys())

            ranks_a = [rank_log[l] for l in labels_common]
            ranks_b = [rank_rna[l] for l in labels_common]


            rho, _ = spearmanr(ranks_a, ranks_b)

            # Kendall’s tau
            tau, _ = kendalltau(ranks_a, ranks_b)


            k = 10
            topk_a = set(order_log_acc[:k])
            topk_b = set(order_rna[:k])
            overlap_topk = len(topk_a & topk_b) / k

            print(f"Spearman rho: {rho:.3f}")
            print(f"Kendall tau: {tau:.3f}")
            print(f"Top-{k} overlap: {overlap_topk:.2f}")



            label_to_group = {}
            for gname, labels in group_defs.items():
                for lab in labels:
                    label_to_group[lab] = gname


            import json, pathlib


            out_path = pathlib.Path(save_dir) / f"results_{arch}_{objective}_{idx}_{lam}.json"


            out_obj = {
                "accs_x": accs_x,
                "accs_y": accs_y,
                "rho": float(rho),
                "tau": float(tau),
            }

            def _to_py(o):
                import numpy as np
                import pandas as pd
                if isinstance(o, (np.generic,)):   # np.float32, np.int64, etc.
                    return o.item()
                if isinstance(o, (np.ndarray,)):
                    return o.tolist()
                if isinstance(o, pd.Series):
                    return o.to_dict()
                if isinstance(o, pd.DataFrame):
                    return o.to_dict(orient="list")
                if isinstance(o, dict):
                    return {k: _to_py(v) for k,v in o.items()}
                if isinstance(o, (list, tuple)):
                    return [_to_py(v) for v in o]
                return o

            with open(out_path, "w") as f:
                json.dump(_to_py(out_obj), f, indent=2)

            print(f"Saved compact results to {out_path}")

