import argparse
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool

from improved_filter import ImprovedCausalFilter
from utils.helper import set_seed
from utils.mask import set_masks, clear_masks
from utils.get_subgraph import split_graph, relabel
from gnn import SPMotifNet


class CausalAttNet(nn.Module):
    def __init__(self, causal_ratio, in_channels, channels):
        super(CausalAttNet, self).__init__()
        from torch_geometric.nn import LEConv
        self.conv1 = LEConv(in_channels=in_channels, out_channels=channels)
        self.conv2 = LEConv(in_channels=channels, out_channels=channels)
        self.mlp = nn.Sequential(
            nn.Linear(channels * 2, channels * 4),
            nn.ReLU(),
            nn.Linear(channels * 4, 1)
        )
        self.ratio = causal_ratio

    def forward(self, data):
        edge_attr_flat = data.edge_attr.view(-1) if data.edge_attr is not None else torch.ones(
            data.edge_index.size(1), device=data.x.device)
        data.edge_attr = edge_attr_flat.unsqueeze(1)
        data.edge_weight = edge_attr_flat

        x = F.relu(self.conv1(data.x, data.edge_index, edge_attr_flat))
        x = self.conv2(x, data.edge_index, edge_attr_flat)

        if x.size(0) == 0:
            empty = lambda: torch.empty((0,), device=x.device)
            return (empty(), empty(), edge_attr_flat, empty(), empty()), (empty(), empty(), edge_attr_flat, empty(), empty()), edge_attr_flat

        row, col = data.edge_index
        edge_rep = torch.cat([x[row], x[col]], dim=-1)
        edge_score = self.mlp(edge_rep).view(-1)

        (cei, cea, cew), (fei, fea, few) = split_graph(data, edge_score, self.ratio)
        causal_x, cei, causal_batch, _ = relabel(x, cei, data.batch)
        conf_x, fei, conf_batch, _ = relabel(x, fei, data.batch)

        return (causal_x, cei, cea, cew, causal_batch), (conf_x, fei, fea, few, conf_batch), edge_score


def train_epoch(model, att, feature_filter, loader, optimizer, criterion, device):
    model.train(); att.train()
    if feature_filter:
        feature_filter.train()

    total_loss = 0.0
    total_graphs = 0

    for graph in loader:
        graph = graph.to(device)
        if graph.num_graphs == 0:
            continue

        if feature_filter:
            graph.x = feature_filter(graph.x)

        (cx, cei, cea, cew, cb), (fx, fei, fea, few, fb), _ = att(graph)

        
        if cea.dim() == 1:
            cea = cea.unsqueeze(1)
        if fea.dim() == 1:
            fea = fea.unsqueeze(1)

        set_masks(cew, model)
        out_c = model(x=cx, edge_index=cei, edge_attr=cea, batch=cb)
        clear_masks(model)

        set_masks(few, model)
        out_f = model(x=fx, edge_index=fei, edge_attr=fea, batch=fb)
        clear_masks(model)

        
        def to_graph_logits(out, batch, num_graphs):
            if out.size(0) == num_graphs:
                return out
            
            return global_mean_pool(out, batch)

        out_c_g = to_graph_logits(out_c, cb, graph.num_graphs)
        out_f_g = to_graph_logits(out_f, fb, graph.num_graphs)

        target = graph.y.view(-1)
        loss_c = criterion(out_c_g, target)
        loss_f = criterion(out_f_g, target)
        loss = loss_c + loss_f

        optimizer.zero_grad(); loss.backward(); optimizer.step()

        total_loss += loss.item() * graph.num_graphs
        total_graphs += graph.num_graphs

    avg_loss = total_loss / max(1, total_graphs)
    return avg_loss


