


import argparse
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, InMemoryDataset, DataLoader
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from models_with_improved_filter import GINWithImprovedFilter as GINModel  


class CiteseerNodeDataset(InMemoryDataset):
    
    def __init__(self, split='train', transform=None):
        super(CiteseerNodeDataset, self).__init__(None, transform)
        self.split = split
        
        
        dataset = Planetoid(root='./data/Planetoid', name='CiteSeer', transform=NormalizeFeatures())
        self.original_data = dataset[0]
        
        self._num_features = self.original_data.x.size(1)
        self._num_classes = dataset.num_classes
        
        
        self.data, self.slices = self.create_split_data()

    @property
    def num_features(self):
        return self._num_features
    
    @property 
    def num_classes(self):
        return self._num_classes

    def create_split_data(self):
        
        if self.split == 'train':
            mask = self.original_data.train_mask
        elif self.split == 'val':
            mask = self.original_data.val_mask
        elif self.split == 'test':
            mask = self.original_data.test_mask
        else:
            raise ValueError(f"Unknown split: {self.split}")
        
        
        node_indices = mask.nonzero(as_tuple=False).view(-1)
        
        
        data_list = []
        for idx in node_indices:
            
            x = self.original_data.x[idx:idx+1]  
            y = self.original_data.y[idx:idx+1]  
            edge_index = torch.empty((2, 0), dtype=torch.long)  
            
            data = Data(x=x, edge_index=edge_index, y=y)
            data_list.append(data)
        
        return self.collate(data_list)


class CiteseerFullGraphDataset(InMemoryDataset):
    
    def __init__(self, split='train', transform=None):
        super(CiteseerFullGraphDataset, self).__init__(None, transform)
        self.split = split
        
        
        dataset = Planetoid(root='./data/Planetoid', name='CiteSeer', transform=NormalizeFeatures())
        self.original_data = dataset[0]
        
        self._num_features = self.original_data.x.size(1)
        self._num_classes = dataset.num_classes
        
        
        self.data, self.slices = self.create_full_graph_data()

    @property
    def num_features(self):
        return self._num_features
    
    @property 
    def num_classes(self):
        return self._num_classes

    def create_full_graph_data(self):
        
        
        data = self.original_data.clone()
        
        
        if self.split == 'train':
            data.node_mask = data.train_mask
        elif self.split == 'val':
            data.node_mask = data.val_mask  
        elif self.split == 'test':
            data.node_mask = data.test_mask
        else:
            raise ValueError(f"Unknown split: {self.split}")
            
        return self.collate([data])


def train_full_graph(model, data, optimizer, device, split_mask):
    
    model.train()
    data = data.to(device)
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)  
    loss = F.nll_loss(out[split_mask], data.y[split_mask])
    loss.backward()
    optimizer.step()
    
    pred = out.argmax(dim=1)
    correct = (pred[split_mask] == data.y[split_mask]).sum().item()
    total = split_mask.sum().item()
    
    return loss.item(), correct / total


def evaluate_full_graph(model, data, device, split_mask):
    
    model.eval()
    with torch.no_grad():
        data = data.to(device)
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        correct = (pred[split_mask] == data.y[split_mask]).sum().item()
        total = split_mask.sum().item()
    return correct / total


def train_batch(model, loader, optimizer, device):
    
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        pred = out.argmax(dim=1)
        total_correct += (pred == data.y).sum().item()
        total_samples += data.y.size(0)

    return total_loss / len(loader), total_correct / total_samples


def evaluate_batch(model, loader, device):
    
    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            pred = out.argmax(dim=1)
            total_correct += (pred == data.y).sum().item()
            total_samples += data.y.size(0)
    return total_correct / total_samples


if __name__ == '__main__':
    print("Starting Citeseer dataset training:")
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs',     type=int,   default=200)
    parser.add_argument('--lr',         type=float, default=0.01)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--hidden_dim', type=int,   default=64)
    parser.add_argument('--num_layers', type=int,   default=2)
    parser.add_argument('--no_cuda',    action='store_true')
    parser.add_argument('--use_causal_filter', action='store_true', default=True, help='Whether to use causal filtering module')
    parser.add_argument('--use_full_graph', action='store_true', default=True, help='Whether to use complete graph structure for training')
    parser.add_argument('--runs',       type=int,   default=10, help='Number of runs')
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    print(f"Using device: {device}, Causal filtering module: {'Enabled' if args.use_causal_filter else 'Disabled'}")
    print(f"Training using {'complete graph structure' if args.use_full_graph else 'batch processing'}")

    
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    
    test_accs = []
    
    for run in range(args.runs):
        print(f"\n=== Run {run + 1} ===")
        
        if args.use_full_graph:
            
            dataset = Planetoid(root='./data/Planetoid', name='CiteSeer', transform=NormalizeFeatures())
            data = dataset[0]
            num_features = dataset.num_node_features
            num_classes = dataset.num_classes
        else:
            
            train_dataset = CiteseerNodeDataset('train')
            val_dataset   = CiteseerNodeDataset('val')
            test_dataset  = CiteseerNodeDataset('test')
            
            train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
            val_loader   = DataLoader(val_dataset, batch_size=64, shuffle=False)
            test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False)
            
            num_features = train_dataset.num_features
            num_classes = train_dataset.num_classes

        
        filter_config = {
            'input': {'lambda_init': 10.0, 'decay_rate': 0.95},
            'layer_0': {'lambda_init': 10.0, 'decay_rate': 0.95},
            'layer_1': {'lambda_init': 10.0, 'decay_rate': 0.95},
        }

        
        model = GINModel(
            num_features,
            num_classes,
            hidden_dim=args.hidden_dim,
            num_layers=args.num_layers,
            task='node',  
            use_causal_filter=args.use_causal_filter,
            filter_config=filter_config if args.use_causal_filter else None
        ).to(device)

        optimizer = optim.Adam(
            model.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay
        )

        best_val = 0
        final_test_acc = 0
        
        for epoch in range(1, args.epochs + 1):
            if args.use_full_graph:
                
                loss, train_acc = train_full_graph(model, data, optimizer, device, data.train_mask)
                val_acc = evaluate_full_graph(model, data, device, data.val_mask)
                test_acc = evaluate_full_graph(model, data, device, data.test_mask)
            else:
                
                loss, train_acc = train_batch(model, train_loader, optimizer, device)
                val_acc = evaluate_batch(model, val_loader, device)
                test_acc = evaluate_batch(model, test_loader, device)

            
            if args.use_causal_filter:
                model.step_epoch()

            
            if val_acc > best_val:
                best_val = val_acc
                final_test_acc = test_acc
                
            
            
            if epoch % 50 == 0:
                print(f'Epoch {epoch:03d}: Loss {loss:.4f} | Train Acc {train_acc:.4f} | Val Acc {val_acc:.4f} | Test Acc {test_acc:.4f}')
                if args.use_causal_filter:
                    print(model.get_filter_info())

        test_accs.append(final_test_acc * 100)
        print(f'Run {run + 1}: Best Val Acc: {best_val:.4f}, Test Acc: {final_test_acc:.4f}')

    
    test_mean = np.mean(test_accs)
    test_std = np.std(test_accs)
    
    print(f'Test Accuracy: {test_mean:.2f} ± {test_std:.2f}%')
    print(f'Individual runs: {[f"{acc:.2f}" for acc in test_accs]}')
