import argparse
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from types import SimpleNamespace
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.loader import DataLoader

from model import CaNet


def train_epoch(model, loader, optimizer, device, criterion, args):
    model.train()
    total_loss = 0.0
    total_graphs = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        
        out, reg = model(batch.x, batch.edge_index, batch=batch.batch, training=True)

        target = batch.y.view(-1)
        loss_sup = criterion(out, target)
        loss = loss_sup + args.lamda * reg
        loss.backward()
        optimizer.step()

        total_loss += float(loss.item()) * batch.num_graphs
        total_graphs += batch.num_graphs

    return total_loss / max(1, total_graphs)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch=batch.batch, training=False)
            
            pred = out.argmax(dim=1)
            target = batch.y.view(-1)
            correct += int((pred == target).sum().item())
            total += batch.num_graphs
    return correct / total if total > 0 else 0.0


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--hidden_channels', type=int, default=64)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--K', type=int, default=3)
    parser.add_argument('--backbone_type', type=str, default='gcn')
    parser.add_argument('--variant', action='store_true', default=False)
    parser.add_argument('--tau', type=float, default=1.0)
    parser.add_argument('--env_type', type=str, default='node')
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--runs', type=int, default=5)
    parser.add_argument('--use_causal_filter', action='store_true', default=True)
    parser.add_argument('--filter_lambda_init', type=float, default=1.0)
    parser.add_argument('--filter_lambda_min', type=float, default=-2.0)
    parser.add_argument('--filter_decay', type=float, default=0.99)
    parser.add_argument('--filter_temp', type=float, default=1.0)
    parser.add_argument('--filter_residual', type=float, default=0.2)
    parser.add_argument('--lamda', type=float, default=1.0, help='weight for reg loss from CaNet')
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    print(f"Device: {device}, use_causal_filter: {args.use_causal_filter}")

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    test_accs = []

    for run in range(args.runs):
        print(f"\n=== Run {run + 1} ===")

        dataset = TUDataset(root='./data/ENZYMES', name='ENZYMES', transform=NormalizeFeatures())
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        n = len(indices)
        n_train = int(n * 0.8)
        n_val = int(n * 0.1)
        train_idx = indices[:n_train]
        val_idx = indices[n_train:n_train + n_val]
        test_idx = indices[n_train + n_val:]

        train_dataset = dataset[train_idx]
        val_dataset = dataset[val_idx]
        test_dataset = dataset[test_idx]

        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

        d = dataset.num_node_features
        c = dataset.num_classes

        
        ca_args = SimpleNamespace()
        ca_args.dropout = args.dropout
        ca_args.num_layers = args.num_layers
        ca_args.tau = args.tau
        ca_args.env_type = args.env_type
        ca_args.use_causal_filter = args.use_causal_filter
        ca_args.filter_lambda_init = args.filter_lambda_init
        ca_args.filter_lambda_min = args.filter_lambda_min
        ca_args.filter_decay = args.filter_decay
        ca_args.filter_temp = args.filter_temp
        ca_args.filter_residual = args.filter_residual
        ca_args.hidden_channels = args.hidden_channels
        ca_args.K = args.K
        ca_args.backbone_type = args.backbone_type
        ca_args.variant = args.variant
        ca_args.lamda = args.lamda

        model = CaNet(d, c, ca_args, device).to(device)
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        criterion = nn.CrossEntropyLoss()

        best_val = 0.0
        best_test = 0.0

        for epoch in range(1, args.epochs + 1):
            loss = train_epoch(model, train_loader, optimizer, device, criterion, ca_args)
            train_acc = evaluate(model, train_loader, device)
            val_acc = evaluate(model, val_loader, device)
            test_acc = evaluate(model, test_loader, device)

            if args.use_causal_filter:
                try:
                    model.step_epoch()
                except Exception:
                    pass

            if val_acc > best_val:
                best_val = val_acc
                best_test = test_acc

            if epoch % 10 == 0 or epoch == 1:
                print(f'Epoch {epoch:03d}: Loss {loss:.4f} | Train {train_acc:.4f} | Val {val_acc:.4f} | Test {test_acc:.4f}')
                if args.use_causal_filter:
                    try:
                        print(model.get_filter_info())
                    except Exception:
                        pass

        test_accs.append(best_test * 100)
        print(f'Run {run + 1}: Best Val {best_val:.4f}, Test @best_val {best_test:.4f}')

    mean = np.mean(test_accs)
    std = np.std(test_accs)
    print(f'\n=== Final ===\nTest Accuracy: {mean:.2f} \u00b1 {std:.2f}%')
    print('Individual runs:', [f"{a:.2f}" for a in test_accs])


if __name__ == '__main__':
    main()
