"""
Core functions demonstration of VVP and GDT based on scGPT.
"""

#!/usr/bin/env python3
import os
import sys
import json
import argparse
import random
from pathlib import Path
import torch
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans

from models_utils.model import scModel
from scgpt_utils.tokenizer import GeneVocab
from grn_config import get_genelink_label_dir
from dataset import load_genelink_splits, load_expression_matrix

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

class ModelArgs:
    def __init__(self, vocab_len, config_path=None):
        self.embsize = 256
        self.nlayers = 6
        self.nheads = 8
        self.d_hid = 256
        self.dropout = 0.15
        self.use_mvc = True
        self.pad_token = "<pad>"
        self.pad_value = 0
        self.mask_value = -1
        self.cls_value = 0
        self.n_bins = 51
        self.input_emb_style = "continuous"
        self.cell_emb_style = "cls"
        self.model_structure = "transformer"
        self.vocab = None
        self.ntoken = vocab_len

        if config_path and os.path.exists(config_path):
            with open(config_path) as f:
                config = json.load(f)
                valid_keys = ['embsize', 'nlayers', 'nheads', 'use_mvc', 'd_hid', 'dropout']
                for k, v in config.items():
                    if k in valid_keys:
                        setattr(self, k, v)

def load_pretrained_model(ckpt_path, device):
    path_obj = Path(ckpt_path).parent
    vocab_path = next((x for x in [path_obj / "vocab.json", Path("vocab.json")] if x.exists()), None)

    vocab = GeneVocab.from_file(vocab_path)
    for token in ["<pad>", "<cls>", "<eoc>"]:
        if token not in vocab:
            vocab.append_token(token)

    model = scModel(vocab, args=ModelArgs(len(vocab), path_obj / "args.json"))

    state_dict = torch.load(ckpt_path, map_location=device)
    clean_state_dict = {}
    for k, v in state_dict.items():
        name = k.replace("module.", "") if k.startswith("module.") else k
        clean_state_dict[name] = v

    model.load_state_dict(clean_state_dict, strict=False)
    return model.to(device).eval(), vocab

def compute_gradient_forward(model, vocab, src_ids, src_val, device):
    batch_size = src_val.shape[0]
    model_src = torch.cat([torch.full((batch_size, 1), vocab["<cls>"], device=device), src_ids], 1)
    model_val = torch.cat([torch.full((batch_size, 1), 0.0, device=device), src_val], 1)
    pad_mask = torch.zeros((batch_size, src_ids.shape[1] + 1), device=device, dtype=torch.bool)

    out = model(model_src, model_val, src_key_padding_mask=pad_mask)
    return (out.get("mvc_output", out.get("mlm_output")))[:, 1:]

@torch.no_grad()
def compute_perturbation_forward(model, vocab, gene_ids, values, device):
    batch_size = values.shape[0]
    src_ids = torch.tensor(gene_ids, device=device).expand(batch_size, -1)
    src_val = torch.tensor(values, device=device).float()

    model_src = torch.cat([torch.full((batch_size, 1), vocab["<cls>"], device=device), src_ids], 1)
    model_val = torch.cat([torch.full((batch_size, 1), 0.0, device=device), src_val], 1)
    padding_mask = torch.zeros((batch_size, src_ids.shape[1] + 1), device=device, dtype=torch.bool)

    out = model(model_src, model_val, src_key_padding_mask=padding_mask)
    return (out.get("mvc_output", out.get("mlm_output")))[:, 1:].cpu().numpy()

def _load_id_name_map(path):
    df = pd.read_csv(path)
    id_col = "index" if "index" in df else next((c for c in df if np.issubdtype(df[c].dtype, np.number)), "_row_id")
    if id_col == "_row_id":
        df["_row_id"] = np.arange(len(df))

    name_col = next((c for c in ["Gene", "TF", "Target", "gene", "tf", "target"] if c in df), None)
    if not name_col:
        name_col = next((c for c in df if not np.issubdtype(df[c].dtype, np.number)), str(id_col))

    return dict(zip(df[id_col].astype(int), df[name_col].astype(str)))

