import torch
import os
import numpy as np
import argparse
import logging
import time
from datetime import datetime
from itertools import product

from greatx.nn.models import GCN, GAT, GNNGUARD, ElasticGNN, RUNG, RobustGCN, GCORN, NoisyGCN, APPNP, GPRGNN, TWIRLS, EvenNet, SoftMedianGCN, SoftMedianGDC, GRAND, GUARDDUAL
from greatx.training import Trainer, GRANDTrainer
from greatx.training.callbacks import ModelCheckpoint, EarlyStopping

import sys
sys.path.append("../")
from common import set_seed, load_inductive_graph_dataset_for_gnn

import yaml

# ------------ Argument Parsing ------------
parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data/")
parser.add_argument("--graph_save_dir", type=str, default="atkg")
parser.add_argument("--config", type=str, default="config.yaml")
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--model", type=str, default='gcn')
parser.add_argument("--attack", type=str, default="wtgia")
parser.add_argument("--injection", type=str, default="random", choices=["random", "tdgia", "atdgia"])
parser.add_argument("--def_emb_type", type=str, default="bow", choices=['bow', 'roberta', 'Mistral-7B', 'MiniLM'])
parser.add_argument("--re_split", type=int, default=2)
parser.add_argument("--ptb_rate", type=float, default=0.2)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--patience", type=int, default=50)
parser.add_argument("--epochs", type=int, default=400)
parser.add_argument("--use_existing_logs", action='store_true')
args = parser.parse_args()

# ------------ Load Config ------------
with open(args.config, 'r') as f:
    config = yaml.safe_load(f)

# Get model-specific hyperparameters
model_config = config['hyperparams'].get(args.model, {})
hyperparams = {}
for param_name, param_values in model_config.items():
    hyperparams[param_name] = param_values

# Skip memory-intensive models for large datasets
if args.model in ['rung', 'twirls', 'softmedian', 'softmediangdc'] and args.dataset in ["computer", "arxiv"]:
    print("Skipping rung, twirls, softmedian, and softmediangdc for computer, photo, and arxiv due to memory constraints")
    exit(0)

if args.dataset != 'arxiv':
    seeds = range(3)
else:
    args.re_split = 0
    seeds = [0]

# ------------ Fixed Settings ------------
root_path = args.root_path
graph_save_dir = f"{root_path}/{args.graph_save_dir}"
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

# ------------ Logging Setup ------------
attack_name = f"wtgia_{args.injection}"
log_dir = f"eval_logs_ind_wtgia/{args.dataset}/{args.model}"
os.makedirs(log_dir, exist_ok=True)
log_file = f"{log_dir}/{attack_name}_bow_{args.def_emb_type}_ptb{int(args.ptb_rate*100)}.log"  # WTGIA always uses bow for attack

# Check if log file exists, skip experiment if it does
if args.use_existing_logs and os.path.exists(log_file):
    print(f"Log file {log_file} already exists. Experiment already completed. Skipping...")
    exit(0)

logging.basicConfig(filename=log_file, level=logging.INFO, filemode='w', format='%(message)s')
logger = logging.getLogger()

