

import argparse
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, InMemoryDataset, DataLoader
from torch_geometric.utils import degree
from models_with_improved_filter import GINWithImprovedFilter as GINModel  


class PaperNodeDataset(InMemoryDataset):
    def __init__(self, npy_path, transform=None):
        super(PaperNodeDataset, self).__init__(None, transform)
        self.data, self.slices = self.load_npy(npy_path)
        self._num_features = self.data.x.size(1)

    @property
    def num_features(self):
        return self._num_features

    def load_npy(self, npy_path):
        import torch
        from torch_geometric.data import Data

        raw = np.load(npy_path, allow_pickle=True)
        data_dict = raw.item() if hasattr(raw, 'item') else raw

        edge_indices  = data_dict['edge_index']
        features_list = data_dict['features']
        role_ids_list = data_dict['role_id']    

        data_list = []
        for edges, features, role_ids in zip(edge_indices,
                                             features_list,
                                             role_ids_list):
            e = torch.tensor(edges, dtype=torch.long)
            if e.ndim == 2 and e.size(0) != 2:
                e = e.t().contiguous()
            e = e.contiguous()

            x = torch.tensor(features, dtype=torch.float)
            y = torch.tensor(role_ids, dtype=torch.long)        

            data_list.append(Data(x=x, edge_index=e, y=y))

        return self.collate(data_list)


def train(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)  
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        pred = out.argmax(dim=1)
        total_correct += (pred == data.y).sum().item()
        total_samples += data.y.size(0)

    return total_loss / len(loader), total_correct / total_samples


def evaluate(model, loader, device):
    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            pred = out.argmax(dim=1)
            total_correct += (pred == data.y).sum().item()
            total_samples += data.y.size(0)
    return total_correct / total_samples


if __name__ == '__main__':
    for i in ['casual_1_3']:
        print(f"Round {i}:")
        parser = argparse.ArgumentParser()
        parser.add_argument('--epochs',     type=int,   default=100)
        parser.add_argument('--lr',         type=float, default=0.01)
        parser.add_argument('--weight_decay', type=float, default=5e-4)
        parser.add_argument('--hidden_dim', type=int,   default=64)
        parser.add_argument('--num_layers', type=int,   default=2)
        parser.add_argument('--no_cuda',    action='store_true')
        parser.add_argument('--use_causal_filter', action='store_true', default=True, help='Whether to use causal filtering module')
        args = parser.parse_args()

        device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
        print(f"Using device: {device}, Causal filtering module: {'Enabled' if args.use_causal_filter else 'Disabled'}")

        random.seed(42)
        np.random.seed(42)
        torch.manual_seed(42)

        train_dataset = PaperNodeDataset(f'./data/paper/train_{i}.npy')
        val_dataset   = PaperNodeDataset(f'./data/paper/val_{i}.npy')
        test_dataset  = PaperNodeDataset(f'./data/paper/test_{i}.npy')

        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        val_loader   = DataLoader(val_dataset, batch_size=64, shuffle=False)
        test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False)

        num_features = train_dataset.num_features
        num_classes  = int(train_dataset.data.y.max().item()) + 1

        filter_config = {
            'input': {'lambda_init': 10.0, 'decay_rate': 0.95},
            'layer_0': {'lambda_init': 10.0, 'decay_rate': 0.95},
            'layer_1': {'lambda_init': 10.0, 'decay_rate': 0.95},
        }

        model = GINModel(
            num_features,
            num_classes,
            hidden_dim=args.hidden_dim,
            num_layers=args.num_layers,
            task='node',
            use_causal_filter=args.use_causal_filter,
            filter_config=filter_config if args.use_causal_filter else None
        ).to(device)

        optimizer = optim.Adam(
            model.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay
        )

        best_val = 0
        val_acc_list_5_epoch = []
        final_train_accuracies = []
        final_test_accuracies = []
        final_in_test_accuracies = []
        for epoch in range(1, args.epochs + 1):
            loss, train_acc = train(model, train_loader, optimizer, device)
            val_acc = evaluate(model, val_loader, device)
            final_train_accuracies.append(train_acc * 100)  
            final_in_test_accuracies.append(val_acc * 100)  
            test_acc = evaluate(model, test_loader, device)
            final_test_accuracies.append(test_acc * 100)  
            val_acc_list_5_epoch.append(val_acc)
            if args.use_causal_filter:
                model.step_epoch()

            if val_acc > best_val:
                best_val = val_acc
                torch.save(model.state_dict(), './best_model_paper_gin.pt')
            
            if epoch % 10 == 0:
                print(f'Epoch {epoch:03d}: Loss {loss:.4f} | Train Acc {train_acc:.4f} | Val Acc {val_acc:.4f} | Test Acc {test_acc:.4f}')
                if args.use_causal_filter:
                    print(model.get_filter_info())

            if epoch % 5 == 0:
                std_val = np.std(val_acc_list_5_epoch)
                val_acc_list_5_epoch = []

        final_train_mean = np.mean(final_train_accuracies)
        final_train_std = np.std(final_train_accuracies)
        final_in_test_mean = np.mean(final_in_test_accuracies)
        final_in_test_std = np.std(final_in_test_accuracies)
        final_test_mean = np.mean(final_test_accuracies)
        final_test_std = np.std(final_test_accuracies)

        
        model.load_state_dict(torch.load('./best_model_paper_gin.pt'))
        final_test = evaluate(model, test_loader, device)*100
        print(f'Train Acc  {final_train_mean:.2f} ± {final_train_std:.2f}  Test Acc  {final_test:.2f} ± {final_test_std:.2f}')
