import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch_geometric.loader import DataLoader
from improved_models import GINModelWithImprovedFilter, GINModel
from data_loader import NPYGraphDataset

def train(model, loader, optimizer, device, task):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, getattr(data, 'batch', None))
        if task == 'graph':
            loss = F.nll_loss(out, data.y.view(-1))
        else:
            loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, device, split='val', task='graph'):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, getattr(data, 'batch', None))
            if task == 'graph':
                pred = out.argmax(dim=1)
                correct += (pred == data.y.view(-1)).sum().item()
                total += data.y.size(0)
            else:
                mask = getattr(data, f'{split}_mask')
                pred = out.argmax(dim=1)[mask]
                correct += (pred == data.y[mask]).sum().item()
                total += mask.sum().item()
    return correct / total

def compare_filter_methods():

    dataset_name = 'crcg'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"device: {device}")
    
    
    task = 'graph'
    train_dataset = NPYGraphDataset(f'./data/crcg/train_{dataset_name}.npy', task=task)
    val_dataset = NPYGraphDataset(f'./data/crcg/val_{dataset_name}.npy', task=task)
    test_dataset = NPYGraphDataset(f'./data/crcg/test_{dataset_name}.npy', task=task)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    num_features = train_dataset.num_features
    num_classes = int(train_dataset.data.y.max()) + 1
    print(f"feature num: {num_features}, classes num: {num_classes}")
    
    
    configs = [
        {"name": "No filtering module", "type": "baseline"},
        {"name": "Improved filtering module", "type": "improved"},
    ]
    
    results = {}
    
    for config in configs:
        print(f"\n{'='*40}")
        print(f"Testing: {config['name']}")
        print(f"{'='*40}")
        
        
        if config["type"] == "baseline":
            model = GINModel(num_features, num_classes, task=task).to(device)
        else:
            model = GINModelWithImprovedFilter(
                num_features, num_classes, 
                task=task, 
                use_causal_filter=True,
                filter_type=config["type"]
            ).to(device)
        
        print(f"Number of model parameters: {sum(p.numel() for p in model.parameters())}")
        
        optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
        
        best_val_acc = 0
        best_test_acc = 0
        test_accs = []
        
        epochs = 100
        for epoch in range(1, epochs + 1):
            
            loss = train(model, train_loader, optimizer, device, task)
            
            
            if hasattr(model, 'step_filters'):
                model.step_filters()
            
            
            val_acc = evaluate(model, val_loader, device, 'val', task)
            test_acc = evaluate(model, test_loader, device, 'test', task)
            
            test_accs.append(test_acc)
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_test_acc = test_acc
            
            
            if epoch % 10 == 0:
                print(f"Epoch {epoch:2d}: Loss={loss:.4f}, Val={val_acc:.4f}, Test={test_acc:.4f}")
                
                
                if hasattr(model, 'get_filter_info'):
                    filter_info = model.get_filter_info()
                    if filter_info != "Causal filters are disabled":
                        print("Filtering module status:")
                        print(filter_info)
        
        
        final_test_acc = np.mean(test_accs[-5:]) * 100  
        best_test_acc *= 100
        
        results[config["name"]] = {
            "best": best_test_acc,
            "final": final_test_acc,
            "std": np.std(test_accs[-10:]) * 100  
        }
        
        print(f"\nResults:")
        print(f"Best test accuracy: {best_test_acc:.2f}%")
        print(f"Final test accuracy: {final_test_acc:.2f}% ± {results[config['name']]['std']:.2f}%")
    
    
    print(f"\n{'='*80}")
    print("Final comparison results:")
    print(f"{'='*80}")
    
    baseline_best = results["No filtering module"]["best"]
    baseline_final = results["No filtering module"]["final"]
    
    for name, result in results.items():
        best_improvement = result["best"] - baseline_best
        final_improvement = result["final"] - baseline_final
        
        print(f"\n{name}:")
        print(f"  Best accuracy: {result['best']:.2f}% ({best_improvement:+.2f}%)")
        print(f"  Final accuracy: {result['final']:.2f}% ± {result['std']:.2f}% ({final_improvement:+.2f}%)")
    
    return results

def detailed_analysis():
    
    print("\n" + "="*80)
    print("Detailed analysis of filtering module working mechanism")
    print("="*80)
    
    dataset_name = 'crcg'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    
    task = 'graph'
    train_dataset = NPYGraphDataset(f'./data/crcg/train_{dataset_name}.npy', task=task)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    
    num_features = train_dataset.num_features
    num_classes = int(train_dataset.data.y.max()) + 1
    
    
    model = GINModelWithImprovedFilter(
        num_features, num_classes, 
        task=task, 
        use_causal_filter=True,
        filter_type='improved'
    ).to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    
    print("Analysis of filtering module behavior during training:")
    print("-" * 50)
    
    model.train()
    for epoch in range(1, 11):
        total_loss = 0
        for batch_idx, data in enumerate(train_loader):
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data.x, data.edge_index, getattr(data, 'batch', None))
            loss = F.nll_loss(out, data.y.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
            if batch_idx == 0:  
                break
        
        
        model.step_filters()
        
        
        avg_loss = total_loss / len(train_loader) if len(train_loader) > 0 else total_loss
        print(f"\nEpoch {epoch}: Loss = {avg_loss:.4f}")
        filter_info = model.get_filter_info()
        print(filter_info)

if __name__ == '__main__':
    try:
        results = compare_filter_methods()
        
        
        
        
        
    except Exception as e:
        print(f"Runtime error: {e}")
        import traceback
        traceback.print_exc()
