import argparse
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.loader import DataLoader

from models_with_improved_filter import GINWithImprovedFilter as GINModel


def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total += 1
    return total_loss / max(1, total)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
            total += data.y.size(0)
    return correct / total if total > 0 else 0.0


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--hidden_dim', type=int, default=64)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--runs', type=int, default=5)
    parser.add_argument('--use_causal_filter', action='store_true', default=False)
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    print(f"Device: {device}, use_improved_filter: {args.use_causal_filter}")

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    test_accs = []

    for run in range(args.runs):
        print(f"\n=== Run {run + 1} ===")

        dataset = TUDataset(root='./data/ENZYMES', name='ENZYMES', transform=NormalizeFeatures())
        
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        n = len(indices)
        n_train = int(n * 0.8)
        n_val = int(n * 0.1)
        train_idx = indices[:n_train]
        val_idx = indices[n_train:n_train + n_val]
        test_idx = indices[n_train + n_val:]

        train_dataset = dataset[train_idx]
        val_dataset = dataset[val_idx]
        test_dataset = dataset[test_idx]

        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

        num_features = dataset.num_node_features
        num_classes = dataset.num_classes

        filter_config = {
            'input': {'lambda_init': 10.0, 'decay_rate': 0.95, 'normalize': False},
            'layer_0': {'lambda_init': 10.0, 'decay_rate': 0.95, 'normalize': False},
            'layer_1': {'lambda_init': 10.0, 'decay_rate': 0.95, 'normalize': False},
        }

        model = GINModel(
            num_features,
            num_classes,
            hidden_dim=args.hidden_dim,
            num_layers=args.num_layers,
            task='graph',
            use_causal_filter=args.use_causal_filter,
            filter_config=filter_config if args.use_causal_filter else None
        ).to(device)

        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        best_val = 0.0
        best_test = 0.0

        for epoch in range(1, args.epochs + 1):
            loss = train_epoch(model, train_loader, optimizer, device)
            train_acc = evaluate(model, train_loader, device)
            val_acc = evaluate(model, val_loader, device)
            test_acc = evaluate(model, test_loader, device)

            if args.use_causal_filter:
                
                model.step_epoch()

            if val_acc > best_val:
                best_val = val_acc
                best_test = test_acc
                

            if epoch % 10 == 0 or epoch == 1:
                print(f'Epoch {epoch:03d}: Loss {loss:.4f} | Train {train_acc:.4f} | Val {val_acc:.4f} | Test {test_acc:.4f}')
                if args.use_causal_filter:
                    try:
                        print(model.get_filter_info())
                    except Exception:
                        pass

        test_accs.append(best_test * 100)
        print(f'Run {run + 1}: Best Val {best_val:.4f}, Test @best_val {best_test:.4f}')

    mean = np.mean(test_accs)
    std = np.std(test_accs)
    print(f'\n=== Final ===\nTest Accuracy: {mean:.2f} ± {std:.2f}%')
    print('Individual runs:', [f"{a:.2f}" for a in test_accs])


if __name__ == '__main__':
    main()
