#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import pickle
import random
import argparse
import ast
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import  TransformerConv
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, f1_score, jaccard_score


# -----------------------------
# 0. Device setup
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# -----------------------------
# 1. Data loading & preprocessing
# -----------------------------
def load_node_data(graphs_path, labels_csv, num_classes, data_seed=42):
    with open(graphs_path,"rb") as f:
        all_graphs = pickle.load(f)

    random.seed(data_seed)
    torch.manual_seed(data_seed)
    graphs_copy = all_graphs.copy()
    random.shuffle(graphs_copy)
    N = len(graphs_copy)
    n_train = int(0.8 * N)
    graphs = graphs_copy[n_train:]  # use only 20% for node task
    print(f"Using {len(graphs)} graphs for node-level task (20% of {N}).")

    ml_df = pd.read_csv(labels_csv)
    def parse_labels(x):
        if isinstance(x,str):
            return list(map(int,ast.literal_eval(x)))
        elif isinstance(x,(int,float)):
            return [int(x)]
        else:
            return list(x)

    ml_df["parsed"] = ml_df["Category_Label"].apply(parse_labels)
    ml_df = ml_df[~ml_df["parsed"].apply(lambda L: num_classes in L)].reset_index(drop=True)

    labels_dict = {}
    for _,row in ml_df.iterrows():
        gi = int(row["Gene_Index"])
        vec = np.zeros(num_classes, dtype=int)
        for c in row["parsed"]:
            vec[c] = 1
        labels_dict[gi] = vec
    gene_list = sorted(labels_dict.keys())
    return graphs, labels_dict, gene_list

# -----------------------------
# 2. Checkpoint loader config
# -----------------------------
LOAD_CONFIG = {
    "supgcl":       {"state_key":"model",              "prefix":"encoder."},
    "gae":          {"state_key":None,                 "prefix":"encoder."},
    "grace":        {"state_key":None,                 "prefix":"encoder."},
    "graphcl":      {"state_key":None,                 "prefix":"encoder."},
    "sgrl":         {"state_key":None,                 "prefix":"online_encoder.backbone."},
    "w_o_pretrain": None,
}

def load_encoder_weights(encoder, ckpt_path, method):
    if method == "w_o_pretrain":
        return encoder
    cfg = LOAD_CONFIG[method]
    state = torch.load(ckpt_path, map_location=device)
    raw = state[cfg["state_key"]] if cfg["state_key"] else state
    prefix = cfg["prefix"]
    enc_state = {k[len(prefix):]:v for k,v in raw.items() if k.startswith(prefix)}
    encoder.load_state_dict(enc_state, strict=False)
    return encoder


# -----------------------------
# 3. Device transfer helper
# -----------------------------
def to_device(data, device):
    data.x = data.x.to(device)
    data.edge_index = data.edge_index.to(device)
    if data.edge_attr is not None:
        data.edge_attr = data.edge_attr.to(device)
        if data.edge_attr.dim() == 1:
            data.edge_attr = data.edge_attr.unsqueeze(-1)
    if hasattr(data, "batch") and data.batch is not None:
        data.batch = data.batch.to(device)
    return data


# -----------------------------
# 4. Dataset & collate_fn
# -----------------------------
class GraphNodeDataset(torch.utils.data.Dataset):
    def __init__(self, graphs, labels_dict, gene_subset, num_classes):
        self.graphs = graphs
        self.labels = []
        self.masks  = []
        for g in graphs:
            N = g.x.size(0)
            L = torch.zeros((N, num_classes),dtype=torch.float)
            M = torch.zeros(N, dtype=torch.bool)
            for gi in gene_subset:
                if gi < N:
                    L[gi] = torch.tensor(labels_dict[gi],dtype=torch.float)
                    M[gi] = True
            self.labels.append(L)
            self.masks.append(M)
    def __len__(self):
        return len(self.graphs)
    def __getitem__(self, idx):
        return self.graphs[idx], self.labels[idx], self.masks[idx]

def collate_graphnodes(batch):
    graphs, Ls, Ms = zip(*batch)
    batched = Batch.from_data_list(graphs)
    return batched, torch.stack(Ls), torch.stack(Ms)


# -----------------------------
# 5. Model definitions
# -----------------------------
class GraphEncoder(nn.Module):
    def __init__(self, in_ch, hid_ch, out_ch, heads):
        super().__init__()
        self.convs = nn.ModuleList([
            TransformerConv(in_ch,        hid_ch, heads=heads, edge_dim=1, dropout=0.1),
            TransformerConv(hid_ch*heads, hid_ch, heads=heads, edge_dim=1, dropout=0.1),
            TransformerConv(hid_ch*heads, hid_ch, heads=heads, edge_dim=1, dropout=0.1),
            TransformerConv(hid_ch*heads, hid_ch, heads=heads, edge_dim=1, dropout=0.1),
            TransformerConv(hid_ch*heads, out_ch,  heads=1,       edge_dim=1, dropout=0.1),
        ])

    def forward(self, x, edge_index, edge_attr, batch):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index, edge_attr))
        return self.convs[-1](x, edge_index, edge_attr)


class NodeClassifierAll(nn.Module):
    def __init__(self, encoder, emb_dim, hid_dim, num_classes):
        super().__init__()
        self.encoder = encoder
        self.head = nn.Sequential(
            nn.Linear(emb_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, num_classes)
        )
    def forward(self, data):
        z = self.encoder(data.x, data.edge_index, data.edge_attr, data.batch)
        return self.head(z)


