#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import pickle
import random
import argparse

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool, TransformerConv
from sklearn.model_selection import KFold
from lifelines.utils import concordance_index

# -----------------------------
# 0. Device setup
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# -----------------------------
# 1. Data utilities
# -----------------------------
def load_data(graphs_path, meta_path, surv_csv):
    # 全グラフと患者IDリストを読み込み
    with open(graphs_path, "rb") as f:
        all_graphs = pickle.load(f)
    with open(meta_path, "rb") as f:
        all_patients = pickle.load(f)

    # 生存データ読み込み
    surv_df = pd.read_csv(surv_csv, index_col=0)

    # surv_df に存在する患者のみフィルタリング
    graphs, patients = [], []
    for g, pid in zip(all_graphs, all_patients):
        if pid in surv_df.index:
            graphs.append(g)
            patients.append(pid)

    # surv_df をフィルタ順に再並び替え
    surv_df = surv_df.loc[patients]
    assert list(surv_df.index) == patients

    return graphs, surv_df

def to_device(data, device):
    data.x = data.x.to(device)
    data.edge_index = data.edge_index.to(device)
    if data.edge_attr is not None:
        data.edge_attr = data.edge_attr.to(device)
        # adjust dims
        if data.edge_attr.dim() == 1:
            data.edge_attr = data.edge_attr.unsqueeze(-1)
        elif data.edge_attr.size(0) == 1 and data.edge_attr.size(1) != 1:
            data.edge_attr = data.edge_attr.transpose(0, 1)
    if hasattr(data, "batch") and data.batch is not None:
        data.batch = data.batch.to(device)
    return data

class HazardDataset(torch.utils.data.Dataset):
    def __init__(self, graphs, surv_df):
        self.graphs = graphs
        self.event  = torch.from_numpy(surv_df["OS"].values).float()
        self.time   = torch.from_numpy(surv_df["OS.time"].values).float()

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx], self.event[idx], self.time[idx]


# -----------------------------
# 2. Model definitions
# -----------------------------
class GraphEncoder(nn.Module):
    def __init__(self, in_ch, hid_ch, out_ch, heads):
        super().__init__()
        self.convs = nn.ModuleList([
            TransformerConv(in_ch,        hid_ch, heads=heads, edge_dim=1, dropout=0.1),
            TransformerConv(hid_ch*heads, hid_ch, heads=heads, edge_dim=1, dropout=0.1),
            TransformerConv(hid_ch*heads, hid_ch, heads=heads, edge_dim=1, dropout=0.1),
            TransformerConv(hid_ch*heads, hid_ch, heads=heads, edge_dim=1, dropout=0.1),
            TransformerConv(hid_ch*heads, out_ch,  heads=1,       edge_dim=1, dropout=0.1),
        ])

    def forward(self, x, edge_index, edge_attr):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index, edge_attr))
        return self.convs[-1](x, edge_index, edge_attr)

class HazardModel(nn.Module):
    def __init__(self, encoder: GraphEncoder, emb_dim: int):
        super().__init__()
        self.encoder = encoder
        hid_dim = emb_dim // 2
        self.head = nn.Sequential(
            nn.Linear(emb_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, 1)
        )

    def forward(self, data):
        z   = self.encoder(data.x, data.edge_index, data.edge_attr)
        emb = global_mean_pool(z, data.batch)
        return self.head(emb).squeeze(-1)


# -----------------------------
# 3. Loss function
# -----------------------------
def cox_ph_loss(risk, time, event):
    order = torch.argsort(time, descending=True)
    r     = risk[order]
    ev    = event[order]
    logcum = torch.logcumsumexp(r, dim=0)
    loss   = - torch.sum(ev * (r - logcum))
    return loss / ev.sum().clamp(min=1.0)


# -----------------------------
# 4. Checkpoint loader
# -----------------------------
LOAD_CONFIG = {
    "supgcl":       {"state_key": "model",              "prefix": "encoder."},
    "gae":          {"state_key": None,                 "prefix": "encoder."},
    "grace":        {"state_key": None,                 "prefix": "encoder."},
    "graphcl":      {"state_key": None,                 "prefix": "encoder."},
    "sgrl":         {"state_key": None,                 "prefix": "online_encoder.backbone."},
    "w_o_pretrain": None,
}

def load_encoder_weights(encoder, ckpt_path, method):
    if method == "w_o_pretrain":
        return encoder
    cfg = LOAD_CONFIG[method]
    full_state = torch.load(ckpt_path, map_location=device)
    raw = full_state[cfg["state_key"]] if cfg["state_key"] else full_state
    prefix = cfg["prefix"]
    enc_state = {
        k[len(prefix):]: v
        for k, v in raw.items()
        if k.startswith(prefix)
    }
    encoder.load_state_dict(enc_state, strict=False)
    return encoder


