import argparse
import os
import numpy as np
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from data_handling import get_planetoid_data
from models import GCN_SSM

def set_seed(seed: int) -> None:
    """Set random seeds for reproducibility."""
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    np.random.seed(seed)

def load_data(dataset_name: str, device: torch.device):
    """Load Planetoid dataset and move to device."""
    if dataset_name != "Cora":
        raise ValueError(f"Dataset '{dataset_name}' not recognized.")
    dataset = get_planetoid_data(dataset_name=dataset_name)
    data = dataset[0].to(device)
    num_classes = 7  # Cora has 7 classes
    return data, num_classes

def build_model(args, num_feats: int, num_classes: int, device: torch.device) -> torch.nn.Module:
    """Instantiate and return the model."""
    if args.model != 'GCN_SSM':
        raise ValueError(f"Model '{args.model}' not supported.")
    model = GCN_SSM(
        nfeat=num_feats,
        nhid=args.nhid,
        nclass=num_classes,
        conv_func=args.model,
        nlayers=args.nlayers,
        bnorm=args.bnorm,
        lin=args.lin,
        shared=args.shared,
        dyn=args.dyn,
        gamma_a=args.dyn_gamma_a,
        device=device
    )
    return model.to(device)

def evaluate(model: torch.nn.Module, data, mask: str) -> float:
    """Evaluate accuracy on a given mask (train/val/test)."""
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)
    mask_tensor = getattr(data, f"{mask}_mask")
    correct = pred[mask_tensor] == data.y[mask_tensor]
    return correct.sum().item() / mask_tensor.sum().item()

def train_epoch(model, data, optimizer, loss_fn) -> float:
    """Run one training epoch."""
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()


def parse_args():
    parser = argparse.ArgumentParser(description="Train GCN_SSM on Planetoid data")
    parser.add_argument("--nhid", type=int, default=128, help="# of hidden features")
    parser.add_argument("--nlayers", type=int, default=2, help="# of layers")
    parser.add_argument("--epochs", type=int, default=200, help="Max epochs")
    parser.add_argument(
        "--device", type=str, default="cuda", choices=["cpu", "cuda"],
        help="Computing device; 'cuda' will use GPU if available"
    )
    parser.add_argument("--lr", type=float, default=0.01, help="Learning rate")
    parser.add_argument("--reduce_factor", type=float, default=0.5, help="LR scheduler factor")
    parser.add_argument("--seed", type=int, default=1112, help="Random seed")
    parser.add_argument("--model", type=str, default="GCN_SSM", help="Model architecture")
    parser.add_argument("--dataset", type=str, default="Cora", help="Dataset: Cora")
    parser.add_argument("--out_dir", type=str, default="res", help="Output directory")
    parser.add_argument("--bnorm", action="store_true", help="Enable batchnorm")
    parser.add_argument("--lin", action="store_true", help="Enable linearity")
    parser.add_argument("--shared", action="store_true", help="Share layers")
    parser.add_argument("--dyn", action="store_true", help="Enable dynamics in shared layers")
    parser.add_argument("--dyn_gamma_a", type=float, default=1.0, help="Gamma A for dynamics")
    return parser.parse_args()


def main():
    args = parse_args()
    device = torch.device("cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    set_seed(args.seed)
    data, num_classes = load_data(args.dataset, device)
    num_feats = data.num_features

    model = build_model(args, num_feats, num_classes, device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=args.reduce_factor
    )
    loss_fn = torch.nn.CrossEntropyLoss()

    os.makedirs(args.out_dir, exist_ok=True)

    best_val, best_test = 0.0, 0.0
    for epoch in range(1, args.epochs + 1):
        loss = train_epoch(model, data, optimizer, loss_fn)
        val_acc = evaluate(model, data, mask="val")
        test_acc = evaluate(model, data, mask="test")
    
        scheduler.step(val_acc)
        curr_lr = optimizer.param_groups[0]["lr"]

        print(
            f"Epoch {epoch:03d} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f}"
            f" | Test Acc: {test_acc:.4f} | LR: {curr_lr:.2e}"
        )

        if val_acc > best_val:
            best_val, best_test = val_acc, test_acc
        if curr_lr < 1e-9:
            print("LR below threshold; stopping early.")
            break

    print("----- Finished Training -----")
    print(f"Best Val Acc: {best_val:.4f} | Corresponding Test Acc: {best_test:.4f}")


if __name__ == "__main__":
    main()