import os
import torch
import dgl
from sklearn.model_selection import train_test_split

from gad3.args import parse_args
from gad3.utils import set_seed
from gad3.binning import bin_features
from gad3.data import create_dataloaders
from gad3.simgraph import build_similarity_graph_from_leaves
from gad3.overlap import compute_neighbor_leaf_overlap
from gad3.models.base import BaseGraphSAGE
from gad3.models.booster import BoosterGraphSAGE
from gad3.models.stage3 import Stage3MLP
from gad3.training.base import train_base
from gad3.training.booster import train_booster, logits_after_stage12
from gad3.training.stage3 import train_stage3
from gad3.testing import test_all


def main():
    set_seed()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args = parse_args()

    graph_path = os.path.join("./datasets", args.dataset)
    train_ratio = args.train_ratio

    if os.path.exists(graph_path):
        graphs, _ = dgl.load_graphs(graph_path)
        graph = graphs[0]
    else:
        raise FileNotFoundError(f"Graph path {graph_path} does not exist.")

    if not isinstance(graph, dgl.DGLGraph):
        graph = graph.to_heterogeneous()

    labels = graph.ndata['label']
    all_indices = torch.arange(graph.number_of_nodes())
    # 先 50% 测试，再从剩余 50% 中划分 10% 验证
    test_idx, temp_idx, test_labels, temp_labels = train_test_split(
        all_indices.numpy(),
        labels.cpu().numpy(),
        stratify=labels.cpu().numpy(),
        test_size=train_ratio,
        random_state=42
    )
    train_idx, val_idx, train_labels, val_labels = train_test_split(
        temp_idx,
        temp_labels,
        stratify=temp_labels,
        test_size=0.1,
        random_state=42
    )

    train_idx = torch.tensor(train_idx)
    val_idx = torch.tensor(val_idx)
    test_idx = torch.tensor(test_idx)

    print("train_num:", len(train_idx))
    print("test_num:", len(test_idx))

    # ===== Stage-0：XGBoost 叶子特征 =====
    graph, tree_model, bin_count = bin_features(graph, train_idx)
    print("bin_count:", bin_count)

    # ===== DataLoaders =====
    train_loader, val_loader, test_loader = create_dataloaders(
        graph, train_idx, val_idx, test_idx, batch_size=1024
    )

    # ===== 相似图（Stage-2 用）=====
    use_idf = not args.no_sim_idf
    max_per_leaf = args.sim_max_per_leaf if args.sim_max_per_leaf >= 0 else None
    g_sim, _ = build_similarity_graph_from_leaves(
        graph.ndata['binned_feature'],
        top_k=args.sim_topk, use_idf=use_idf, sym=args.sim_sym,
        max_per_leaf=max_per_leaf if max_per_leaf is not None else 500
    )
    print(f"[Similarity Graph] edges={g_sim.num_edges()} nodes={g_sim.num_nodes()}")

    # ===== Stage-3 特征：一跳邻域叶共现向量 =====
    leaf_nb_overlap = compute_neighbor_leaf_overlap(
        graph, graph.ndata['binned_feature'], undirected=True
    )
    graph.ndata['leaf_nb_overlap'] = leaf_nb_overlap  # FloatTensor [N,T]

    # ===== Stage-1 =====
    base_model = BaseGraphSAGE(args.in_feats, args.hidden_feats, 1, bin_count).to(device)
    base_model = train_base(base_model, graph, train_loader, val_loader, args.epochs_base, device)

    # ===== Stage-2 =====
    booster2 = BoosterGraphSAGE(args.in_feats, args.hidden_feats, 1, bin_count).to(device)
    booster2, base_logits_all = train_booster(
        booster2, base_model, graph, g_sim,
        train_loader, val_loader, args.epochs_boost2, device,
        eta=args.boost2_eta,
        boost_mode=args.boost2_mode,
        boost_weight_pos=args.boost2_weight_pos,
        boost_weight_neg=args.boost2_weight_neg,
        boost_gamma=args.boost2_gamma
    )

    # 固定 Stage-1+2 的 logits（作为 Stage-3 的“基线”）
    base12_logits_fixed = logits_after_stage12(
        base_model, booster2, graph, g_sim, device, eta1=args.boost2_eta
    )

    # ===== Stage-3 =====
    in_dim_s3 = leaf_nb_overlap.shape[1]  # = num_trees
    stage3 = Stage3MLP(in_dim=in_dim_s3, hidden_dim=args.boost3_hidden, dropout=args.boost3_dropout).to(device)
    stage3 = train_stage3(
        stage3, graph, graph.ndata['leaf_nb_overlap'],
        train_loader, val_loader, args.epochs_boost3, device,
        base12_logits_fixed=base12_logits_fixed,
        eta2=args.boost3_eta,
        focus_mode=args.boost3_focus,
        weight_fn=args.boost3_weight_fn,
        weight_fp=args.boost3_weight_fp,
        gamma=args.boost3_gamma,
        log1p=bool(args.boost3_log1p)
    )

    # ===== Test：Base / Base+S2 / Base+S2+S3 =====
    test_all(
        base_model, booster2, stage3,
        graph, g_sim, graph.ndata['leaf_nb_overlap'],
        test_loader, device,
        eta1=args.boost2_eta, eta2=args.boost3_eta,
        log1p=bool(args.boost3_log1p)
    )


if __name__ == "__main__":
    main()