def _prepare_data_indices(label_dir, net, ds, vocab, seed, pos_neg_ratio):
    if not os.path.exists(label_dir):
        return None

    splits = load_genelink_splits(label_dir, seed=seed, train_ratio=1.0, val_ratio=0.0, pos_neg_ratio=pos_neg_ratio)
    all_pairs = splits[0][0]

    id2gene = {**_load_id_name_map(os.path.join(label_dir, "TF.csv")),
               **_load_id_name_map(os.path.join(label_dir, "Target.csv"))}

    gene_names, expr_matrix = load_expression_matrix(label_dir)
    if expr_matrix.shape[0] != len(gene_names):
        expr_matrix = expr_matrix.T

    name_to_idx = {g.upper(): i for i, g in enumerate(gene_names)}
    vocab_stoi = vocab.get_stoi()

    valid_indices = []
    vocab_ids = []
    dataset_pos_map = {}
    expr_indices = []

    for i in sorted(set(all_pairs.flatten())):
        gene_str = id2gene.get(int(i), "").upper()
        if gene_str in vocab_stoi and gene_str in name_to_idx:
            dataset_pos_map[int(i)] = len(valid_indices)
            valid_indices.append(int(i))
            vocab_ids.append(vocab_stoi[gene_str])
            expr_indices.append(name_to_idx[gene_str])

    if not valid_indices:
        return None

    mask = np.isin(all_pairs[:, 0], valid_indices) & np.isin(all_pairs[:, 1], valid_indices)
    valid_pairs = all_pairs[mask]

    return {
        "valid_pairs": valid_pairs,
        "vocab_ids": np.array(vocab_ids, dtype=np.int64),
        "dataset_pos_map": dataset_pos_map,
        "expr_indices": expr_indices,
        "expr_matrix": expr_matrix,
        "unique_genes": list(set(valid_pairs.flatten()))
    }

def _extract_gdt_features(model, vocab, context, args, device):
    valid_pairs = context["valid_pairs"]
    vocab_ids_arr = context["vocab_ids"]
    pos_map = context["dataset_pos_map"]

    bases_to_run = args.grad_bases if args.grad_bases else [0.0]
    total_bases = len(bases_to_run)
    integ_bs = args.integ_batch_size

    grad_accumulator = {}
    unique_targets = list(set(valid_pairs[:, 1]))

    src_ids_base = torch.tensor(vocab_ids_arr, device=device).unsqueeze(0)

    for b_start in range(0, total_bases, integ_bs):
        b_end = min(b_start + integ_bs, total_bases)
        current_bases = bases_to_run[b_start:b_end]
        current_bs = len(current_bases)

        src_ids_batch = src_ids_base.repeat(current_bs, 1)
        val_list = [torch.full((1, len(vocab_ids_arr)), b, device=device) for b in current_bases]
        inp_val = torch.cat(val_list, dim=0).requires_grad_(True)

        output = compute_gradient_forward(model, vocab, src_ids_batch, inp_val, device)

        for target_id in unique_targets:
            t_idx = pos_map[target_id]
            target_outputs = output[:, t_idx]

            grads = torch.autograd.grad(
                outputs=target_outputs,
                inputs=inp_val,
                grad_outputs=torch.ones_like(target_outputs),
                retain_graph=True,
                create_graph=False
            )[0]

            summed_grads = grads.sum(dim=0).detach().cpu().numpy()

            subset_indices = np.where(valid_pairs[:, 1] == target_id)[0]
            for idx in subset_indices:
                source_id = valid_pairs[idx, 0]
                s_idx = pos_map[source_id]
                grad_accumulator[(source_id, target_id)] = grad_accumulator.get((source_id, target_id), 0.0) + summed_grads[s_idx]

        del output, inp_val, src_ids_batch
        torch.cuda.empty_cache()

    base_ref = bases_to_run[0]
    features = []
    for s, t in valid_pairs:
        g_st = grad_accumulator.get((s, t), 0.0)
        g_ts = grad_accumulator.get((t, s), 0.0)
        features.append([g_st, g_ts, base_ref, base_ref])

    return np.array(features, dtype=np.float32)

