import sys, os
import logging

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import copy
import torch
import argparse
import random
import numpy as np
from datetime import datetime
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import LEConv
from utils.mask import set_masks, clear_masks
from utils.helper import set_seed, args_print
from utils.get_subgraph import split_graph, relabel
from datasets.molecular_dir_dataset import MolecularDIRDataset
from gnn.molhiv_gnn import MolHivNet
from improved_filter import ImprovedCausalFilter


class CausalAttNet(nn.Module):
    def __init__(self, causal_ratio, in_channels, channels):
        super(CausalAttNet, self).__init__()
        self.conv1 = LEConv(in_channels=in_channels, out_channels=channels)
        self.conv2 = LEConv(in_channels=channels, out_channels=channels)
        self.mlp = nn.Sequential(
            nn.Linear(channels * 2, channels * 4),
            nn.ReLU(),
            nn.Linear(channels * 4, 1)
        )
        self.ratio = causal_ratio

    def forward(self, data):
        
        edge_attr_flat = data.edge_attr.view(-1) if data.edge_attr is not None else torch.ones(
            data.edge_index.size(1), device=data.x.device)
        data.edge_attr = edge_attr_flat.unsqueeze(1)
        data.edge_weight = edge_attr_flat

        
        x = F.relu(self.conv1(data.x, data.edge_index, edge_attr_flat))
        x = self.conv2(x, data.edge_index, edge_attr_flat)

        
        if x.size(0) == 0:
            empty = lambda: torch.empty((0,), device=data.x.device)
            return (empty(), empty(), edge_attr_flat, empty(), empty()), \
                   (empty(), empty(), edge_attr_flat, empty(), empty()), edge_attr_flat

        
        row, col = data.edge_index
        edge_rep = torch.cat([x[row], x[col]], dim=-1)
        edge_score = self.mlp(edge_rep).view(-1)

        
        (cei, cea, cew), (fei, fea, few) = split_graph(data, edge_score, self.ratio)
        causal_x, cei, causal_batch, _ = relabel(x, cei, data.batch)
        conf_x, fei, conf_batch, _   = relabel(x, fei, data.batch)

        return (causal_x, cei, cea, cew, causal_batch), \
               (conf_x, fei, fea, few, conf_batch), edge_score


