import argparse
import random
import sys

import numpy as np
import torch
import torch.nn as nn

from parse import parser_add_main_args
from load_paper_data import load_datasets, create_dataloaders
from model import CaNet
from improved_filter import ImprovedCausalFilter  


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def train_epoch(model, loader, optimizer, criterion, device, args, feature_filter=None):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        x_in = feature_filter(data.x) if feature_filter else data.x  
        out, reg_loss = model(x_in, data.edge_index, training=True)
        
        
        reg_loss = torch.abs(reg_loss)  
        
        ce_loss = criterion(out, data.y)
        loss = ce_loss + args.lamda * reg_loss
        
        
        if loss.item() < 0:
            print(f"Warning: Negative loss detected! CE Loss: {ce_loss.item()}, Reg Loss: {reg_loss.item()}")
            
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = out.argmax(dim=1)
        total_correct += pred.eq(data.y).sum().item()
        total_samples += data.y.size(0)

    return total_loss / len(loader), total_correct / total_samples


@torch.no_grad()
def evaluate_epoch(model, loader, device, args, feature_filter=None):
    model.eval()
    total_correct = 0
    total_samples = 0
    for data in loader:
        data = data.to(device)
        x_in = feature_filter(data.x) if feature_filter else data.x  
        out = model(x_in, data.edge_index, training=False)
        pred = out.argmax(dim=1)
        total_correct += pred.eq(data.y).sum().item()
        total_samples += data.y.size(0)
    return total_correct / total_samples

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description='CaNet for Node Classification on Paper Dataset')
    parser_add_main_args(parser)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--use_causal_filter', default=True, action='store_true')  
    parser.add_argument('--filter_lambda_init', type=float, default=1.0)  
    parser.add_argument('--filter_lambda_min', type=float, default=-2.0)  
    parser.add_argument('--filter_decay', type=float, default=0.99)  
    parser.add_argument('--filter_temp', type=float, default=1.0)  
    parser.add_argument('--filter_residual', type=float, default=0.2)  
    args = parser.parse_args()
    args.dataset = 'paper'

    
    fix_seed(args.seed)
    device = torch.device('cuda:' + str(args.device) if (not args.cpu and torch.cuda.is_available()) else 'cpu')

    
    train_ds, val_ds, test_ds = load_datasets()
    if not train_ds:
        sys.exit("Dataset loading failed.")

    
    d = train_ds.num_node_features
    max_label = -1
    if train_ds.data.y is not None:
        max_label = max(max_label, int(train_ds.data.y.max().item()))
    if val_ds.data.y is not None:
        max_label = max(max_label, int(val_ds.data.y.max().item()))
    if test_ds.data.y is not None:
        max_label = max(max_label, int(test_ds.data.y.max().item()))
    c = max_label + 1
    print(f"Dataset properties: Features={d}, Classes={c}")

    
    train_loader, val_loader, test_loader = create_dataloaders(train_ds, val_ds, test_ds, batch_size=args.batch_size)
    if not train_loader:
        sys.exit("Failed to create dataloaders.")

    
    model = CaNet(d, c, args, device).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    
    feature_filter = None
    if args.use_causal_filter:
        feature_filter = ImprovedCausalFilter(
            input_dim=d,
            lambda_init=args.filter_lambda_init,
            lambda_min=args.filter_lambda_min,
            decay_rate=args.filter_decay,
            temperature=args.filter_temp,
            residual_weight=args.filter_residual,
            normalize=False,
            dropout=0.1
        ).to(device)

    
    best_val_acc = 0.0
    val_acc_list_5_epoch = []
    final_train_accuracies = []
    final_test_accuracies = []
    final_in_test_accuracies = []

    for epoch in range(1, args.epochs + 1):
        if feature_filter: feature_filter.train()
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device, args, feature_filter)
        val_acc = evaluate_epoch(model, val_loader, device, args, feature_filter)
        test_acc = evaluate_epoch(model, test_loader, device, args, feature_filter)

        
        final_train_accuracies.append(train_acc)
        final_in_test_accuracies.append(test_acc)  
        final_test_accuracies.append(test_acc)  

        print(f"Epoch {epoch:03d}: Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}, Test Acc={test_acc:.4f}")
        val_acc_list_5_epoch.append(val_acc)

        
        if epoch % 5 == 0:
            std_val = np.std(val_acc_list_5_epoch)
            
            val_acc_list_5_epoch = []

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model_paper.pth")
            if feature_filter:
                torch.save(feature_filter.state_dict(), "best_filter_paper.pth")  

    
    final_train_mean = np.mean(final_train_accuracies)
    final_train_std = np.std(final_train_accuracies)
    final_test_mean = np.mean(final_test_accuracies)
    final_test_std = np.std(final_test_accuracies)
    final_in_test_mean = np.mean(final_in_test_accuracies)
    final_in_test_std = np.std(final_in_test_accuracies)

    
    model.load_state_dict(torch.load("best_model_paper.pth"))
    final_test_acc = evaluate_epoch(model, test_loader, device, args)
    print(f'Train Acc  {final_train_mean*100:.2f} ± {final_train_std*100:.2f}  Test Acc  {final_test_acc*100:.2f} ± {final_test_std*100:.2f}')
