import sys, os
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import argparse
import copy
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.data import DataLoader

from utils.mask import set_masks, clear_masks
from utils.helper import set_seed, args_print
from utils.get_subgraph import split_graph, relabel
from gnn.molhiv_gnn import MolHivNet
from improved_filter import ImprovedCausalFilter


class CausalAttNet(nn.Module):
    def __init__(self, causal_ratio, in_channels, channels):
        super(CausalAttNet, self).__init__()
        from torch_geometric.nn import LEConv
        self.conv1 = LEConv(in_channels=in_channels, out_channels=channels)
        self.conv2 = LEConv(in_channels=channels, out_channels=channels)
        self.mlp = nn.Sequential(
            nn.Linear(channels * 2, channels * 4),
            nn.ReLU(),
            nn.Linear(channels * 4, 1)
        )
        self.ratio = causal_ratio

    def forward(self, data):
        edge_attr_flat = data.edge_attr.view(-1) if data.edge_attr is not None else torch.ones(
            data.edge_index.size(1), device=data.x.device)
        data.edge_attr = edge_attr_flat.unsqueeze(1)
        data.edge_weight = edge_attr_flat

        x = F.relu(self.conv1(data.x, data.edge_index, edge_attr_flat))
        x = self.conv2(x, data.edge_index, edge_attr_flat)

        if x.size(0) == 0:
            empty = lambda: torch.empty((0,), device=data.x.device)
            return (empty(), empty(), edge_attr_flat, empty(), empty()), \
                   (empty(), empty(), edge_attr_flat, empty(), empty()), edge_attr_flat

        row, col = data.edge_index
        edge_rep = torch.cat([x[row], x[col]], dim=-1)
        edge_score = self.mlp(edge_rep).view(-1)

        (cei, cea, cew), (fei, fea, few) = split_graph(data, edge_score, self.ratio)
        causal_x, cei, causal_batch, _ = relabel(x, cei, data.batch)
        conf_x, fei, conf_batch, _ = relabel(x, fei, data.batch)

        return (causal_x, cei, cea, cew, causal_batch), \
               (conf_x, fei, fea, few, conf_batch), edge_score


