#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import pickle
import random
import argparse
import ast

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import TransformerConv
from sklearn.metrics import accuracy_score, f1_score

# -----------------------------
# 0. Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# -----------------------------
# 1. Data loading & preprocessing
# -----------------------------
def load_graphs(graphs_path, data_seed):
    with open(graphs_path, "rb") as f:
        all_graphs = pickle.load(f)
    random.seed(data_seed)
    torch.manual_seed(data_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(data_seed)
    graphs_copy = all_graphs.copy()
    random.shuffle(graphs_copy)
    n_train = int(0.8 * len(graphs_copy))
    return graphs_copy[n_train:]  # use 20% for node task

def load_labels(labels_csv):
    df = pd.read_csv(labels_csv)
    labels = {int(r.Gene_Index): int(r.Cancer_relation)
              for _, r in df.iterrows()}
    genes = sorted(labels.keys())
    return labels, genes


# -----------------------------
# 2. Checkpoint loader config
# -----------------------------
LOAD_CONFIG = {
    "supgcl":       {"state_key":"model",              "prefix":"encoder."},
    "gae":          {"state_key":None,                 "prefix":"encoder."},
    "grace":        {"state_key":None,                 "prefix":"encoder."},
    "graphcl":      {"state_key":None,                 "prefix":"encoder."},
    "sgrl":         {"state_key":None,                 "prefix":"online_encoder.backbone."},
    "w_o_pretrain": None,
}

def load_encoder_weights(encoder, ckpt_path, method):
    if method == "w_o_pretrain":
        return encoder
    cfg = LOAD_CONFIG[method]
    state = torch.load(ckpt_path, map_location=device)
    raw = state[cfg["state_key"]] if cfg["state_key"] else state
    prefix = cfg["prefix"]
    enc_state = {k[len(prefix):]: v for k, v in raw.items() if k.startswith(prefix)}
    encoder.load_state_dict(enc_state, strict=False)
    return encoder


# -----------------------------
# 3. Device transfer helper
# -----------------------------
def to_device(data, device):
    data.x = data.x.to(device)
    data.edge_index = data.edge_index.to(device)
    if data.edge_attr is not None:
        data.edge_attr = data.edge_attr.to(device)
        if data.edge_attr.dim() == 1:
            data.edge_attr = data.edge_attr.unsqueeze(-1)
    if hasattr(data, "batch") and data.batch is not None:
        data.batch = data.batch.to(device)
    return data


# -----------------------------
# 4. Dataset & collate_fn
# -----------------------------
class GraphBinaryDataset(torch.utils.data.Dataset):
    def __init__(self, graphs, labels_dict, gene_subset):
        self.graphs = graphs
        self.labels = []
        self.masks  = []
        for g in graphs:
            N = g.x.size(0)
            L = torch.zeros((N,1), dtype=torch.float)
            M = torch.zeros(N,   dtype=torch.bool)
            for gi in gene_subset:
                if gi < N:
                    L[gi,0] = float(labels_dict[gi])
                    M[gi]   = True
            self.labels.append(L)
            self.masks.append(M)

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx], self.labels[idx], self.masks[idx]

def collate_binary(batch):
    graphs, labs, ms = zip(*batch)
    batched = Batch.from_data_list(graphs)
    labels  = torch.cat(labs, dim=0)   
    mask    = torch.cat(ms,  dim=0)    
    return batched, labels, mask


# -----------------------------
# 5. Model definitions
# -----------------------------
class GraphEncoder(nn.Module):
    def __init__(self,in_ch,hid_ch,out_ch,heads):
        super().__init__()
        self.conv1 = TransformerConv(in_ch,  hid_ch, heads=heads, edge_dim=1, dropout=0.1)
        self.conv2 = TransformerConv(hid_ch*heads, hid_ch, heads=heads, edge_dim=1, dropout=0.1)
        self.conv3 = TransformerConv(hid_ch*heads, hid_ch, heads=heads, edge_dim=1, dropout=0.1)
        self.conv4 = TransformerConv(hid_ch*heads, hid_ch, heads=heads, edge_dim=1, dropout=0.1)
        self.conv5 = TransformerConv(hid_ch*heads, out_ch,  heads=1,       edge_dim=1, dropout=0.1)

    def forward(self,x,edge_index,edge_attr,batch):
        x = F.relu(self.conv1(x,edge_index,edge_attr))
        x = F.relu(self.conv2(x,edge_index,edge_attr))
        x = F.relu(self.conv3(x,edge_index,edge_attr))
        x = F.relu(self.conv4(x,edge_index,edge_attr))
        return self.conv5(x,edge_index,edge_attr)

