import argparse
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.loader import DataLoader

from models_with_filter import GCNWithImprovedFilter as GCNModel


def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_graphs = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        try:
            out = model(batch)
        except TypeError:
            out = model(batch.x, batch.edge_index, getattr(batch, 'batch', None))

        
        if out.size(0) != batch.y.size(0):
            from torch_geometric.nn import global_mean_pool
            out = global_mean_pool(out, batch.batch)

        log_probs = F.log_softmax(out, dim=1)
        target = batch.y
        if target.ndim > 1:
            target = target.squeeze()

        loss = F.nll_loss(log_probs, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch.num_graphs
        total_graphs += batch.num_graphs

    return total_loss / max(1, total_graphs)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            try:
                out = model(batch)
            except TypeError:
                out = model(batch.x, batch.edge_index, getattr(batch, 'batch', None))

            if out.size(0) != batch.y.size(0):
                from torch_geometric.nn import global_mean_pool
                out = global_mean_pool(out, batch.batch)

            pred = out.argmax(dim=1)
            target = batch.y
            if target.ndim > 1:
                target = target.squeeze()
            correct += int((pred == target).sum().item())
            total += batch.num_graphs
    return correct / total if total > 0 else 0.0


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--hidden_dim', type=int, default=64)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--runs', type=int, default=5)
    parser.add_argument('--use_causal_filter', action='store_true', default=True)
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    print(f"Device: {device}, use_causal_filter: {args.use_causal_filter}")

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    test_accs = []

    for run in range(args.runs):
        print(f"\n=== Run {run + 1} ===")

        dataset = TUDataset(root='./data/ENZYMES', name='ENZYMES', transform=NormalizeFeatures())
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        n = len(indices)
        n_train = int(n * 0.8)
        n_val = int(n * 0.1)
        train_idx = indices[:n_train]
        val_idx = indices[n_train:n_train + n_val]
        test_idx = indices[n_train + n_val:]

        train_dataset = dataset[train_idx]
        val_dataset = dataset[val_idx]
        test_dataset = dataset[test_idx]

        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

        num_features = dataset.num_node_features
        num_classes = dataset.num_classes

        filter_config = {
            'input': {'lambda_init': 1.0, 'decay_rate': 0.99, 'normalize': False},
            'hidden': {'lambda_init': 1.0, 'decay_rate': 0.99, 'normalize': False},
        }

        model = GCNModel(
            num_features,
            num_classes,
            hidden_channels=args.hidden_dim,
            use_causal_filter=args.use_causal_filter,
            filter_config=filter_config if args.use_causal_filter else None
        ).to(device)

        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        best_val = 0.0
        best_test = 0.0

        for epoch in range(1, args.epochs + 1):
            loss = train_epoch(model, train_loader, optimizer, device)
            train_acc = evaluate(model, train_loader, device)
            val_acc = evaluate(model, val_loader, device)
            test_acc = evaluate(model, test_loader, device)

            if args.use_causal_filter:
                try:
                    model.step_epoch()
                except Exception:
                    pass

            if val_acc > best_val:
                best_val = val_acc
                best_test = test_acc

            if epoch % 10 == 0 or epoch == 1:
                print(f'Epoch {epoch:03d}: Loss {loss:.4f} | Train {train_acc:.4f} | Val {val_acc:.4f} | Test {test_acc:.4f}')
                if args.use_causal_filter:
                    try:
                        print(model.get_filter_info())
                    except Exception:
                        pass

        test_accs.append(best_test * 100)
        print(f'Run {run + 1}: Best Val {best_val:.4f}, Test @best_val {best_test:.4f}')

    mean = np.mean(test_accs)
    std = np.std(test_accs)
    print(f'\n=== Final ===\nTest Accuracy: {mean:.2f} \u00b1 {std:.2f}%')
    print('Individual runs:', [f"{a:.2f}" for a in test_accs])


if __name__ == '__main__':
    main()