logger.info("=" * 80)
logger.info(f"WTGIA EVALUATION - Dataset: {args.dataset}, Model: {args.model}")
logger.info(f"Attack: {attack_name}, Injection: {args.injection}, Atk Embedding: bow, Def Embedding: {args.def_emb_type}")
logger.info(f"PTB Rate: {args.ptb_rate}")
logger.info(f"Training: epochs={args.epochs}, patience={args.patience}")
logger.info(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
logger.info("=" * 80)

# ------------ Run over hyperparams ------------
best_results = []

for hp_combo in product(*hyperparams.values()):
    hp_dict = dict(zip(hyperparams.keys(), hp_combo))
    hp_str = "_".join([f"{k}={v}" for k, v in hp_dict.items()])
    
    # Extract common parameters
    lr = hp_dict.get('learning_rates', 0.01)
    wd = hp_dict.get('weight_decays', 0.0)
    dropout = hp_dict.get('dropouts', 0.5)
    if args.dataset in ['cora', 'citeseer', 'pubmed']:
        hidden_dim = 128
    else:
        hidden_dim = 256
    if args.model == 'gat':
        hidden_dim = hidden_dim // 8
    
    logger.info("-" * 60)
    logger.info(f"Hyperparams: {dict(zip(hyperparams.keys(), hp_combo))}")
    start_time = time.time()
    
    clean_accs, clean_val_accs, attack_accs_recomputed = [], [], []
    
    for seed in seeds:
        set_seed(seed)
        
        # Load clean data
        full_data, (train_data, val_data, test_data) = load_inductive_graph_dataset_for_gnn(
            args.dataset, device, re_split=args.re_split,
            path_prefix=args.root_path, emb_model=args.def_emb_type, seed=seed
        )
        
        train_data = train_data.to(device)
        val_data = val_data.to(device)
        test_data = test_data.to(device)
        
        num_features = full_data.x.shape[-1]
        num_classes = full_data.y.max().item() + 1
        
        if args.model == 'gcn':
            model = GCN(num_features, num_classes, hids=[64])
        elif args.model == 'gat':
            model = GAT(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'appnp':
            alpha = hp_dict.get('alpha', 0.1)
            model = APPNP(num_features, num_classes, hids=[hidden_dim], dropout=dropout, alpha=alpha)
        elif args.model == 'gprgnn':
            alpha = hp_dict.get('alpha', 0.1)
            model = GPRGNN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, alpha=alpha)
        elif args.model == 'robustgcn':
            model = RobustGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'elasticgnn':
            lambda1 = hp_dict.get('lambda1', 0)
            lambda2 = hp_dict.get('lambda2', 0)
            model = ElasticGNN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                              lambda1=lambda1, lambda2=lambda2, cached=False)
        elif args.model == 'gnnguard':
            threshold = hp_dict.get('threshold', 0.1)
            model = GNNGUARD(num_features, num_classes, hids=[hidden_dim], dropout=dropout, threshold=threshold)
        elif args.model == 'guarddual':
            if args.def_emb_type != 'roberta':
                logger.info("GuardDUAL requires roberta embeddings, skipping...")
                continue
            num_train_nodes = train_data.x.shape[0]
            num_val_nodes = val_data.x.shape[0] - num_train_nodes
            model = GUARDDUAL(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                              dataset_name=args.dataset, num_train_nodes=num_train_nodes, num_val_nodes=num_val_nodes)
        elif args.model == 'noisy_gcn':
            beta = hp_dict.get('beta', 0.1)
            model = NoisyGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, noise_ratio_1=beta)
        elif args.model == 'gcorn':
            model = GCORN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, order=3)
        elif args.model == 'evennet':
            K = hp_dict.get('K', 10)
            alpha = hp_dict.get('alpha', 0.1)
            dprate = hp_dict.get('dprate', 0.5)
            model = EvenNet(num_features, num_classes, hids=[hidden_dim], dropout=dropout, K=K, alpha=alpha, dprate=dprate)
        elif args.model == 'twirls':
            alp = hp_dict.get('alp', 1.0)
            lam = hp_dict.get('lam', 1.0)
            prop_step = hp_dict.get('prop_step', 32)
            attention = hp_dict.get('attention', True)
            model = TWIRLS(num_features, num_classes, hidden_channels=hidden_dim, dropout=dropout, alp=alp, lam=lam, prop_step=prop_step, attention=attention)
        elif args.model == 'rung':  
            lamb = hp_dict.get('lamb', 0.8)
            gamma = hp_dict.get('gamma', 2)
            model = RUNG(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                        lam_hat=lamb, gamma=gamma)
        elif args.model == 'softmedian':
            temperature = hp_dict.get('temperature', 0.5)
            model = SoftMedianGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                                 temperature=temperature,  
                                 row_normalize=False, normalize=False, cached=False)
        elif args.model == 'softmediangdc':
            temperature = hp_dict.get('temperature', 1.0)
            teleport_proba = hp_dict.get('teleport_proba', 0.15)
            neighbors = hp_dict.get('neighbors', 64)
            model = SoftMedianGDC(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                                 temperature=temperature, teleport_proba=teleport_proba,
                                 neighbors=neighbors, cached=False)
        elif args.model == 'grand':
            dropnode = hp_dict.get('dropnode', 0.5)
            order = hp_dict.get('order', 2)
            mlp_input_dropout = hp_dict.get('mlp_input_dropout', 0.5)
            model = GRAND(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                         dropnode=dropnode, order=order, mlp_input_dropout=mlp_input_dropout)
        else:
            raise ValueError(f"Unsupported model: {args.model}")
        
        if args.model.lower() == "grand":
            trainer = GRANDTrainer(model, device=device)
        else:
            trainer = Trainer(model, device=device, verbose=0)
        
        trainer.reset_optimizer(lr=0.01)
        
        # Train model
        ckp = ModelCheckpoint(f'wtgia_clean_{args.dataset}_{args.model}_{args.def_emb_type}_{args.ptb_rate}_{seed}.pth', monitor='val_acc')
        es = EarlyStopping(monitor='val_acc', patience=args.patience)
        trainer.fit((train_data, val_data), epochs=args.epochs, callbacks=[ckp, es], verbose=0)
        
        # Evaluate clean performance
        clean_logs = trainer.evaluate(test_data, mask=full_data.test_mask)
        val_logs = trainer.evaluate(val_data)
        clean_accs.append(clean_logs['acc'])
        clean_val_accs.append(val_logs['acc'])
        
        # Load WTGIA attack features and edge index (WTGIA always uses bow embeddings for attack)
        attack_name = f"wtgia_{args.injection}"
        attack_edge_path = f"{args.root_path}/atkg/{args.dataset}/{attack_name}/bow_{int(args.ptb_rate*100)}_{seed}.pt"
        attack_recomputed_path = f"{args.root_path}/atkg/{args.dataset}/{attack_name}_recomputed_features/bow_{int(args.ptb_rate*100)}_{seed}.pt"
        
        attack_edge_index = torch.load(attack_edge_path, map_location=device)
        attack_recomputed_features = torch.load(attack_recomputed_path, map_location=device)
        
        # Evaluate WTGIA attack with node ID remapping
        n_inject = attack_recomputed_features.shape[0]  # Get actual number of injected nodes
        
        # Check if edge index needs remapping (uses full graph node IDs)
        subgraph_max = test_data.x.shape[0] - 1
        if (attack_edge_index > subgraph_max).any():
            # Create node ID mapping from full graph space to subgraph space
            # Use bow embeddings to get the correct node count (WTGIA uses bow)
            full_graph_bow_features = torch.load(f"{args.root_path}/datasets/bow/{args.dataset}.pt", map_location=device)
            full_to_sub_mapping = torch.full((full_graph_bow_features.shape[0] + n_inject,), -1, dtype=torch.long, device=device)
            
            # Map original nodes
            for sub_idx, full_idx in enumerate(test_data.node_ids):
                full_to_sub_mapping[full_idx.item()] = sub_idx
            
            # Map injected nodes (starting after original nodes)
            n_original = test_data.x.shape[0]  # Number of original nodes in test subgraph
            for i in range(n_inject):
                injected_full_id = n_original + i  # Injected nodes start after original nodes in full graph
                injected_sub_id = n_original + i   # Same in subgraph after concatenation
                full_to_sub_mapping[injected_full_id] = injected_sub_id
            
            # Remap edge index to subgraph space
            attack_edge_index = full_to_sub_mapping[attack_edge_index]
        
        # Construct evaluation data - recompute embeddings if defense uses different embedding type
        if args.def_emb_type != 'bow':
            # For non-bow defense embeddings, we need to recompute embeddings for injected texts
            # Load injected texts and compute embeddings in defense embedding space
            texts_path = f"{args.root_path}/atkg/{args.dataset}/{attack_name}_texts/llama-3.1-8B_{int(args.ptb_rate*100)}_{seed}.json"
            if os.path.exists(texts_path):
                import json
                with open(texts_path, 'r') as f:
                    text_data = json.load(f)
                    injected_texts = text_data['texts']
                
                # Generate embeddings for injected texts using defense embedding type
                sys.path.append("../")
                from common.lm import TextEncoder
                
                encoder_type = "LLM" if args.def_emb_type in ["Mistral-7B", "Qwen-7B"] else "LM"
                text_encoder = TextEncoder(args.def_emb_type, encoder_type, device)
                
                injected_embeddings = []
                with torch.no_grad():
                    for text in injected_texts:
                        emb = text_encoder.forward(text, pooling="mean", max_length=512)
                        injected_embeddings.append(emb)
                    injected_embeddings = torch.cat(injected_embeddings, dim=0)
                
                del text_encoder
                torch.cuda.empty_cache()
                
                # Use recomputed embeddings that match defense embedding dimension
                attacked_features = torch.cat([test_data.x, injected_embeddings.to(device)], dim=0)
            else:
                raise FileNotFoundError(f"WTGIA texts not found at {texts_path}")
        else:
            # For bow defense, use the pre-computed bow recomputed features
            attacked_features = torch.cat([test_data.x, attack_recomputed_features.to(device)], dim=0)
        attacked_test_data = test_data.clone()
        attacked_test_data.x = attacked_features
        attacked_test_data.edge_index = attack_edge_index
        attacked_test_data.y = torch.cat([attacked_test_data.y, torch.zeros(n_inject, dtype=torch.long, device=device)])
        
        # Create test mask excluding injected nodes
        extended_test_mask = torch.cat([full_data.test_mask, torch.zeros(n_inject, dtype=torch.bool, device=device)])
        
        # Evaluate attack
        attack_logs = trainer.evaluate(attacked_test_data, mask=extended_test_mask)
        attack_accs_recomputed.append(attack_logs['acc'])
        
        # Cleanup
        del model, trainer
        torch.cuda.empty_cache()
    
    # Log result
    if len(clean_accs) > 0:
        avg_clean = np.mean(clean_accs)
        avg_attack_recomp = np.mean(attack_accs_recomputed) if attack_accs_recomputed else avg_clean
        
        logger.info(f"> Clean Val Acc: {np.mean(clean_val_accs):.4f} ± {np.std(clean_val_accs):.4f}")
        logger.info(f"> Clean Test Acc: {avg_clean:.4f} ± {np.std(clean_accs):.4f}")
        logger.info(f"> Attacked Test Acc: {avg_attack_recomp:.4f} ± {np.std(attack_accs_recomputed):.4f}")
        
        logger.info(f"> Time: {(time.time() - start_time) / len(seeds):.2f}s")

