import argparse, os, random
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.datasets import WebKB, Actor
from torch_geometric.loader import DataLoader

from .utils import split_batch_to_graphs
from .utils import Uext_batch_from_tree_lists
from .coarsen import Make_tree_real2, HaarGOB_with_Sassign
from .models import NodeHaarUnpoolClassifier, Hetero_Graph_Attention_Layer
from .losses import loss_diversity_from_S, loss_reconstruction_from_lists

def set_seed(s=42):
    random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)

def load_dataset(name, root='./data'):
    """Load a TU dataset by name and ensure node features exist."""
    ds = TUDataset(root, name)
    if ds.num_features == 0:
        # Provide constant feature if missing
        ds.data.x = torch.ones(ds.data.num_nodes, 1)
    return ds

def main():
    parser = argparse.ArgumentParser(description="Graph Classification Training")
    parser.add_argument('--dataset', choices=['MUTAG','ENZYMES','NCI1','PROTEINS','IMDB-BINARY'],
                        default='MUTAG', help='TU dataset name')
    parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
    parser.add_argument('--batch', type=int, default=64, help='batch size')
    parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
    parser.add_argument('--hidden', type=int, default=64, help='hidden dimension')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
                        help='compute device')
    args = parser.parse_args()

    set_seed(args.seed)
    dataset = load_dataset(args.dataset)
    dataloader = DataLoader(dataset, batch_size=args.batch, shuffle=True)
    device = args.device

    

    set_seed(42)
    loader = load_dataset(args.dataset)

    data0 = next(iter(loader))
    in_dim = data0.num_features
    enc = Hetero_Graph_Attention_Layer(in_features=in_dim, out_features=in_dim, num_layers=1).to(args.device)

    num_classes = int(data0.y.max().item() + 1)
    model = NodeHaarUnpoolClassifier(in_dim=in_dim, hid_dim=args.hidden, num_classes=num_classes,
                                     max_K=1024, num_levels=args.levels-1).to(args.device)

    opt = Adam(list(model.parameters()) + list(enc.parameters()), lr=args.lr)

    for run in range(args.runs):
        for epoch in range(args.epochs):
            model.train(); enc.train()
            total_loss = 0.0

            for batch in loader:
                batch = batch.to(args.device)
                if batch.x is None:
                    batch.x = torch.ones(batch.num_nodes, 1, device=args.device)

                X_list, edge_index_list, y_list = split_batch_to_graphs(batch)
                (U_batch, eidx_batch, n_nodes_batch, n_edges_batch,
                         feats_batch, tree_batch, S_batch) = Uext_batch_from_tree_lists(
                            X_list, edge_index_list, enc,
                            levels=5, ratio=0.3, temp=0.1, tau=0.5
                        )
                for i in range(len(U_batch)):
                    logits = model(U_batch[i], feats_batch[i], tree_batch[i])
                    logits = logits.mean(dim=0)

                    y_i = y_list[i]
                    device = logits.device

                    # Core CE loss
                    L_ce = F.cross_entropy(logits, y_i)

                    # NEW: auxiliary losses
                    L_div = loss_diversity_from_S(S_batch[i], device=device)
                    L_rec = loss_reconstruction_from_treeG(tree_batch[i], device=device)

                    L_total = 0.8*L_ce + lambda_div * L_div + lambda_rec * L_rec

                    opt.zero_grad()
                    L_total.backward()
                    opt.step()

                    print(f"Graph {i}: CE={L_ce.item():.4f}  Div={L_div.item():.4f}  Rec={L_rec.item():.4f}  Total={L_total.item():.4f}")
                    total_loss += L_total.item()

if __name__ == "__main__":
    main()