def test_acc(loader, net, predictor, device):
    acc, total = 0, 0
    for graph in loader:
        graph = graph.to(device)
        total += graph.num_graphs
        (cx, ei, ea, ew, cb), _unused, _ = net(graph)
        if cx.size(0) > 0:
            
            if ea.dim() == 1:
                ea = ea.unsqueeze(1)
            ea = ea.long()
            set_masks(ew, predictor)
            out = predictor(x=cx, edge_index=ei, edge_attr=ea, batch=cb)
            clear_masks(predictor)
            acc += (out.argmax(dim=1) == graph.y.view(-1)).sum().item()
    return acc / total if total > 0 else 0.0

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Training DIR on Molecular Dataset')
    parser.add_argument('--cuda', default=0, type=int)
    parser.add_argument('--datadir', default='./data/molecular_dataset', type=str)
    parser.add_argument('--epoch', default=50, type=int)
    parser.add_argument('--reg', default=True, type=bool)
    parser.add_argument('--seed', nargs='?', default='[1,2,3]')
    parser.add_argument('--channels', default=128, type=int)
    parser.add_argument('--exp_name', default='molecular_dir', type=str)
    parser.add_argument('--r', default=0.7, type=float)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--net_lr', default=1e-3, type=float)
    parser.add_argument('--use_causal_filter', action='store_true', default=True, help='Enable improved causal filter on node features before attention net')
    parser.add_argument('--filter_lambda_init', type=float, default=1.0)
    parser.add_argument('--filter_decay', type=float, default=0.99)
    parser.add_argument('--filter_lambda_min', type=float, default=-2.0)
    parser.add_argument('--filter_temp', type=float, default=1.0)
    parser.add_argument('--filter_residual', type=float, default=0.2)
    args = parser.parse_args()
    
    seed_val = eval(args.seed)
    if isinstance(seed_val, (list, tuple)) and seed_val:
        args.seed = int(seed_val[0])
    else:
        args.seed = int(seed_val)
    set_seed(args.seed)

    
    device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
    dataset_root = args.datadir
    print(f"Loading dataset from root: {dataset_root}")

    
    train_dataset = MolecularDIRDataset(root=dataset_root, mode='train')
    val_dataset   = MolecularDIRDataset(root=dataset_root, mode='val')
    test_dataset  = MolecularDIRDataset(root=dataset_root, mode='test')
    
    proc_paths = [train_dataset.processed_paths[0], val_dataset.processed_paths[0], test_dataset.processed_paths[0]]
    if not all(os.path.exists(p) for p in proc_paths):
        print("Processed files not found, processing splits...")
        MolecularDIRDataset(root=dataset_root, mode='train')
        MolecularDIRDataset(root=dataset_root, mode='val')
        MolecularDIRDataset(root=dataset_root, mode='test')
        print("Processing done.")

    
    if any(len(ds) == 0 for ds in (train_dataset, val_dataset, test_dataset)):
        raise ValueError("One or more splits are empty.")

    
    num_features = train_dataset.num_features
    all_labels = torch.cat([d.y for d in train_dataset] + [d.y for d in val_dataset] + [d.y for d in test_dataset])
    num_classes = int(all_labels.max().item()) + 1
    print(f"Num Features: {num_features}, Num Classes: {num_classes}")

    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset,   batch_size=args.batch_size)
    test_loader  = DataLoader(test_dataset,  batch_size=args.batch_size)

    
    now = datetime.now().strftime("%Y%m%d-%H%M%S")
    exp = f"{args.exp_name}.r_{args.r}.lr_{args.net_lr}.bs_{args.batch_size}.ch_{args.channels}.seed_{args.seed}.{now}"
    exp_dir = os.path.join('logs', args.exp_name, exp)
    os.makedirs(exp_dir, exist_ok=True)
    logging.basicConfig(filename=os.path.join(exp_dir, '_output_.log'),
                        level=logging.INFO,
                        format='%(asctime)s %(levelname)s: %(message)s')
    logger = logging.getLogger()
    args_print(args, logger)

    feature_filter = None
    if args.use_causal_filter:
        feature_filter = ImprovedCausalFilter(
            input_dim=num_features,
            lambda_init=args.filter_lambda_init,
            lambda_min=args.filter_lambda_min,
            decay_rate=args.filter_decay,
            temperature=args.filter_temp,
            residual_weight=args.filter_residual,
            normalize=False,
            dropout=0.1
        ).to(device)
    
    g = MolHivNet(num_tasks=num_classes, emb_dim=args.channels).to(device)
    att_net = CausalAttNet(causal_ratio=args.r, in_channels=num_features, channels=args.channels).to(device)
    params = list(g.parameters()) + list(att_net.parameters()) + (list(feature_filter.parameters()) if feature_filter else [])
    optimizer = torch.optim.Adam(params, lr=args.net_lr)
    criterion = nn.CrossEntropyLoss()

    
    best_g, best_att, best_val = None, None, 0.0
    
    val_acc_list_5_epoch = []
    final_train_accuracies = []
    final_in_test_accuracies = []
    final_test_accuracies = []

    best_val = 0.0
    best_g, best_att = None, None

    for epoch in range(args.epoch):
        g.train(); att_net.train()
        if feature_filter: feature_filter.train()
        total_loss, n_batches = 0.0, 0
        for graph in train_loader:
            graph = graph.to(device)
            if graph.num_graphs == 0:
                continue
            if feature_filter:
                graph.x = feature_filter(graph.x)
            
            (cx, cei, cea, cew, cb), (fx, fei, fea, few, fb), _ = att_net(graph)
            
            cea = (cea.unsqueeze(1) if cea.dim() == 1 else cea).long()
            fea = (fea.unsqueeze(1) if fea.dim() == 1 else fea).long()
            
            set_masks(cew, g)
            out_c = g(x=cx, edge_index=cei, edge_attr=cea, batch=cb)
            clear_masks(g)
            loss_c = criterion(out_c, graph.y.view(-1))
            
            set_masks(few, g)
            out_f = g(x=fx, edge_index=fei, edge_attr=fea, batch=fb)
            clear_masks(g)
            loss_f = criterion(out_f, graph.y.view(-1))
            loss = loss_c + loss_f
            optimizer.zero_grad(); loss.backward(); optimizer.step()
            total_loss += loss.item(); n_batches += 1

        avg_loss = total_loss / n_batches if n_batches else 0.0

        
        g.eval(); att_net.eval()
        if feature_filter: feature_filter.eval();
        def eval_loader(loader):
            acc, total = 0, 0
            for graph in loader:
                graph = graph.to(device)
                if feature_filter:
                    x_backup = graph.x
                    graph.x = feature_filter(graph.x)
                (cx, ei, ea, ew, cb), _unused, _ = att_net(graph)
                if cx.size(0) > 0:
                    if ea.dim() == 1: ea_local = ea.unsqueeze(1)
                    else: ea_local = ea
                    ea_local = ea_local.long()
                    set_masks(ew, g)
                    out = g(x=cx, edge_index=ei, edge_attr=ea_local, batch=cb)
                    clear_masks(g)
                    acc += (out.argmax(dim=1) == graph.y.view(-1)).sum().item()
                total += graph.num_graphs
            return acc / total if total>0 else 0.0
        train_acc = eval_loader(train_loader) * 100
        val_acc = eval_loader(val_loader) * 100

        
        final_train_accuracies.append(train_acc)
        final_in_test_accuracies.append(val_acc)
        final_test_accuracies.append(val_acc)
        
        val_acc_list_5_epoch.append(val_acc)

        logger.info(f"Epoch {epoch:03d}: Loss={avg_loss:.4f}, TrainAcc={train_acc:.4f}, ValAcc={val_acc:.4f}")

        
        if (epoch + 1) % 5 == 0:
            std_val = np.std(val_acc_list_5_epoch)
            logger.info(f"[Epoch {epoch - 4:03d}~{epoch:03d}] ValAcc Std: {std_val:.4f}")
            val_acc_list_5_epoch = []

        if val_acc >= best_val:
            best_val = val_acc
            best_g = copy.deepcopy(g)
            best_att = copy.deepcopy(att_net)
        if feature_filter:
            feature_filter.step()
            if (epoch+1) % 10 == 0:
                stats = feature_filter.get_stats()
                logger.info(f"Filter stats epoch {epoch}: {stats}")
    
    if feature_filter: feature_filter.eval()
    test_acc_val = eval_loader(test_loader) * 100
    
    
    final_train_mean = np.mean(final_train_accuracies)
    final_train_std = np.std(final_train_accuracies)
    final_test_std = np.std(final_test_accuracies)
    print(f"Train Acc  {final_train_mean:.2f} ± {final_train_std:.2f}  Test Acc  {test_acc_val:.2f} ± {final_test_std:.2f}")
    torch.save(g.state_dict(), os.path.join(exp_dir, 'best_g.pth'))
    torch.save(att_net.state_dict(), os.path.join(exp_dir, 'best_att.pth'))
