#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import pickle
import random
import argparse

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool, TransformerConv
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import KFold

# -----------------------------
# 0. Device setup
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# -----------------------------
# 1. Load & prepare data
# -----------------------------
def load_data(graphs_path, subtype_csv):
    with open(graphs_path, "rb") as f:
        graphs = pickle.load(f)
    subtype_df = pd.read_csv(subtype_csv, index_col=0)
    mapping = {"LumA": 0, "LumB": 1, "Basal": 2, "Her2": 3, "Normal": 4}
    labels = subtype_df["subtype"].map(mapping).values
    assert len(graphs) == len(labels), "Graphs and labels must align in same order"
    return graphs, labels


# -----------------------------
# 2. Device transfer
# -----------------------------
def to_device(data, device):
    data.x = data.x.to(device)
    data.edge_index = data.edge_index.to(device)
    if data.edge_attr is not None:
        data.edge_attr = data.edge_attr.to(device)
        if data.edge_attr.dim() == 1:
            data.edge_attr = data.edge_attr.unsqueeze(-1)
        elif data.edge_attr.size(0) == 1 and data.edge_attr.size(1) != 1:
            data.edge_attr = data.edge_attr.transpose(0, 1)
    if hasattr(data, "batch") and data.batch is not None:
        data.batch = data.batch.to(device)
    return data


# -----------------------------
# 3. Dataset & Model
# -----------------------------
class SubtypeDataset(torch.utils.data.Dataset):
    def __init__(self, graphs, labels):
        self.graphs = graphs
        self.labels = labels
    def __len__(self):
        return len(self.graphs)
    def __getitem__(self, idx):
        return self.graphs[idx], torch.tensor(self.labels[idx], dtype=torch.long)


class GraphEncoder(nn.Module):
    def __init__(self, in_ch, hid_ch, out_ch, heads):
        super().__init__()
        self.conv1 = TransformerConv(in_ch,  hid_ch, heads=heads, edge_dim=1, dropout=0.1)
        self.conv2 = TransformerConv(hid_ch*heads, hid_ch, heads=heads, edge_dim=1, dropout=0.1)
        self.conv3 = TransformerConv(hid_ch*heads, hid_ch, heads=heads, edge_dim=1, dropout=0.1)
        self.conv4 = TransformerConv(hid_ch*heads, hid_ch, heads=heads, edge_dim=1, dropout=0.1)
        self.conv5 = TransformerConv(hid_ch*heads, out_ch,  heads=1,       edge_dim=1, dropout=0.1)

    def forward(self, x, edge_index, edge_attr):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.relu(self.conv3(x, edge_index, edge_attr))
        x = F.relu(self.conv4(x, edge_index, edge_attr))
        return self.conv5(x, edge_index, edge_attr)


class SubtypeModel(nn.Module):
    def __init__(self, encoder, emb_dim, num_classes=5):
        super().__init__()
        self.encoder = encoder
        hid_dim = emb_dim // 2
        self.head = nn.Sequential(
            nn.Linear(emb_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, num_classes)
        )

    def forward(self, data):
        z   = self.encoder(data.x, data.edge_index, data.edge_attr)
        emb = global_mean_pool(z, data.batch)
        return self.head(emb)


# -----------------------------
# 4. Checkpoint loader config
# -----------------------------
LOAD_CONFIG = {
    "supgcl":    {"state_key": "model",              "prefix": "encoder."},
    "gae":       {"state_key": None,                 "prefix": "encoder."},
    "grace":     {"state_key": None,                 "prefix": "encoder."},
    "graphcl":   {"state_key": None,                 "prefix": "encoder."},
    "sgrl":      {"state_key": None,                 "prefix": "online_encoder.backbone."},
    "w_o_pretrain": None,
}

def load_encoder_weights(encoder, ckpt_path, method):
    if method == "w_o_pretrain":
        return encoder
    cfg = LOAD_CONFIG[method]
    full_state = torch.load(ckpt_path, map_location=device)
    raw = full_state[cfg["state_key"]] if cfg["state_key"] else full_state
    prefix = cfg["prefix"]
    enc_state = {
        k[len(prefix):]: v
        for k, v in raw.items()
        if k.startswith(prefix)
    }
    encoder.load_state_dict(enc_state, strict=False)
    return encoder


