


import sys, os
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import argparse
import copy
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.nn import GCNConv, LEConv
from torch_geometric.data import Data

from utils.mask import set_masks, clear_masks
from utils.helper import set_seed, args_print
from utils.get_subgraph import split_graph, relabel
from improved_filter import ImprovedCausalFilter


def set_env(seed):
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True



class GCNNodeClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes, dropout=0.5):
        super(GCNNodeClassifier, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        out = self.conv2(x, edge_index)
        return out



class CausalAttNet(nn.Module):
    def __init__(self, ratio, in_channels, hidden_channels):
        super(CausalAttNet, self).__init__()
        self.conv1 = LEConv(in_channels, hidden_channels)
        self.conv2 = LEConv(hidden_channels, hidden_channels)
        self.mlp = nn.Sequential(
            nn.Linear(2 * hidden_channels, 4 * hidden_channels),
            nn.ReLU(),
            nn.Linear(4 * hidden_channels, 1)
        )
        self.ratio = ratio

    def forward(self, data):
        edge_attr = data.edge_attr.view(-1) if data.edge_attr is not None else \
                    torch.ones(data.edge_index.size(1), device=data.x.device)
        data.edge_attr = edge_attr.unsqueeze(1)
        data.edge_weight = edge_attr

        x = F.relu(self.conv1(data.x, data.edge_index, edge_attr))
        x = self.conv2(x, data.edge_index, edge_attr)

        if x.size(0) == 0:
            empty = lambda: torch.empty((0,), device=x.device)
            return (empty(),) * 5, (empty(),) * 5, edge_attr

        row, col = data.edge_index
        edge_rep = torch.cat([x[row], x[col]], dim=-1)
        score = self.mlp(edge_rep).view(-1)

        (c_edge_index, c_edge_attr, c_edge_weight), (f_edge_index, f_edge_attr, f_edge_weight) = \
            split_graph(data, score, self.ratio)

        if c_edge_index.size(1) == 0 or f_edge_index.size(1) == 0:
            empty = lambda: torch.empty((0,), device=x.device)
            return (empty(),) * 5, (empty(),) * 5, score

        causal_x = data.x
        causal_edge_index = c_edge_index
        causal_batch = data.batch if hasattr(data, 'batch') else None
        conf_x = data.x
        conf_edge_index = f_edge_index
        conf_batch = data.batch if hasattr(data, 'batch') else None

        return (causal_x, causal_edge_index, c_edge_attr, c_edge_weight, causal_batch), \
               (conf_x, conf_edge_index, f_edge_attr, f_edge_weight, conf_batch), \
               score


def train_full_graph(gcn, att, feature_filter, data, optimizer, criterion, device):
    
    gcn.train()
    att.train()
    if feature_filter:
        feature_filter.train()
    
    data = data.to(device)
    optimizer.zero_grad()
    
    
    if feature_filter:
        data.x = feature_filter(data.x)
    
    
    if not hasattr(data, 'batch') or data.batch is None:
        data.batch = torch.zeros(data.x.size(0), dtype=torch.long, device=data.x.device)
    
    
    (cx, c_edge_index, c_attr, c_weight, cb), (fx, f_edge_index, f_attr, f_weight, fb), score = att(data)
    
    
    if cx.size(0) == 0 or fx.size(0) == 0:
        return 0.0, 0.0
    
    
    out_c = gcn(cx, c_edge_index)
    out_f = gcn(fx, f_edge_index)
    
    
    loss_c = criterion(out_c[data.train_mask], data.y[data.train_mask])
    loss_f = criterion(out_f[data.train_mask], data.y[data.train_mask])
    loss = loss_c + loss_f
    
    loss.backward()
    optimizer.step()
    
    
    pred = out_c.argmax(dim=1)
    correct = (pred[data.train_mask] == data.y[data.train_mask]).sum().item()
    total = data.train_mask.sum().item()
    
    return loss.item(), correct / total


def evaluate_full_graph(gcn, att, feature_filter, data, device, split_mask):
    
    gcn.eval()
    att.eval()
    if feature_filter:
        feature_filter.eval()
    
    with torch.no_grad():
        data = data.to(device)
        
        
        if feature_filter:
            data.x = feature_filter(data.x)
        
        
        if not hasattr(data, 'batch') or data.batch is None:
            data.batch = torch.zeros(data.x.size(0), dtype=torch.long, device=data.x.device)
        
        
        (cx, c_edge_index, c_attr, c_weight, cb), (fx, f_edge_index, f_attr, f_weight, fb), score = att(data)
        
        if cx.size(0) == 0:
            return 0.0
        
        
        out = gcn(cx, c_edge_index)
        pred = out.argmax(dim=1)
        correct = (pred[split_mask] == data.y[split_mask]).sum().item()
        total = split_mask.sum().item()
    
    return correct / total


if __name__ == '__main__':
    print("Starting training with DIR-GNN on CiteSeer dataset:")
    
    parser = argparse.ArgumentParser(description="Train DIR-GNN on CiteSeer (Node Classification)")
    parser.add_argument('--epochs',     type=int,   default=1000)
    parser.add_argument('--lr',         type=float, default=0.01)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--hidden',     type=int,   default=64, help='Hidden dimension for GCN')
    parser.add_argument('--r',          type=float, default=0.7, help='Ratio of edges for causal subgraph')
    parser.add_argument('--dropout',    type=float, default=0.5)
    parser.add_argument('--no_cuda',    action='store_true')
    parser.add_argument('--use_causal_filter', action='store_true', default=True, help='Whether to use causal filter module')
    parser.add_argument('--filter_lambda_init', type=float, default=10.0)
    parser.add_argument('--filter_lambda_min', type=float, default=-2.0)
    parser.add_argument('--filter_decay', type=float, default=0.95)
    parser.add_argument('--filter_temp', type=float, default=1.0)
    parser.add_argument('--filter_residual', type=float, default=0.2)
    parser.add_argument('--runs',       type=int,   default=10, help='Number of runs')
    parser.add_argument('--seed',       type=int,   default=42, help='Random seed')
    parser.add_argument('--patience',   type=int,   default=100, help='Early stopping patience')
    args = parser.parse_args()

    
    set_env(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    print(f"Using device: {device}")
    print(f"Causal filter module: {'Enabled' if args.use_causal_filter else 'Disabled'}")
    print(f"Hidden layer dimension: {args.hidden}, Causal edge ratio: {args.r}")

    
    test_accs = []
    val_accs = []
    
    for run in range(args.runs):
        print(f"\n=== Run {run + 1} ===")
        
        
        set_env(args.seed + run)
        
        
        dataset = Planetoid(root='./data/Planetoid', name='CiteSeer', transform=NormalizeFeatures())
        data = dataset[0]
        num_features = dataset.num_node_features
        num_classes = dataset.num_classes
        
        print(f"Dataset info: Nodes={data.x.size(0)}, Edges={data.edge_index.size(1)}, Feature dimensions={num_features}, Number of classes={num_classes}")
        print(f"Training nodes={data.train_mask.sum().item()}, Validation nodes={data.val_mask.sum().item()}, Test nodes={data.test_mask.sum().item()}")
        
        
        feature_filter = None
        if args.use_causal_filter:
            feature_filter = ImprovedCausalFilter(
                input_dim=num_features,
                lambda_init=args.filter_lambda_init,
                lambda_min=args.filter_lambda_min,
                decay_rate=args.filter_decay,
                temperature=args.filter_temp,
                residual_weight=args.filter_residual,
                normalize=False,
                dropout=0.1
            ).to(device)

        
        gcn = GCNNodeClassifier(
            in_channels=num_features, 
            hidden_channels=args.hidden, 
            num_classes=num_classes,
            dropout=args.dropout
        ).to(device)
        
        att = CausalAttNet(
            ratio=args.r, 
            in_channels=num_features, 
            hidden_channels=args.hidden
        ).to(device)

        
        params = list(gcn.parameters()) + list(att.parameters())
        if feature_filter:
            params += list(feature_filter.parameters())
        optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
        
        criterion = nn.CrossEntropyLoss()

        print(f"Total parameter count: {sum(p.numel() for p in params if p.requires_grad)}")

        best_val = 0
        final_test_acc = 0
        patience_counter = 0
        
        for epoch in range(1, args.epochs + 1):
            
            loss, train_acc = train_full_graph(gcn, att, feature_filter, data, optimizer, criterion, device)
            
            
            val_acc = evaluate_full_graph(gcn, att, feature_filter, data, device, data.val_mask)
            test_acc = evaluate_full_graph(gcn, att, feature_filter, data, device, data.test_mask)

            
            if args.use_causal_filter and feature_filter:
                feature_filter.step()

            
            if val_acc > best_val:
                best_val = val_acc
                final_test_acc = test_acc
                patience_counter = 0
                
                
                
                
                
            else:
                patience_counter += 1
                if patience_counter >= args.patience:
                    print(f'Early stopping at epoch {epoch}')
                    break
            
            
            if epoch % 100 == 0 or epoch == 1:
                print(f'Epoch {epoch:04d}: Loss {loss:.4f} | Train Acc {train_acc:.4f} | Val Acc {val_acc:.4f} | Test Acc {test_acc:.4f}')
                if args.use_causal_filter and feature_filter and epoch % 200 == 0:
                    stats = feature_filter.get_stats()
                    print(f"Filter Stats: lambda={stats['lambda']:.4f}, gate_mean={stats['gate_stats']['mean']:.4f}")

        test_accs.append(final_test_acc * 100)
        val_accs.append(best_val * 100)
        print(f'Run {run + 1}: Best Val Acc: {best_val:.4f}, Test Acc: {final_test_acc:.4f}')

    
    test_mean = np.mean(test_accs)
    test_std = np.std(test_accs)
    val_mean = np.mean(val_accs)
    val_std = np.std(val_accs)
    
    print(f'\n=== Final Results ===')
    print(f'Validation Accuracy: {val_mean:.2f} ± {val_std:.2f}%')
    print(f'Test Accuracy: {test_mean:.2f} ± {test_std:.2f}%')
    print(f'Individual test runs: {[f"{acc:.2f}" for acc in test_accs]}')
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    print(f'\n=== Final Results of DIR-GNN on CiteSeer Dataset ===')
    if args.use_causal_filter:
        print(f'DIR-GNN + ImprovedCausalFilter: Test Accuracy = {test_mean:.2f} ± {test_std:.2f}%')
    else:
        print(f'DIR-GNN (baseline): Test Accuracy = {test_mean:.2f} ± {test_std:.2f}%')
