import sys, os
import logging
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
import argparse
import copy
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv

from utils.mask import set_masks, clear_masks
from utils.helper import set_seed, args_print
from utils.get_subgraph import split_graph, relabel
from datasets import PaperDIRDataset
from improved_filter import ImprovedCausalFilter


class GCNNodeClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes, dropout=0.5):
        super(GCNNodeClassifier, self).__init__()
        
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)
        self.dropout = dropout

    def forward(self, x, edge_index):
        
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        out = self.conv2(x, edge_index)
        return out


class CausalAttNet(nn.Module):
    def __init__(self, ratio, in_channels, hidden_channels):
        super(CausalAttNet, self).__init__()
        
        from torch_geometric.nn import LEConv
        self.conv1 = LEConv(in_channels, hidden_channels)
        self.conv2 = LEConv(hidden_channels, hidden_channels)
        
        self.mlp = nn.Sequential(
            nn.Linear(2 * hidden_channels, 4 * hidden_channels),
            nn.ReLU(),
            nn.Linear(4 * hidden_channels, 1)
        )
        self.ratio = ratio  

    def forward(self, data):
        
        edge_attr = data.edge_attr.view(-1) if data.edge_attr is not None else \
                    torch.ones(data.edge_index.size(1), device=data.x.device)
        
        data.edge_attr = edge_attr.unsqueeze(1)
        data.edge_weight = edge_attr  
        
        x = F.relu(self.conv1(data.x, data.edge_index, edge_attr))
        x = self.conv2(x, data.edge_index, edge_attr)
        
        if x.size(0) == 0:
            empty = lambda: torch.empty((0,), device=x.device)
            return (empty(),) * 5, (empty(),) * 5, edge_attr  
        
        row, col = data.edge_index
        edge_rep = torch.cat([x[row], x[col]], dim=-1)  
        score = self.mlp(edge_rep).view(-1)  
        
        (c_edge_index, c_edge_attr, c_edge_weight), (f_edge_index, f_edge_attr, f_edge_weight) = \
            split_graph(data, score, self.ratio)
        
        if c_edge_index.size(1) == 0 or f_edge_index.size(1) == 0:
            empty = lambda: torch.empty((0,), device=x.device)
            return (empty(),) * 5, (empty(),) * 5, score
        
        causal_x = data.x
        causal_edge_index = c_edge_index
        causal_batch = data.batch if hasattr(data, 'batch') else None
        conf_x = data.x
        conf_edge_index = f_edge_index
        conf_batch = data.batch if hasattr(data, 'batch') else None
        
        return (causal_x, causal_edge_index, c_edge_attr, c_edge_weight, causal_batch), \
               (conf_x, conf_edge_index, f_edge_attr, f_edge_weight, conf_batch), \
               score


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train DIR-GNN on Paper (Node Classification)")
    parser.add_argument('--datadir', default='./data/paper', type=str, help='Path to paper dataset directory')
    parser.add_argument('--cuda', default=0, type=int, help='CUDA device id (if available)')
    parser.add_argument('--epochs', default=50, type=int, help='Number of training epochs')
    parser.add_argument('--batch_size', default=128, type=int, help='Graphs per batch (use 1 for node classification)')
    parser.add_argument('--hidden', default=64, type=int, help='Hidden dimension for GCN')
    parser.add_argument('--lr', default=0.01, type=float, help='Learning rate')
    parser.add_argument('--r', default=0.7, type=float, help='Ratio of edges for causal subgraph')
    parser.add_argument('--seed', default=42, type=int, help='Random seed')
    parser.add_argument('--use_causal_filter', action='store_true', default=False, help='Enable improved causal filter before GCN classifier')
    parser.add_argument('--filter_lambda_init', type=float, default=1.0)
    parser.add_argument('--filter_lambda_min', type=float, default=-2.0)
    parser.add_argument('--filter_decay', type=float, default=0.99)
    parser.add_argument('--filter_temp', type=float, default=1.0)
    parser.add_argument('--filter_residual', type=float, default=0.2)
    args = parser.parse_args()

    set_seed(args.seed)
    device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')

    
    train_ds = PaperDIRDataset(args.datadir, mode='train')
    val_ds   = PaperDIRDataset(args.datadir, mode='val')
    test_ds  = PaperDIRDataset(args.datadir, mode='test')
    for ds in (train_ds, val_ds, test_ds):
        if len(ds) == 0:
            raise ValueError(f"Empty split: {ds.mode}")
    
    in_feats = train_ds.num_node_features  
    
    all_labels = torch.cat([data.y for data in train_ds] + 
                           [data.y for data in val_ds] + 
                           [data.y for data in test_ds], dim=0)
    num_classes = int(all_labels.max().item()) + 1
    
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False)
    test_loader  = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False)

    feature_filter = None
    if args.use_causal_filter:
        feature_filter = ImprovedCausalFilter(
            input_dim=in_feats,
            lambda_init=args.filter_lambda_init,
            lambda_min=args.filter_lambda_min,
            decay_rate=args.filter_decay,
            temperature=args.filter_temp,
            residual_weight=args.filter_residual,
            normalize=False,
            dropout=0.1
        ).to(device)
    
    gcn = GCNNodeClassifier(in_channels=in_feats, hidden_channels=args.hidden, num_classes=num_classes).to(device)
    att = CausalAttNet(ratio=args.r, in_channels=in_feats, hidden_channels=args.hidden).to(device)
    optimizer = torch.optim.Adam(list(gcn.parameters()) + list(att.parameters()) + (list(feature_filter.parameters()) if feature_filter else []), lr=args.lr)
    criterion = nn.CrossEntropyLoss()

    
    def evaluate(loader):
        gcn.eval(); att.eval()
        if feature_filter: feature_filter.eval()
        total_correct = 0; total_nodes = 0
        for graph in loader:
            graph = graph.to(device)
            if feature_filter:
                graph.x = feature_filter(graph.x)
            (cx, c_edge_index, c_attr, c_weight, cb), (fx, f_edge_index, f_attr, f_weight, fb), score = att(graph)
            if cx.size(0) == 0 or fx.size(0) == 0:
                continue
            out = gcn(cx, c_edge_index)
            pred = out.argmax(dim=1)
            total_correct += int((pred == graph.y).sum().item())
            total_nodes += graph.y.size(0)
        return (total_correct / total_nodes) if total_nodes > 0 else 0.0

    
    best_val_acc = 0.0
    best_model_state = None
    best_att_state = None
    val_acc_list_5_epoch = []
    final_train_accuracies = []
    final_in_test_accuracies = []
    final_test_accuracies = []

    for epoch in range(1, args.epochs + 1):
        gcn.train(); att.train();
        if feature_filter: feature_filter.train()
        epoch_loss = 0.0
        for graph in train_loader:
            graph = graph.to(device)
            if feature_filter:
                graph.x = feature_filter(graph.x)
            (cx, c_edge_index, c_attr, c_weight, cb), (fx, f_edge_index, f_attr, f_weight, fb), score = att(graph)
            if cx.size(0) == 0 or fx.size(0) == 0:
                continue
            out_c = gcn(cx, c_edge_index); loss_c = criterion(out_c, graph.y)
            out_f = gcn(fx, f_edge_index); loss_f = criterion(out_f, graph.y)
            loss = loss_c + loss_f
            optimizer.zero_grad(); loss.backward(); optimizer.step(); epoch_loss += loss.item()
        train_acc = evaluate(train_loader) * 100
        val_acc = evaluate(val_loader) * 100
        test_acc_eval = evaluate(test_loader) * 100
        final_train_accuracies.append(train_acc)
        final_in_test_accuracies.append(val_acc)
        final_test_accuracies.append(val_acc)
        print(f"Epoch {epoch}/{args.epochs}, Loss = {epoch_loss:.4f}, Train Acc = {train_acc:.4f}, Val Acc = {val_acc:.4f}")
        if feature_filter: feature_filter.step()
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = copy.deepcopy(gcn.state_dict())
            best_att_state = copy.deepcopy(att.state_dict())
            if feature_filter:
                best_filter_state = copy.deepcopy(feature_filter.state_dict())
    
    final_train_mean = np.mean(final_train_accuracies); final_train_std = np.std(final_train_accuracies)
    final_test_std = np.std(final_test_accuracies)
    if best_model_state is not None:
        gcn.load_state_dict(best_model_state); att.load_state_dict(best_att_state)
        if feature_filter and 'best_filter_state' in locals():
            feature_filter.load_state_dict(best_filter_state)
    test_acc = evaluate(test_loader) * 100
    print(f'Train Acc  {final_train_mean:.2f} ± {final_train_std:.2f}  Test Acc  {test_acc:.2f} ± {final_test_std:.2f}')
    torch.save(gcn.state_dict(), "best_paper_gcn.pth")
    torch.save(att.state_dict(), "best_paper_att.pth")
    if feature_filter:
        torch.save(feature_filter.state_dict(), "best_paper_filter.pth")