def _extract_vvp_features(model, vocab, context, args, device):
    valid_pairs = context["valid_pairs"]
    vocab_ids_arr = context["vocab_ids"]
    pos_map = context["dataset_pos_map"]
    unique_genes = context["unique_genes"]

    feature_clusters = []

    if args.pert_bases:
        for b_val, f_val in zip(args.pert_bases, args.input_factors):
            curr_base = np.full(len(vocab_ids_arr), b_val, dtype=np.float32)
            y_base = compute_perturbation_forward(model, vocab, vocab_ids_arr, curr_base[None, :], device)[0]

            perturbation_map = {}
            for i in range(0, len(unique_genes), 32):
                batch_genes = unique_genes[i:i + 32]
                batch_inp = np.tile(curr_base, (len(batch_genes), 1))

                for idx, g in enumerate(batch_genes):
                    batch_inp[idx, pos_map[g]] = f_val

                delta = compute_perturbation_forward(model, vocab, vocab_ids_arr, batch_inp, device) - y_base

                for idx, g in enumerate(batch_genes):
                    perturbation_map[g] = delta[idx]

            pair_feat = []
            for s, t in valid_pairs:
                pair_feat.append([
                    perturbation_map[s][pos_map[t]],
                    perturbation_map[t][pos_map[s]],
                    curr_base[pos_map[s]],
                    curr_base[pos_map[t]]
                ])
            feature_clusters.append(np.array(pair_feat, dtype=np.float32))

    else:
        expr_matrix = context["expr_matrix"]
        expr_indices = context["expr_indices"]
        X_expr = expr_matrix[expr_indices, :].T

        if args.num_clusters > 1:
            base_arrs = KMeans(n_clusters=args.num_clusters, n_init=10, random_state=args.seed).fit(X_expr).cluster_centers_
            base_arrs = base_arrs[np.argsort(base_arrs.mean(axis=1))].astype(np.float32)
        else:
            base_arrs = X_expr.mean(axis=0, keepdims=True).astype(np.float32)

        for curr_base in base_arrs:
            y_base = compute_perturbation_forward(model, vocab, vocab_ids_arr, curr_base[None, :], device)[0]
            perturbation_map = {}

            for i in range(0, len(unique_genes), 32):
                batch_genes = unique_genes[i:i + 32]

                for factor in args.input_factors:
                    batch_inp = np.tile(curr_base, (len(batch_genes), 1))
                    for idx, g in enumerate(batch_genes):
                        batch_inp[idx, pos_map[g]] = factor

                    delta = compute_perturbation_forward(model, vocab, vocab_ids_arr, batch_inp, device) - y_base
                    for idx, g in enumerate(batch_genes):
                        perturbation_map.setdefault(g, {})[factor] = delta[idx]

            cluster_feats = []
            for s, t in valid_pairs:
                feats = [perturbation_map[s][f][pos_map[t]] for f in args.input_factors]
                feats += [perturbation_map[t][f][pos_map[s]] for f in args.input_factors]
                feats += [curr_base[pos_map[s]], curr_base[pos_map[t]]]
                cluster_feats.append(feats)

            feature_clusters.append(np.array(cluster_feats, dtype=np.float32))

    return np.concatenate(feature_clusters, axis=1)

def extract_features_core(net, ds, args, device, model, vocab, mode):
    label_dir = get_genelink_label_dir(args.genelink_root, net, ds, args.num_tfs)
    context = _prepare_data_indices(label_dir, net, ds, vocab, args.seed, args.pos_neg_ratio)

    if context is None:
        return None

    if mode == "grad":
        feature_array = _extract_gdt_features(model, vocab, context, args, device)
    elif mode == "perturbation":
        feature_array = _extract_vvp_features(model, vocab, context, args, device)

    return pd.DataFrame(feature_array).add_prefix("feat_")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_ckpt", default="./checkpoints/model.pt", type=str)
    parser.add_argument("--genelink_root", default="./Dataset", type=str)
    parser.add_argument("--input_factors", nargs="+", type=float, default=[0.0, 5.0, 10.0, 0.5, 2.0])
    parser.add_argument("--pert_bases", nargs="+", type=float, default=[])
    parser.add_argument("--grad_bases", nargs="+", type=float, default=[])
    parser.add_argument("--integ_batch_size", type=int, default=8)
    parser.add_argument("--num_tfs", type=int, default=500)
    parser.add_argument("--num_clusters", type=int, default=1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--pos_neg_ratio", type=float, default=1.0)

    args = parser.parse_args()
    seed_everything(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, vocab = load_pretrained_model(args.model_ckpt, device)

    example_net, example_ds = "STRING", "hESC"

    features_vvp = extract_features_core(example_net, example_ds, args, device, model, vocab, mode="perturbation")
    features_gdt = extract_features_core(example_net, example_ds, args, device, model, vocab, mode="grad")

if __name__ == "__main__":
    main()