# -----------------------------
# 5. Training & evaluation
# -----------------------------
def run_finetune(graphs, labels, args):
    dataset = SubtypeDataset(graphs, labels)
    in_ch  = graphs[0].x.shape[1]
    encoder = GraphEncoder(in_ch, args.hid_ch, args.out_ch, args.heads).to(device)
    encoder = load_encoder_weights(encoder, args.ckpt, args.method)

    all_acc, all_f1 = [], []
    best_overall_f1, best_state, best_info = 0.0, None, {}

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

    kf = KFold(n_splits=args.n_splits, shuffle=True, random_state=args.seed)
    idxs = np.arange(len(dataset))

    for fold, (tr_idx, va_idx) in enumerate(kf.split(idxs), 1):
        print(f"-- Fold {fold} --")
        encoder = load_encoder_weights(encoder, args.ckpt, args.method)
        model = SubtypeModel(encoder, args.out_ch).to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
        loss_fn   = nn.CrossEntropyLoss()

        tr_loader = DataLoader(torch.utils.data.Subset(dataset, tr_idx),
                              batch_size=args.batch_size, shuffle=True)
        va_loader = DataLoader(torch.utils.data.Subset(dataset, va_idx),
                              batch_size=args.batch_size, shuffle=False)

        best_f1, no_imp, best_metrics = 0.0, 0, {}

        for epoch in range(1, args.epochs + 1):
            model.train()
            for g, label in tr_loader:
                g = to_device(g, device)
                label = label.to(device)
                optimizer.zero_grad()
                logits = model(g)
                loss = loss_fn(logits, label)
                loss.backward()
                optimizer.step()

            model.eval()
            all_logits, all_labels = [], []
            with torch.no_grad():
                for g, label in va_loader:
                    g = to_device(g, device)
                    all_logits.append(model(g).cpu())
                    all_labels.append(label)
            logits = torch.cat(all_logits).numpy()
            labels = torch.cat(all_labels).numpy()
            preds  = np.argmax(logits, axis=1)

            acc = accuracy_score(labels, preds)
            f1  = f1_score(labels, preds, average="macro")

            if f1 > best_f1:
                best_f1 = f1
                best_metrics = {"acc": acc, "f1": f1}
                best_state = {k: v.cpu() for k, v in model.state_dict().items()}
                no_imp = 0
            else:
                no_imp += 1
                if no_imp >= args.patience:
                    break

        print(f" Fold {fold} | Acc: {best_metrics['acc']:.4f} | Macro F1: {best_metrics['f1']:.4f}")
        all_acc.append(best_metrics["acc"])
        all_f1.append(best_metrics["f1"])

        if best_f1 > best_overall_f1:
            best_overall_f1 = best_f1
            best_info = {"fold": fold}

    mean_acc, std_acc = np.mean(all_acc), np.std(all_acc)
    mean_f1,  std_f1  = np.mean(all_f1),  np.std(all_f1)
    print(f"\nAverage over {args.n_splits}-fold | Acc: {mean_acc:.4f} ± {std_acc:.4f} | Macro F1: {mean_f1:.4f} ± {std_f1:.4f}")


# -----------------------------
# 6. Argument parser
# -----------------------------
def parse_args():
    p = argparse.ArgumentParser(description="Breast cancer subtype finetuning")
    p.add_argument("--method",      type=str, required=True,
                   choices=list(LOAD_CONFIG.keys()),
                   help="Pretraining method ('w_o_pretrain' = train from scratch)")
    p.add_argument("--ckpt",        type=str, default=None,
                   help="Path to pretrained checkpoint (omit if w_o_pretrain)")
    p.add_argument("--data",        type=str, required=True,
                   help="Path to pickled graphs (.pkl)")
    p.add_argument("--subtype",     type=str, required=True,
                   help="Path to subtype CSV (index_col=0, column 'subtype')")
    p.add_argument("--batch-size",  type=int, default=8)
    p.add_argument("--lr",          type=float, default=1e-3)
    p.add_argument("--epochs",      type=int, default=500)
    p.add_argument("--patience",    type=int, default=5)
    p.add_argument("--n-splits",    type=int, default=10)
    p.add_argument("--hid-ch",      type=int, default=64)
    p.add_argument("--out-ch",      type=int, default=64)
    p.add_argument("--heads",       type=int, default=8)
    p.add_argument("--seed",        type=int, default=0)
    args = p.parse_args()
    if args.method != "w_o_pretrain" and args.ckpt is None:
        p.error("--ckpt is required unless --method w_o_pretrain")
    return args


# -----------------------------
# Main
# -----------------------------
if __name__ == "__main__":
    args = parse_args()
    graphs, labels = load_data(args.data, args.subtype)
    run_finetune(graphs, labels, args)