def eval_loader(loader, g, att_net, feature_filter, device):
    acc, total = 0, 0
    with torch.no_grad():
        for graph in loader: 
            graph = graph.to(device)
            if feature_filter:
                x_backup = graph.x
                graph.x = feature_filter(graph.x)

            (cx, ei, ea, ew, cb), _unused, _ = att_net(graph)
            if cx.size(0) > 0:
                ea_local = ea.unsqueeze(1) if ea.dim() == 1 else ea
                ea_local = ea_local.long()
                set_masks(ew, g)
                out = g(x=cx, edge_index=ei, edge_attr=ea_local, batch=cb)
                clear_masks(g)
                acc += (out.argmax(dim=1) == graph.y.view(-1)).sum().item()
            total += graph.num_graphs
    return acc / total if total > 0 else 0.0


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train DIR on ENZYMES (Graph Classification)')
    parser.add_argument('--cuda', default=0, type=int)
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--channels', default=128, type=int)
    parser.add_argument('--r', default=0.7, type=float)
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--net_lr', default=1e-3, type=float)
    parser.add_argument('--use_causal_filter', action='store_true', default=True)
    parser.add_argument('--filter_lambda_init', type=float, default=1.0)
    parser.add_argument('--filter_decay', type=float, default=0.99)
    parser.add_argument('--filter_lambda_min', type=float, default=-2.0)
    parser.add_argument('--filter_temp', type=float, default=1.0)
    parser.add_argument('--filter_residual', type=float, default=0.2)
    parser.add_argument('--runs', type=int, default=5)
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()

    set_seed(args.seed)
    device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')

    print(f"Device: {device}, use_causal_filter: {args.use_causal_filter}")

    test_accs = []

    for run in range(args.runs):
        print(f"\n=== Run {run + 1} ===")
        set_seed(args.seed + run)

        dataset = TUDataset(root='./data/ENZYMES', name='ENZYMES', transform=NormalizeFeatures())
        
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        n = len(indices)
        n_train = int(n * 0.8)
        n_val = int(n * 0.1)
        train_idx = indices[:n_train]
        val_idx = indices[n_train:n_train + n_val]
        test_idx = indices[n_train + n_val:]

        train_dataset = dataset[train_idx]
        val_dataset = dataset[val_idx]
        test_dataset = dataset[test_idx]

        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

        num_features = dataset.num_node_features
        num_classes = dataset.num_classes

        feature_filter = None
        if args.use_causal_filter:
            feature_filter = ImprovedCausalFilter(
                input_dim=num_features,
                lambda_init=args.filter_lambda_init,
                lambda_min=args.filter_lambda_min,
                decay_rate=args.filter_decay,
                temperature=args.filter_temp,
                residual_weight=args.filter_residual,
                normalize=False,
                dropout=0.1
            ).to(device)

        g = MolHivNet(num_tasks=num_classes, emb_dim=args.channels).to(device)
        att_net = CausalAttNet(causal_ratio=args.r, in_channels=num_features, channels=args.channels).to(device)
        params = list(g.parameters()) + list(att_net.parameters()) + (list(feature_filter.parameters()) if feature_filter else [])
        optimizer = torch.optim.Adam(params, lr=args.net_lr)
        criterion = nn.CrossEntropyLoss()

        best_val = 0.0
        best_test = 0.0

        for epoch in range(1, args.epochs + 1):
            g.train(); att_net.train()
            if feature_filter: feature_filter.train()
            total_loss, n_batches = 0.0, 0

            for graph in train_loader:
                graph = graph.to(device)
                if graph.num_graphs == 0:
                    continue
                if feature_filter:
                    graph.x = feature_filter(graph.x)

                (cx, cei, cea, cew, cb), (fx, fei, fea, few, fb), _ = att_net(graph)

                cea = (cea.unsqueeze(1) if cea.dim() == 1 else cea).long()
                fea = (fea.unsqueeze(1) if fea.dim() == 1 else fea).long()

                set_masks(cew, g)
                out_c = g(x=cx, edge_index=cei, edge_attr=cea, batch=cb)
                clear_masks(g)
                loss_c = criterion(out_c, graph.y.view(-1))

                set_masks(few, g)
                out_f = g(x=fx, edge_index=fei, edge_attr=fea, batch=fb)
                clear_masks(g)
                loss_f = criterion(out_f, graph.y.view(-1))

                loss = loss_c + loss_f
                optimizer.zero_grad(); loss.backward(); optimizer.step()
                total_loss += loss.item(); n_batches += 1

            avg_loss = total_loss / n_batches if n_batches else 0.0

            train_acc = eval_loader(train_loader, g, att_net, feature_filter, device) * 100
            val_acc = eval_loader(val_loader, g, att_net, feature_filter, device) * 100
            test_acc = eval_loader(test_loader, g, att_net, feature_filter, device) * 100

            if val_acc >= best_val:
                best_val = val_acc
                best_test = test_acc

            if feature_filter:
                feature_filter.step()

            if epoch % 10 == 0 or epoch == 1:
                print(f'Epoch {epoch:03d}: Loss {avg_loss:.4f} | Train {train_acc/100:.4f} | Val {val_acc/100:.4f} | Test {test_acc/100:.4f}')
                if feature_filter and epoch % 20 == 0:
                    stats = feature_filter.get_stats()
                    print(f"Filter Stats epoch {epoch}: lambda={stats.get('lambda',0):.4f}, gate_mean={stats.get('gate_stats',{}).get('mean',0):.4f}")

        test_accs.append(best_test)
        print(f'Run {run + 1}: Best Val {best_val/100:.4f}, Test @best_val {best_test/100:.4f}')

    mean = np.mean(test_accs)
    std = np.std(test_accs)
    print(f'\n=== Final ===\nTest Accuracy: {mean:.2f} ± {std:.2f}%')
    print('Individual runs:', [f"{a:.2f}" for a in test_accs])
