# train_hpc.py
import os
import torch
import numpy as np
import argparse
from sklearn.metrics import f1_score, roc_auc_score

from hpc_utils import create_line_graph, get_relational_spectral_encoding, get_topological_motif_encoding
from hpc_sgt import HPC_SGT
from common import DATA_EMB_DIC

parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, default='ml-1m-1')
parser.add_argument('--emb_size', type=int, default=32)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--gamma1', type=float, default=0.1, help='LVO loss weight') # [cite: 177]
parser.add_argument('--gamma2', type=float, default=0.1, help='Co-training loss weight') # [cite: 177]
args = parser.parse_args()

def run():
    # 1. 加载数据
    train_edgelist, val_edgelist, test_edgelist = load_data(args.dataset_name)
    set_a_num, set_b_num = DATA_EMB_DIC[args.dataset_name]
    
    print("Creating line graph and structural encodings...")
    line_adj, edge_signs, edge_to_idx = create_line_graph(train_edgelist, set_a_num, set_b_num)
    
    rse_bias = get_relational_spectral_encoding(line_adj, edge_signs, k=args.emb_size // 4)
    
    tme_paths = get_topological_motif_encoding(line_adj, max_path_len=3)
    
    num_line_nodes = len(train_edgelist)
    model = HPC_SGT(n=set_a_num, m=set_b_num, num_line_nodes=num_line_nodes, embed_size=args.emb_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    labels = edge_signs
    minority_class = -1 if torch.sum(labels == -1) < torch.sum(labels == 1) else 1
    minority_mask = (labels == minority_class)
    idx_to_edge = {v: k for k, v in edge_to_idx.items()}

    for epoch in range(args.epochs):
        model.train()
        optimizer.zero_grad()
        
        h_b, h_l = model(idx_to_edge, rse_bias, tme_paths)
        
        loss, logits = model.loss(h_b, h_l, labels, minority_mask,
                                  gamma1=args.gamma1, gamma2=args.gamma2,
                                  beta_vae=1.0, beta1=0.1, beta2=0.1, tau=0.1)
        
        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss.item():.4f}")
        
        if (epoch + 1) % 10 == 0:
            model.eval()
            with torch.no_grad():
                pred_probs = torch.sigmoid(logits).numpy()
                true_labels = ((labels + 1) / 2).numpy()
                
                auc = roc_auc_score(true_labels, pred_probs)
                preds = (pred_probs > 0.5).astype(int)
                macro_f1 = f1_score(true_labels, preds, average='macro')
                print(f"  Eval - AUC: {auc:.4f}, Macro-F1: {macro_f1:.4f}")


if __name__ == "__main__":
    run()