class NodeBinaryClassifier(nn.Module):
    def __init__(self, encoder, emb_dim, hid_dim=None):
        super().__init__()
        self.encoder = encoder
        h = hid_dim or (emb_dim // 2)
        self.head = nn.Sequential(
            nn.Linear(emb_dim, h),
            nn.ReLU(),
            nn.Linear(h, 1)
        )
    def forward(self,data):
        z = self.encoder(data.x,data.edge_index,data.edge_attr,data.batch)
        return self.head(z).squeeze(-1)   # [total_nodes]



# -----------------------------
# 6. Training & evaluation
# -----------------------------
def run_finetune(graphs, labels_dict, all_genes, args):
    in_ch = graphs[0].x.size(1)
    seeds = list(range(args.num_seeds))

    test_accs = []
    test_f1s  = []

    for seed in seeds:
        # undersampling
        ones  = [g for g in all_genes if labels_dict[g] == 1]
        zeros = [g for g in all_genes if labels_dict[g] == 0]
        random.seed(seed)
        undersamp = random.sample(zeros, len(ones))
        genes_bal = ones + undersamp

        # 80/20 split
        random.shuffle(genes_bal)
        split = int(0.8 * len(genes_bal))
        train_genes, test_genes = genes_bal[:split], genes_bal[split:]

        # dataloaders
        tr_ds = GraphBinaryDataset(graphs, labels_dict, train_genes)
        te_ds = GraphBinaryDataset(graphs, labels_dict, test_genes)
        tr_loader = DataLoader(tr_ds, batch_size=args.batch_size, shuffle=True,  collate_fn=collate_binary)
        te_loader = DataLoader(te_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_binary)

        # model init
        encoder = GraphEncoder(in_ch, args.hid_ch, args.out_ch, args.heads).to(device)
        encoder = load_encoder_weights(encoder, args.ckpt, args.method)
        model   = NodeBinaryClassifier(encoder, emb_dim=args.out_ch).to(device)
        opt     = torch.optim.AdamW(model.parameters(), lr=args.lr)
        loss_fn = nn.BCEWithLogitsLoss()

        best_acc, no_imp = 0.0, 0
        best_state = None

        for epoch in range(1, args.epochs+1):
            model.train()
            for data, labels, mask in tr_loader:
                data   = to_device(data, device)
                lab    = labels.view(-1).to(device)
                m      = mask.view(-1).to(device)

                opt.zero_grad()
                logits = model(data)
                loss   = loss_fn(logits[m], lab[m])
                loss.backward()
                opt.step()

            # evaluate
            model.eval()
            all_t, all_p = [], []
            with torch.no_grad():
                for data, labels, mask in te_loader:
                    data = to_device(data, device)
                    logits = model(data).cpu().numpy()
                    preds  = (logits > 0.0).astype(int)
                    lab    = labels.view(-1).numpy()
                    m      = mask.view(-1).numpy()
                    all_t.extend(lab[m]); all_p.extend(preds[m])

            acc = accuracy_score(all_t, all_p)
            if acc > best_acc:
                best_acc, best_state, no_imp = acc, {k:v.cpu() for k,v in model.state_dict().items()}, 0
            else:
                no_imp += 1
                if no_imp >= args.patience:
                    print(f"Seed {seed}: early stop at epoch {epoch}, best_acc={best_acc:.4f}")
                    break

        # compute F1 on test set with best model
        model.load_state_dict(best_state)
        model.eval()
        all_t, all_p = [], []
        with torch.no_grad():
            for data, labels, mask in te_loader:
                data = to_device(data, device)
                logits = model(data).cpu().numpy()
                preds  = (logits > 0.0).astype(int)
                lab    = labels.view(-1).numpy()
                m      = mask.view(-1).numpy()
                all_t.extend(lab[m]); all_p.extend(preds[m])
        f1 = f1_score(all_t, all_p)
        print(f"Seed {seed} → Acc: {best_acc:.4f}, F1: {f1:.4f}")
        test_accs.append(best_acc)
        test_f1s.append(f1)

        # --- シード結果の平均 ± 標準偏差 を表示 ---
        mean_acc, std_acc = np.mean(test_accs), np.std(test_accs)
        mean_f1,  std_f1  = np.mean(test_f1s),  np.std(test_f1s)
        print("\n===== 10‐seed CV Summary =====")
        print(f"Test Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")
        print(f"Test F1 Score : {mean_f1:.4f} ± {std_f1:.4f}")

# -----------------------------
# 7. Parser (just above main)
# -----------------------------
def parse_args():
    p = argparse.ArgumentParser(description="Binary gene classification finetuning")
    p.add_argument("--method",     type=str,  required=True,
                   choices=list(LOAD_CONFIG.keys()),
                   help="Pretraining method ('w_o_pretrain' = scratch)")
    p.add_argument("--ckpt",       type=str,  default=None,
                   help="Path to pretrained checkpoint (omit if w_o_pretrain)")
    p.add_argument("--data",       type=str,  required=True,
                   help="Path to pickled graphs (.pkl)")
    p.add_argument("--labels",     type=str,  required=True,
                   help="CSV with 'Gene_Index','Cancer_relation'")
    p.add_argument("--batch-size", type=int,  default=8)
    p.add_argument("--lr",         type=float,default=1e-3)
    p.add_argument("--epochs",     type=int,  default=500)
    p.add_argument("--patience",   type=int,  default=5)
    p.add_argument("--num-seeds",  type=int,  default=10)
    p.add_argument("--hid-ch",     type=int,  default=64)
    p.add_argument("--out-ch",     type=int,  default=64)
    p.add_argument("--heads",      type=int,  default=8)
    p.add_argument("--seed",       type=int,  default=42,
                   help="Seed for initial graph split")
    args = p.parse_args()
    if args.method != "w_o_pretrain" and args.ckpt is None:
        p.error("--ckpt is required unless --method w_o_pretrain")
    return args


# -----------------------------
# Main
# -----------------------------
if __name__ == "__main__":
    args = parse_args()
    graphs     = load_graphs(args.data, args.seed)
    labels_dict, all_genes = load_labels(args.labels)
    run_finetune(graphs, labels_dict, all_genes, args)
