


import argparse
import sys
import os, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F



from torch_geometric.nn import global_mean_pool 

from logger import Logger

from load_molecular_data import load_datasets, create_dataloaders

from data_utils import get_gpu_memory_map, count_parameters
from eval import eval_acc 
from parse import parser_add_main_args
from model import CaNet 
from improved_filter import ImprovedCausalFilter
import time


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def train_epoch(model, loader, optimizer, criterion, device, args, feature_filter=None):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_graphs = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        x_in = feature_filter(batch.x) if feature_filter else batch.x
        out, reg_loss = model(x_in, batch.edge_index, batch=batch.batch, training=True)
        true_label = batch.y.view(-1).long()  
        sup_loss = criterion(out, true_label)
        
        reg_loss = torch.abs(reg_loss)
        loss = sup_loss + args.lamda * reg_loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch.num_graphs
        total_correct += out.argmax(dim=1).eq(true_label).sum().item()
        total_graphs += batch.num_graphs

        
        

    return total_loss / total_graphs, total_correct / total_graphs

@torch.no_grad()
def evaluate_epoch(model, loader, device, args, feature_filter=None):
    model.eval()
    total_correct = 0
    total_graphs = 0
    for batch in loader:
        batch = batch.to(device)
        x_in = feature_filter(batch.x) if feature_filter else batch.x
        out = model(x_in, batch.edge_index, batch=batch.batch, training=False)
        true_label = batch.y.view(-1)
        total_correct += out.argmax(dim=1).eq(true_label).sum().item()
        total_graphs += batch.num_graphs

    return total_correct / total_graphs



if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description=	'CaNet Training for Molecular Graph Classification	')
    parser_add_main_args(parser)
    
    
    parser.add_argument('--use_causal_filter', default=True, action='store_true')
    parser.add_argument('--filter_lambda_init', type=float, default=1.0)
    parser.add_argument('--filter_lambda_min', type=float, default=-2.0)
    parser.add_argument('--filter_decay', type=float, default=0.99)
    parser.add_argument('--filter_temp', type=float, default=1.0)
    parser.add_argument('--filter_residual', type=float, default=0.2)
    args = parser.parse_args()
    
    args.dataset = 'molecular' 
    print("--- Running CaNet for Molecular Graph Classification ---")
    print(args)

    fix_seed(args.seed)

    if args.cpu:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    print(f"Using device: {device}")

    
    train_dataset, val_dataset, test_dataset = load_datasets()
    if not train_dataset:
        sys.exit("Failed to load datasets. Exiting.")

    
    d = train_dataset.num_features 
    
    max_label = 0
    if train_dataset.data.y is not None: max_label = max(max_label, train_dataset.data.y.max().item())
    if val_dataset and val_dataset.data.y is not None: max_label = max(max_label, val_dataset.data.y.max().item())
    if test_dataset and test_dataset.data.y is not None: max_label = max(max_label, test_dataset.data.y.max().item())
    c = int(max_label) + 1 

    print(f"Dataset properties: Features={d}, Classes={c}")

    
    
    batch_size = args.batch_size if hasattr(args, 'batch_size') else 32 
    train_loader, val_loader, test_loader = create_dataloaders(train_dataset, val_dataset, test_dataset, batch_size=batch_size)
    if not train_loader:
        sys.exit("Failed to create dataloaders. Exiting.")

    
    
    
    model = CaNet(d, c, args, device).to(device)
    criterion = nn.CrossEntropyLoss() 
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    logger = Logger(args.runs, args)

    print(f'MODEL: {model}')
    print(f'Total Parameters: {count_parameters(model)}') 

    
    print("--- Starting Training ---")
    for run in range(args.runs):
        print(f"--- Run {run+1}/{args.runs} ---")
        
        model.reset_parameters()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        best_val_acc = 0.0
        acc_list_5_epoch = []

        feature_filter = None
        if args.use_causal_filter:
            feature_filter = ImprovedCausalFilter(
                input_dim=d,
                lambda_init=args.filter_lambda_init,
                lambda_min=args.filter_lambda_min,
                decay_rate=args.filter_decay,
                temperature=args.filter_temp,
                residual_weight=args.filter_residual,
                normalize=False,
                dropout=0.1
            ).to(device)

        for epoch in range(args.epochs):
            start_time = time.time()
            if feature_filter: feature_filter.train()
            
            def train_epoch_with_filter(model, loader, optimizer, criterion, device, args):
                model.train(); total_loss=0; total_correct=0; total_graphs=0
                for batch in loader:
                    batch = batch.to(device)
                    optimizer.zero_grad()
                    x_in = feature_filter(batch.x) if feature_filter else batch.x
                    out, reg_loss = model(x_in, batch.edge_index, batch=batch.batch, training=True)
                    true_label = batch.y.view(-1).long()
                    reg_loss = torch.abs(reg_loss)
                    sup_loss = criterion(out, true_label)
                    loss = sup_loss + args.lamda * reg_loss
                    loss.backward(); optimizer.step()
                    total_loss += loss.item() * batch.num_graphs
                    total_correct += out.argmax(dim=1).eq(true_label).sum().item()
                    total_graphs += batch.num_graphs
                return total_loss/total_graphs, total_correct/total_graphs
            train_loss, train_acc = train_epoch_with_filter(model, train_loader, optimizer, criterion, device, args)
            
            def eval_with_filter(model, loader):
                model.eval(); correct=0; total=0
                for batch in loader:
                    batch = batch.to(device)
                    x_in = feature_filter(batch.x) if feature_filter else batch.x
                    out = model(x_in, batch.edge_index, batch=batch.batch, training=False)
                    true_label = batch.y.view(-1)
                    correct += out.argmax(dim=1).eq(true_label).sum().item(); total += batch.num_graphs
                return correct/total if total>0 else 0.0
            val_acc = eval_with_filter(model, val_loader)
            test_in_acc = eval_with_filter(model, test_loader)
            if feature_filter: feature_filter.step()
            epoch_time = time.time() - start_time

            result_line = (train_acc, val_acc, test_in_acc, 0.0)
            logger.add_result(run, result_line)

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                if feature_filter:
                    torch.save(feature_filter.state_dict(), f"best_filter_molecular_run{run}.pth")

            if epoch % args.display_step == 0:
                print(f'Run: {run+1:02d}, Epoch: {epoch:03d}, '
                    f'Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, '
                    f'Val Acc: {val_acc:.4f}, Test Acc: {test_in_acc:.4f}, '
                    f'Time: {epoch_time:.2f}s')
            
            acc_list_5_epoch.append(test_in_acc)
            if (epoch + 1) % 5 == 0:
                std = np.std(acc_list_5_epoch)
                
                acc_list_5_epoch = []
        
        
        test_acc = evaluate_epoch(model, test_loader, device, args)
        print(f"Run {run+1} finished. Test Accuracy: {test_acc:.4f}")
        
        

        logger.print_statistics(run) 

    print("--- Training Finished ---")
    logger.print_statistics() 
    

    print("\nNote: This script assumes graph classification. The CaNet model architecture, particularly")
    print("the final layer(s), might need adjustment to properly handle pooled graph embeddings.")
    print("The current implementation applies the last linear layer of CaNet AFTER global mean pooling.")

