


import argparse
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, InMemoryDataset, DataLoader
from models_with_filter import GCNWithImprovedFilter 


class PaperNodeDataset(InMemoryDataset):
    def __init__(self, npy_path, transform=None):
        super(PaperNodeDataset, self).__init__(None, transform)
        self.data, self.slices = self.load_npy(npy_path)
        self._num_features = self.data.x.size(1) if self.data.x is not None else 0

    @property
    def num_features(self):
        return self._num_features

    def load_npy(self, npy_path):
        raw = np.load(npy_path, allow_pickle=True)
        data_dict = raw.item() if hasattr(raw, 'item') else raw

        edge_indices  = data_dict['edge_index']
        features_list = data_dict['features']
        role_ids_list = data_dict['role_id']

        data_list = []
        for edges, features, role_ids in zip(edge_indices, features_list, role_ids_list):
            e = torch.tensor(edges, dtype=torch.long)
            if e.ndim == 2 and e.size(0) != 2:
                e = e.t().contiguous()
            e = e.contiguous()

            x = torch.tensor(features, dtype=torch.float)
            y = torch.tensor(role_ids, dtype=torch.long)

            data_list.append(Data(x=x, edge_index=e, y=y))

        return self.collate(data_list)


def train(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        pred = out.argmax(dim=1)
        total_correct += (pred == data.y).sum().item()
        total_samples += data.y.size(0)

    return total_loss / len(loader), total_correct / total_samples


def evaluate(model, loader, device):
    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            pred = out.argmax(dim=1)
            total_correct += (pred == data.y).sum().item()
            total_samples += data.y.size(0)
    return total_correct / total_samples


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--hidden_dim', type=int, default=64)
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--use_causal_filter', action='store_true', default=True, help='Whether to use causal filter module')
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    print(f"Using device: {device}, Causal filter module: {'Enabled' if args.use_causal_filter else 'Disabled'}")

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    for i in ['casual_1_3']:
        print(f'\nCurrent experiment: {i}')
        
        train_dataset = PaperNodeDataset(f"./data/paper/train_{i}.npy")
        val_dataset   = PaperNodeDataset(f"./data/paper/val_{i}.npy")
        test_dataset  = PaperNodeDataset(f"./data/paper/test_{i}.npy")

        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        val_loader   = DataLoader(val_dataset, batch_size=64, shuffle=False)
        test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False)

        num_features = train_dataset.num_features
        num_classes  = int(train_dataset.data.y.max().item()) + 1

        filter_config = {
            'input': {'lambda_init': 1.0, 'decay_rate': 0.99},
            'hidden': {'lambda_init': 1.0, 'decay_rate': 0.99},
        }

        model = GCNWithImprovedFilter(
            num_features,
            num_classes,
            hidden_channels=args.hidden_dim,
            use_causal_filter=args.use_causal_filter,
            filter_config=filter_config if args.use_causal_filter else None,
            task='node'
        ).to(device)

        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        best_val_acc = 0
        final_train_accuracies = []
        final_test_accuracies = []

        for epoch in range(1, args.epochs + 1):
            loss, train_acc = train(model, train_loader, optimizer, device)
            val_acc = evaluate(model, val_loader, device)
            test_acc = evaluate(model, test_loader, device)

            final_train_accuracies.append(train_acc * 100)
            final_test_accuracies.append(test_acc * 100)

            if args.use_causal_filter:
                model.step_epoch()

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                
            
            if epoch % 10 == 0:
                print(f'Epoch {epoch:03d}: Loss {loss:.4f} | Train Acc {train_acc:.4f} | Val Acc {val_acc:.4f} | Test Acc {test_acc:.4f}')
                if args.use_causal_filter:
                    print(model.get_filter_info())

        final_train_mean = np.mean(final_train_accuracies)
        final_train_std = np.std(final_train_accuracies)
        final_test_mean = np.mean(final_test_accuracies)
        final_test_std = np.std(final_test_accuracies)

        print(f"Train Acc  {final_train_mean:.2f} ± {final_train_std:.2f}  |  Test Acc  {final_test_mean:.2f} ± {final_test_std:.2f}")
