


import sys
sys.path.append('./')
from improved_filter import ImprovedCausalFilter
import torch
import argparse
import os
import os.path as osp
import numpy as np
import random
from datetime import datetime
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.data import DataLoader, Data
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import LEConv, Linear
from utils.mask import set_masks, clear_masks
from utils.logger import Logger
from utils.helper import set_seed, args_print
from utils.get_subgraph import split_graph, relabel
from gnn.spmotif_paper import SPMotifNet1


def set_env(seed):
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True



class CausalAttNet(nn.Module):
    
    def __init__(self, causal_ratio, in_channels, channels):
        super().__init__()
        self.conv1 = LEConv(in_channels=in_channels, out_channels=channels)
        self.conv2 = LEConv(in_channels=channels, out_channels=channels)
        self.mlp = nn.Sequential(
            nn.Linear(channels*2, channels*4),
            nn.ReLU(),
            nn.Linear(channels*4, 1)
        )
        self.ratio = causal_ratio

    def forward(self, data):
        x0 = data.x
        
        edge_weight = data.edge_attr.view(-1) if hasattr(data, 'edge_attr') and data.edge_attr is not None else torch.ones(data.edge_index.size(1), device=x0.device)
        
        h = F.relu(self.conv1(x0, data.edge_index, edge_weight))
        h = self.conv2(h, data.edge_index, edge_weight)
        row, col = data.edge_index
        rep = torch.cat([h[row], h[col]], dim=-1)
        score = self.mlp(rep).view(-1)
        
        
        if not hasattr(data, 'batch') or data.batch is None:
            data.batch = torch.zeros(data.x.size(0), dtype=torch.long, device=data.x.device)
        
        
        if not hasattr(data, 'edge_attr') or data.edge_attr is None:
            data.edge_attr = torch.ones(data.edge_index.size(1), device=data.x.device)
        
        (cei, cea, cew), (fei, fea, few) = split_graph(data, score, self.ratio)
        return (cei, cea, cew), (fei, fea, few), score


def train_full_graph(g, att_net, feature_filter, data, optimizer, criterion, device):
    
    g.train()
    att_net.train()
    if feature_filter:
        feature_filter.train()
    
    data = data.to(device)
    optimizer.zero_grad()
    
    
    filtered_data = data
    if feature_filter:
        filtered_data = Data(
            x=feature_filter(data.x),
            edge_index=data.edge_index,
            edge_attr=getattr(data, 'edge_attr', None),
            y=data.y,
            batch=getattr(data, 'batch', None),
            train_mask=data.train_mask,
            val_mask=data.val_mask,
            test_mask=data.test_mask
        )
    
    
    (cei, cea, cew), (fei, fea, few), _ = att_net(filtered_data)
    
    
    if cei.size(1) == 0 or fei.size(1) == 0:
        return 0.0, 0.0
    
    
    set_masks(cew, g)
    out_c = g(x=data.x, edge_index=cei, edge_attr=cea, batch=data.batch if hasattr(data, 'batch') else None)
    loss_c = criterion(out_c[data.train_mask], data.y[data.train_mask])
    
    
    clear_masks(g)
    out_f = g(x=data.x, edge_index=fei, edge_attr=fea, batch=data.batch if hasattr(data, 'batch') else None)
    loss_f = criterion(out_f[data.train_mask], data.y[data.train_mask])
    
    loss = loss_c + loss_f
    loss.backward()
    optimizer.step()
    
    
    pred = out_c.argmax(dim=1)
    correct = (pred[data.train_mask] == data.y[data.train_mask]).sum().item()
    total = data.train_mask.sum().item()
    
    return loss.item(), correct / total


def evaluate_full_graph(g, att_net, feature_filter, data, device, split_mask):
    
    g.eval()
    att_net.eval()
    if feature_filter:
        feature_filter.eval()
    
    with torch.no_grad():
        data = data.to(device)
        
        
        filtered_data = data
        if feature_filter:
            filtered_data = Data(
                x=feature_filter(data.x),
                edge_index=data.edge_index,
                edge_attr=getattr(data, 'edge_attr', None),
                y=data.y,
                batch=getattr(data, 'batch', None)
            )
        
        
        (cei, cea, cew), _, _ = att_net(filtered_data)
        
        if cei.size(1) == 0:
            return 0.0
        
        
        set_masks(cew, g)
        out = g(x=data.x, edge_index=cei, edge_attr=cea, batch=data.batch if hasattr(data, 'batch') else None)
        pred = out.argmax(dim=1)
        clear_masks(g)
        
        correct = (pred[split_mask] == data.y[split_mask]).sum().item()
        total = split_mask.sum().item()
    
    return correct / total