def evaluate(model, att, feature_filter, loader, device):
    model.eval(); att.eval()
    if feature_filter:
        feature_filter.eval()

    correct = 0
    total = 0
    with torch.no_grad():
        for graph in loader:
            graph = graph.to(device)
            if graph.num_graphs == 0:
                continue
            if feature_filter:
                graph.x = feature_filter(graph.x)

            (cx, cei, cea, cew, cb), _unused, _ = att(graph)

            if cea.dim() == 1:
                cea = cea.unsqueeze(1)

            set_masks(cew, model)
            out = model(x=cx, edge_index=cei, edge_attr=cea, batch=cb)
            clear_masks(model)

            if out.size(0) != graph.num_graphs:
                out = global_mean_pool(out, cb)

            pred = out.argmax(dim=1)
            correct += (pred == graph.y.view(-1)).sum().item()
            total += graph.num_graphs

    return correct / total if total > 0 else 0.0


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--channels', type=int, default=64)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--runs', type=int, default=5)
    parser.add_argument('--use_causal_filter', action='store_true', default=True)
    parser.add_argument('--filter_lambda_init', type=float, default=1.0)
    parser.add_argument('--filter_lambda_min', type=float, default=-2.0)
    parser.add_argument('--filter_decay', type=float, default=0.99)
    parser.add_argument('--filter_temp', type=float, default=1.0)
    parser.add_argument('--filter_residual', type=float, default=0.2)
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    print(f"Device: {device}, use_causal_filter: {args.use_causal_filter}")

    random.seed(42); np.random.seed(42); torch.manual_seed(42)

    test_accs = []

    for run in range(args.runs):
        print(f"\n=== Run {run + 1} ===")
        set_seed(42 + run)

        dataset = TUDataset(root='./data/ENZYMES', name='ENZYMES', transform=NormalizeFeatures())
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        n = len(indices)
        n_train = int(n * 0.8)
        n_val = int(n * 0.1)
        train_idx = indices[:n_train]
        val_idx = indices[n_train:n_train + n_val]
        test_idx = indices[n_train + n_val:]

        train_dataset = dataset[train_idx]
        val_dataset = dataset[val_idx]
        test_dataset = dataset[test_idx]

        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

        num_features = dataset.num_node_features
        num_classes = dataset.num_classes

        feature_filter = None
        if args.use_causal_filter:
            feature_filter = ImprovedCausalFilter(
                input_dim=num_features,
                lambda_init=args.filter_lambda_init,
                lambda_min=args.filter_lambda_min,
                decay_rate=args.filter_decay,
                temperature=args.filter_temp,
                residual_weight=args.filter_residual,
                normalize=False,
                dropout=0.1
            ).to(device)

        model = SPMotifNet(args.channels, num_classes=num_classes).to(device)
        att = CausalAttNet(causal_ratio=0.7, in_channels=num_features, channels=args.channels).to(device)

        params = list(model.parameters()) + list(att.parameters()) + (list(feature_filter.parameters()) if feature_filter else [])
        optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
        criterion = nn.CrossEntropyLoss()

        best_val = 0.0
        best_test = 0.0

        for epoch in range(1, args.epochs + 1):
            loss = train_epoch(model, att, feature_filter, train_loader, optimizer, criterion, device)
            train_acc = evaluate(model, att, feature_filter, train_loader, device)
            val_acc = evaluate(model, att, feature_filter, val_loader, device)
            test_acc = evaluate(model, att, feature_filter, test_loader, device)

            if feature_filter:
                try:
                    feature_filter.step()
                except Exception:
                    pass

            if val_acc > best_val:
                best_val = val_acc
                best_test = test_acc

            if epoch % 10 == 0 or epoch == 1:
                print(f'Epoch {epoch:03d}: Loss {loss:.4f} | Train {train_acc:.4f} | Val {val_acc:.4f} | Test {test_acc:.4f}')
                if feature_filter and epoch % 20 == 0:
                    try:
                        print(feature_filter.get_stats())
                    except Exception:
                        pass

        test_accs.append(best_test * 100)
        print(f'Run {run + 1}: Best Val {best_val:.4f}, Test @best_val {best_test:.4f}')

    mean = np.mean(test_accs)
    std = np.std(test_accs)
    print(f'\n=== Final ===\nTest Accuracy: {mean:.2f} \u00b1 {std:.2f}%')
    print('Individual runs:', [f"{a:.2f}" for a in test_accs])


if __name__ == '__main__':
    main()