# -----------------------------
# 6. Training & evaluation
# -----------------------------
def run_finetune(graphs, labels_dict, gene_list, args):
    kf = KFold(n_splits=args.n_splits, shuffle=True, random_state=args.seed)
    all_subset, all_macro, all_micro, all_jaccard = [], [], [], []

    for fold, (tr_idx, va_idx) in enumerate(kf.split(gene_list),1):
        print(f"\n=== Fold {fold} ===")
        train_genes = [gene_list[i] for i in tr_idx]
        val_genes   = [gene_list[i] for i in va_idx]

        train_ds = GraphNodeDataset(graphs, labels_dict, train_genes, args.num_classes)
        val_ds   = GraphNodeDataset(graphs, labels_dict, val_genes,   args.num_classes)
        tr_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_graphnodes)
        va_loader = DataLoader(val_ds,   batch_size=args.batch_size, shuffle=False, collate_fn=collate_graphnodes)

        in_ch = graphs[0].x.size(1)
        encoder = GraphEncoder(in_ch, args.hid_ch, args.out_ch, args.heads).to(device)
        encoder = load_encoder_weights(encoder, args.ckpt, args.method)
        model   = NodeClassifierAll(encoder, args.out_ch, args.out_ch//2, args.num_classes).to(device)
        opt     = torch.optim.AdamW(model.parameters(), lr=args.lr)

        best_macro, no_imp = 0.0, 0
        best_metrics = (0,0,0,0)

        for epoch in range(1, args.epochs+1):
            model.train()
            for data, labels, mask in tr_loader:
                data = to_device(data, device)
                B,N,C = labels.shape
                labels = labels.view(B*N,C).to(device)
                mask   = mask.view(B*N).to(device)

                opt.zero_grad()
                logits = model(data)
                loss = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')
                loss = loss[mask].mean()
                loss.backward()
                opt.step()

            model.eval()
            y_true, y_pred = [], []
            with torch.no_grad():
                for data, labels, mask in va_loader:
                    data = to_device(data, device)
                    B,N,C = labels.shape
                    labels = labels.view(B*N,C).to(device)
                    mask   = mask.view(B*N).to(device)

                    probs = torch.sigmoid(model(data)).cpu().numpy()
                    preds = (probs>0.5).astype(int)
                    true  = labels.cpu().numpy()
                    m     = mask.cpu().numpy()

                    y_true.append(true[m]); y_pred.append(preds[m])

            y_true = np.vstack(y_true); y_pred = np.vstack(y_pred)
            subset  = accuracy_score(y_true, y_pred)
            macro   = f1_score(y_true, y_pred, average="macro", zero_division=0)
            micro   = f1_score(y_true, y_pred, average="micro", zero_division=0)
            jaccard = jaccard_score(y_true, y_pred, average="samples", zero_division=0)

            if macro > best_macro:
                best_macro = macro
                best_metrics = (subset,macro,micro,jaccard)
                no_imp = 0
            else:
                no_imp += 1
                if no_imp >= args.patience:
                    break

        s,m,mi,ja = best_metrics
        print(f"Fold{fold} → Subset:{s:.4f}, Macro:{m:.4f}, Micro:{mi:.4f}, Jaccard:{ja:.4f}")
        all_subset.append(s); all_macro.append(m); all_micro.append(mi); all_jaccard.append(ja)

    print("\n===== CV Results =====")
    print(f"Subset: {np.mean(all_subset):.4f} ± {np.std(all_subset):.4f}")
    print(f"Macro : {np.mean(all_macro):.4f} ± {np.std(all_macro):.4f}")
    print(f"Micro : {np.mean(all_micro):.4f} ± {np.std(all_micro):.4f}")
    print(f"Jaccard: {np.mean(all_jaccard):.4f} ± {np.std(all_jaccard):.4f}")


# -----------------------------
# 7. Parser (just above main)
# -----------------------------
def parse_args():
    p = argparse.ArgumentParser(description="Node‐level multilabel gene classification")
    p.add_argument("--method",      type=str, required=True,
                   choices=list(LOAD_CONFIG.keys()),
                   help="Pretraining method ('w_o_pretrain' = train from scratch)")
    p.add_argument("--ckpt",        type=str, default=None,
                   help="Path to pretrained checkpoint (omit if w_o_pretrain)")
    p.add_argument("--data",        type=str, required=True,
                   help="Path to pickled graphs (.pkl)")
    p.add_argument("--labels",      type=str, required=True,
                   help="CSV with columns 'Gene_Index','Category_Label'")
    p.add_argument("--batch-size",  type=int, default=8)
    p.add_argument("--lr",          type=float, default=1e-3)
    p.add_argument("--epochs",      type=int, default=500)
    p.add_argument("--patience",    type=int, default=5)
    p.add_argument("--n-splits",    type=int, default=10)
    p.add_argument("--hid-ch",      type=int, default=64)
    p.add_argument("--out-ch",      type=int, default=64)
    p.add_argument("--heads",       type=int, default=8)
    p.add_argument("--num-classes", type=int, default=3)
    p.add_argument("--seed",        type=int, default=0)
    args = p.parse_args()
    if args.method != "w_o_pretrain" and args.ckpt is None:
        p.error("--ckpt is required unless --method w_o_pretrain")
    return args


# -----------------------------
# Main
# -----------------------------
if __name__ == "__main__":
    args = parse_args()
    graphs, labels_dict, gene_list = load_node_data(
        args.data, args.labels, args.num_classes, data_seed=args.seed)
    run_finetune(graphs, labels_dict, gene_list, args)