def main():
    print("Starting training with CRCG on CiteSeer dataset:")
    
    parser = argparse.ArgumentParser('CRCG CiteSeer Node Classification')
    parser.add_argument('--cuda',      default=0,           type=int)
    parser.add_argument('--epochs',    default=1000,        type=int)
    parser.add_argument('--r',         default=0.7,         type=float, help='Ratio of edges for causal subgraph')
    parser.add_argument('--batch_size',default=1,           type=int, help='Batch size (use 1 for single graph)')
    parser.add_argument('--net_lr',    default=0.01,        type=float)
    parser.add_argument('--weight_decay', default=5e-4,     type=float)
    parser.add_argument('--hidden_channels', default=64,    type=int)
    parser.add_argument('--num_unit',  default=2,           type=int)
    parser.add_argument('--seed',      default=42,          type=int)
    parser.add_argument('--runs',      default=10,          type=int)
    parser.add_argument('--patience',  default=100,         type=int)
    parser.add_argument('--use_causal_filter', action='store_true', default=True)
    parser.add_argument('--filter_lambda_init', type=float, default=10.0)
    parser.add_argument('--filter_lambda_min', type=float, default=-2.0)
    parser.add_argument('--filter_decay', type=float, default=0.95)
    parser.add_argument('--filter_temp', type=float, default=1.0)
    parser.add_argument('--filter_residual', type=float, default=0.2)
    args = parser.parse_args()

    set_env(args.seed)
    device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
    
    print(f"Using device: {device}")
    print(f"Causal filter module: {'Enabled' if args.use_causal_filter else 'Disabled'}")
    print(f"Hidden layer dimension: {args.hidden_channels}, Causal edge ratio: {args.r}")

    
    test_accs = []
    val_accs = []
    
    for run in range(args.runs):
        print(f"\n=== Run {run + 1} ===")
        
        
        set_env(args.seed + run)
        
        
        dataset = Planetoid(root='./data/Planetoid', name='CiteSeer', transform=NormalizeFeatures())
        data = dataset[0]
        num_features = dataset.num_node_features
        num_classes = dataset.num_classes
        
        print(f"Dataset info: Nodes={data.x.size(0)}, Edges={data.edge_index.size(1)}, Feature dimensions={num_features}, Number of classes={num_classes}")
        print(f"Training nodes={data.train_mask.sum().item()}, Validation nodes={data.val_mask.sum().item()}, Test nodes={data.test_mask.sum().item()}")

        
        g = SPMotifNet1(
            in_channels=num_features, 
            hid_channels=args.hidden_channels, 
            num_classes=num_classes,
            num_unit=args.num_unit
        ).to(device)
        
        hidden_ch = g.node_emb.out_features
        att_net = CausalAttNet(args.r, in_channels=num_features, channels=hidden_ch).to(device)

        feature_filter = None
        if args.use_causal_filter:
            feature_filter = ImprovedCausalFilter(
                input_dim=num_features,
                lambda_init=args.filter_lambda_init,
                lambda_min=args.filter_lambda_min,
                decay_rate=args.filter_decay,
                temperature=args.filter_temp,
                residual_weight=args.filter_residual,
                normalize=False,
                dropout=0.1
            ).to(device)

        params = list(g.parameters()) + list(att_net.parameters())
        if feature_filter:
            params += list(feature_filter.parameters())
        optimizer = torch.optim.Adam(params, lr=args.net_lr, weight_decay=args.weight_decay)
        criterion = nn.CrossEntropyLoss()

        print(f"Total parameter count: {sum(p.numel() for p in params if p.requires_grad)}")

        best_val = 0
        final_test_acc = 0
        patience_counter = 0
        
        for epoch in range(1, args.epochs + 1):
            
            loss, train_acc = train_full_graph(g, att_net, feature_filter, data, optimizer, criterion, device)
            
            
            val_acc = evaluate_full_graph(g, att_net, feature_filter, data, device, data.val_mask)
            test_acc = evaluate_full_graph(g, att_net, feature_filter, data, device, data.test_mask)

            
            if args.use_causal_filter and feature_filter:
                feature_filter.step()

            
            if val_acc > best_val:
                best_val = val_acc
                final_test_acc = test_acc
                patience_counter = 0
                
                
                
                
                
            else:
                patience_counter += 1
                if patience_counter >= args.patience:
                    print(f'Early stopping at epoch {epoch}')
                    break
            
            
            if epoch % 100 == 0 or epoch == 1:
                print(f'Epoch {epoch:04d}: Loss {loss:.4f} | Train Acc {train_acc:.4f} | Val Acc {val_acc:.4f} | Test Acc {test_acc:.4f}')
                if args.use_causal_filter and feature_filter and epoch % 200 == 0:
                    stats = feature_filter.get_stats()
                    print(f"Filter Stats: lambda={stats['lambda']:.4f}, gate_mean={stats['gate_stats']['mean']:.4f}")

        test_accs.append(final_test_acc * 100)
        val_accs.append(best_val * 100)
        print(f'Run {run + 1}: Best Val Acc: {best_val:.4f}, Test Acc: {final_test_acc:.4f}')

    
    test_mean = np.mean(test_accs)
    test_std = np.std(test_accs)
    val_mean = np.mean(val_accs)
    val_std = np.std(val_accs)
    
    print(f'\n=== Final Results ===')
    print(f'Validation Accuracy: {val_mean:.2f} ± {val_std:.2f}%')
    print(f'Test Accuracy: {test_mean:.2f} ± {test_std:.2f}%')
    print(f'Individual test runs: {[f"{acc:.2f}" for acc in test_accs]}')
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    print(f'\n=== Final Results of CRCG on CiteSeer Dataset ===')
    if args.use_causal_filter:
        print(f'CRCG + ImprovedCausalFilter: Test Accuracy = {test_mean:.2f} ± {test_std:.2f}%')
    else:
        print(f'CRCG (baseline): Test Accuracy = {test_mean:.2f} ± {test_std:.2f}%')


if __name__ == '__main__':
    main()