# -----------------------------
# 5. Training & evaluation
# -----------------------------
def run_finetune(graphs, surv_df, args):
    dataset = HazardDataset(graphs, surv_df)
    in_ch  = graphs[0].x.shape[1]
    encoder = GraphEncoder(in_ch, args.hid_ch, args.out_ch, args.heads).to(device)
    encoder = load_encoder_weights(encoder, args.ckpt, args.method)

    all_cidx, best_overall, best_state, best_info = [], 0.0, None, {}

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

    kf = KFold(n_splits=args.n_splits, shuffle=True, random_state=args.seed)
    idxs = np.arange(len(dataset))

    for fold, (tr_idx, va_idx) in enumerate(kf.split(idxs), 1):
        print(f"-- Fold {fold} --")
        # reload (or initialize) encoder each fold
        encoder = load_encoder_weights(GraphEncoder(in_ch, args.hid_ch, args.out_ch, args.heads).to(device),
                                       args.ckpt, args.method)
        model     = HazardModel(encoder, args.out_ch).to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

        tr_loader = DataLoader(torch.utils.data.Subset(dataset, tr_idx),
                               batch_size=args.batch_size, shuffle=True)
        va_loader = DataLoader(torch.utils.data.Subset(dataset, va_idx),
                               batch_size=args.batch_size, shuffle=False)

        best_cidx, no_imp, best_model_sd = 0.0, 0, None

        for epoch in range(1, args.epochs + 1):
            model.train()
            for g, e, t in tr_loader:
                g = to_device(g, device)
                e = e.to(device); t = t.to(device)
                optimizer.zero_grad()
                risk = model(g)
                loss = cox_ph_loss(risk, t, e)
                loss.backward()
                optimizer.step()

            model.eval()
            all_risk, all_e, all_t = [], [], []
            with torch.no_grad():
                for g, e, t in va_loader:
                    g = to_device(g, device)
                    all_risk.append(model(g).cpu())
                    all_e.append(e); all_t.append(t)
            risk = torch.cat(all_risk).numpy()
            e    = torch.cat(all_e).numpy()
            t    = torch.cat(all_t).numpy()

            cidx = concordance_index(t, -risk, e)
            if cidx > best_cidx:
                best_cidx = cidx
                no_imp    = 0
                best_model_sd = {k: v.cpu() for k, v in model.state_dict().items()}
            else:
                no_imp += 1
                if no_imp >= args.patience:
                    break

        print(f" Fold {fold} | Best C-index: {best_cidx:.4f}")
        all_cidx.append(best_cidx)
        if best_cidx > best_overall:
            best_overall = best_cidx
            best_state   = best_model_sd
            best_info    = {'fold': fold}

    mean_cidx, std_cidx = np.mean(all_cidx), np.std(all_cidx)
    print(f"\nAverage over {args.n_splits}-fold | C-index: {mean_cidx:.4f} ± {std_cidx:.4f}")


# -----------------------------
# 6. Argument parser
# -----------------------------
def parse_args():
    p = argparse.ArgumentParser(description="Breast cancer hazard finetuning")
    p.add_argument("--method",      type=str, required=True,
                   choices=list(LOAD_CONFIG.keys()),
                   help="Pretraining method ('w_o_pretrain' = train from scratch)")
    p.add_argument("--ckpt",        type=str, default=None,
                   help="Path to pretrained checkpoint (omit if w_o_pretrain)")
    p.add_argument("--data",        type=str, required=True,
                   help="Path to pickled graphs (.pkl)")
    p.add_argument("--meta",        type=str, required=True,
                   help="Path to pickled metadata (patient IDs list)")
    p.add_argument("--surv",        type=str, required=True,
                   help="Path to survival CSV")
    p.add_argument("--batch-size",  type=int, default=8)
    p.add_argument("--lr",          type=float, default=1e-3)
    p.add_argument("--epochs",      type=int, default=500)
    p.add_argument("--patience",    type=int, default=5)
    p.add_argument("--n-splits",    type=int, default=10)
    p.add_argument("--hid-ch",      type=int, default=64)
    p.add_argument("--out-ch",      type=int, default=64)
    p.add_argument("--heads",       type=int, default=8)
    p.add_argument("--seed",        type=int, default=0)
    if args.method != "w_o_pretrain" and args.ckpt is None:
        p.error("--ckpt is required unless --method w_o_pretrain")
    return args


# -----------------------------
# 7. Main
# -----------------------------
if __name__ == "__main__":
    args = parse_args()
    graphs, surv_df = load_data(args.data, args.meta, args.surv)
    run_finetune(graphs, surv_df, args)
