import argparse
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_scipy_sparse_matrix

from models_with_filter import ChebyNetWithFilter as ChebyModel
from script import utility


def process_graph_data(data, device, gso_type):
    num_nodes = data.x.size(0)
    adj = to_scipy_sparse_matrix(data.edge_index, num_nodes=num_nodes)
    gso = utility.calc_gso(adj, gso_type)
    gso = utility.calc_chebynet_gso(gso)
    return utility.cnv_sparse_mat_to_coo_tensor(gso, device)


def train_epoch(model, loader, optimizer, device, gso_type):
    model.train()
    total_loss = 0.0
    total_graphs = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        batch_loss = 0.0
        
        for data_i in batch.to_data_list():
            gso_i = process_graph_data(data_i, device, gso_type)
            out_i = model(data_i.x, gso_i)

            
            graph_out_i = torch.mean(out_i, dim=0, keepdim=True)

            loss = F.nll_loss(graph_out_i, data_i.y)
            batch_loss += loss

        if batch.num_graphs > 0:
            batch_loss = batch_loss / batch.num_graphs
            batch_loss.backward()
            optimizer.step()
            total_loss += batch_loss.item() * batch.num_graphs
            total_graphs += batch.num_graphs

    return total_loss / max(1, total_graphs)


def evaluate(model, loader, device, gso_type):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            for data_i in batch.to_data_list():
                gso_i = process_graph_data(data_i, device, gso_type)
                out_i = model(data_i.x, gso_i)
                graph_out_i = torch.mean(out_i, dim=0, keepdim=True)
                pred = graph_out_i.argmax(dim=1)
                correct += (pred == data_i.y).sum().item()
            total += batch.num_graphs

    return correct / total if total > 0 else 0.0


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--lr', type=float, default=0.005)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--n_hid', type=int, default=64)
    parser.add_argument('--Ko', type=int, default=3)
    parser.add_argument('--Kl', type=int, default=2)
    parser.add_argument('--droprate', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--runs', type=int, default=5)
    parser.add_argument('--use_causal_filter', action='store_true', default=True)
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    print(f"Device: {device}, use_causal_filter: {args.use_causal_filter}")

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    test_accs = []

    for run in range(args.runs):
        print(f"\n=== Run {run + 1} ===")

        dataset = TUDataset(root='./data/ENZYMES', name='ENZYMES', transform=NormalizeFeatures())
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        n = len(indices)
        n_train = int(n * 0.8)
        n_val = int(n * 0.1)
        train_idx = indices[:n_train]
        val_idx = indices[n_train:n_train + n_val]
        test_idx = indices[n_train + n_val:]

        train_dataset = dataset[train_idx]
        val_dataset = dataset[val_idx]
        test_dataset = dataset[test_idx]

        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

        n_feat = dataset.num_node_features
        n_class = dataset.num_classes

        filter_config = {
            'input': {'lambda_init': 1.0, 'decay_rate': 0.99},
            'hidden': {'lambda_init': 1.0, 'decay_rate': 0.99},
        }

        model = ChebyModel(
            n_feat, args.n_hid, n_class, True, args.Ko, args.Kl, args.droprate,
            use_causal_filter=args.use_causal_filter,
            filter_config=filter_config if args.use_causal_filter else None,
            task='graph'
        ).to(device)

        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        best_val = 0.0
        best_test = 0.0

        for epoch in range(1, args.epochs + 1):
            loss = train_epoch(model, train_loader, optimizer, device, 'sym_norm_lap')
            train_acc = evaluate(model, train_loader, device, 'sym_norm_lap')
            val_acc = evaluate(model, val_loader, device, 'sym_norm_lap')
            test_acc = evaluate(model, test_loader, device, 'sym_norm_lap')

            if args.use_causal_filter:
                try:
                    model.step_epoch()
                except Exception:
                    pass

            if val_acc > best_val:
                best_val = val_acc
                best_test = test_acc

            if epoch % 10 == 0 or epoch == 1:
                print(f'Epoch {epoch:03d}: Loss {loss:.4f} | Train {train_acc:.4f} | Val {val_acc:.4f} | Test {test_acc:.4f}')
                if args.use_causal_filter:
                    try:
                        print(model.get_filter_info())
                    except Exception:
                        pass

        test_accs.append(best_test * 100)
        print(f'Run {run + 1}: Best Val {best_val:.4f}, Test @best_val {best_test:.4f}')

    mean = np.mean(test_accs)
    std = np.std(test_accs)
    print(f'\n=== Final ===\nTest Accuracy: {mean:.2f} \u00b1 {std:.2f}%')
    print('Individual runs:', [f"{a:.2f}" for a in test_accs])


if __name__ == '__main__':
    main()
