import torch
import time
import os
import os.path as osp
import numpy as np
import argparse
import logging
import json
from datetime import datetime
from itertools import product

from torch_geometric.utils import to_undirected
from torch_geometric.data import Data

from greatx.attack.untargeted import Metattack, DICEAttack, PGDAttack, PRBCDAttack, GRBCDAttack
from greatx.datasets import GraphDataset
from greatx.nn.models import GCN, GAT, GNNGUARD, ElasticGNN, RUNG, RobustGCN, GCORN, NoisyGCN, APPNP, GPRGNN, ProGNN, Stable, PurificationGCN, EvenNet, TWIRLS, SoftMedianGCN, SoftMedianGDC, GRAND
from greatx.training import Trainer, GRANDTrainer
from greatx.training.callbacks import ModelCheckpoint, EarlyStopping
from greatx.utils import split_nodes

import sys
sys.path.append("../")
from common import set_seed, load_graph_dataset_for_gnn, load_atk_graph_dataset_for_gnn

import yaml

# ------------ Argument Parsing ------------
parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data/")
parser.add_argument("--config", type=str, default="config.yaml")
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--model", type=str, default='gcn')
parser.add_argument("--attack", type=str, default="textfooler", choices=['textfooler', 'llm', 'llm_Ministral', 'gpt'])
parser.add_argument("--atk_emb_type", type=str, default="bow", choices=['bow', 'roberta', 'Mistral-7B', 'MiniLM'])
parser.add_argument("--def_emb_type", type=str, default="bow", choices=['bow', 'roberta', 'Mistral-7B', 'MiniLM'])
parser.add_argument("--re_split", type=int, default=1)
parser.add_argument("--ptb_rate", type=float, default=0.2)
parser.add_argument("--device", type=int, default=1)
parser.add_argument("--patience", type=int, default=100)
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--use_existing_logs", action='store_true')
args = parser.parse_args()

# ------------ Load Config ------------
with open(args.config, 'r') as f:
    config = yaml.safe_load(f)

# Get model-specific hyperparameters
model_config = config['hyperparams'].get(args.model, {})
hyperparams = {}
for param_name, param_values in model_config.items():
    hyperparams[param_name] = param_values

assert args.re_split == 1, "Only transductive split with re_split=1 (1/1/8) is supported for defense evaluation"
assert args.dataset != 'arxiv', "arxiv dataset is not supported for transductive evaluation"

# ------------ Fixed Settings ------------
root_path = args.root_path
seeds = range(3)
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

# ------------ Logging Setup ------------
log_dir = f"eval_logs_trans_text/{args.dataset}/{args.model}/{args.attack}"
os.makedirs(log_dir, exist_ok=True)
log_file = f"{log_dir}/{args.atk_emb_type}_{args.def_emb_type}_ptb{int(args.ptb_rate*100)}.log"

# Check if log file exists, skip experiment if it does
if args.use_existing_logs and os.path.exists(log_file):
    print(f"Log file {log_file} already exists. Experiment already completed. Skipping...")
    exit(0)

logging.basicConfig(filename=log_file, level=logging.INFO, filemode='w', format='%(message)s')
logger = logging.getLogger()

logger.info("=" * 80)
logger.info(f"EXPERIMENT - Dataset: {args.dataset}, Model: {args.model}")
logger.info(f"Text Attack: {args.attack}")
logger.info(f"Embedding: atk={args.atk_emb_type}, def={args.def_emb_type}, ptb_rate={args.ptb_rate}")
logger.info(f"Training: epochs={args.epochs}, patience={args.patience}")
logger.info(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
logger.info("=" * 80)

# Skip for prognn if too large
if args.model == 'prognn' and args.dataset in ["pubmed", "reddit", "history", "computer", "photo", "wikics", "instagram"]:
    logger.info("Skipping prognn for large datasets")
    exit(0)
elif args.model == 'stable' and args.dataset in ["computer", "photo"]:
    logger.info("Skipping prognn and stable for large datasets")
    exit(0)
elif args.model == 'rung' and args.dataset in ["computer", "photo"]:
    logger.info("Skipping rung for large datasets")
    exit(0)
elif args.model == 'twirls' and args.dataset in ["computer", "photo"]:
    logger.info("Skipping twirls for large datasets")
    exit(0)
elif args.model == 'softmediangdc' and args.dataset in ["computer", "photo"]:
    logger.info("Skipping softmediangdc for large datasets")
    exit(0)

# ------------ Run over hyperparams ------------
# Create all combinations of hyperparameters
param_names = list(hyperparams.keys())
param_values = list(hyperparams.values())
param_combinations = list(product(*param_values))

for params in param_combinations:
    # Create a dictionary of parameter values
    param_dict = dict(zip(param_names, params))
    
    # Extract common parameters - align with inductive settings
    lr = param_dict.get('learning_rates', 0.01)
    wd = param_dict.get('weight_decays', 0.0)
    dropout = param_dict.get('dropouts', 0.5)
    if args.dataset in ['cora', 'citeseer', 'pubmed', 'instagram', 'wikics']:
        hidden_dim = 128
    else:
        hidden_dim = 256
    if args.model == 'gat':
        hidden_dim = hidden_dim // 8
    
    clean_accs, attacked_accs = [], []
    clean_val_accs, attacked_val_accs = [], []
    
    logger.info("-" * 60)
    logger.info(f"Hyperparams: {param_dict}")
    start_time = time.time()
    
    for seed in seeds:
        set_seed(seed)
        
        # Load clean data
        full_data = load_graph_dataset_for_gnn(args.dataset, device, re_split=args.re_split, path_prefix=root_path, emb_model=args.def_emb_type)
        full_data = full_data.to(device)

        num_features = full_data.x.shape[-1]
        num_classes = full_data.y.max().item() + 1
        
        # Initialize model based on args.model with appropriate hyperparameters
        if args.model == 'gcn':
            model = GCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'gat':
            model = GAT(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'appnp':
            alpha = param_dict.get('alpha', 0.1)
            model = APPNP(num_features, num_classes, hids=[hidden_dim], dropout=dropout, alpha=alpha)
        elif args.model == 'gprgnn':
            alpha = param_dict.get('alpha', 0.1)
            model = GPRGNN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, alpha=alpha)
        elif args.model == 'robustgcn':
            model = RobustGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'elasticgnn':
            lambda1 = param_dict.get('lambda1', 0)
            lambda2 = param_dict.get('lambda2', 0)
            model = ElasticGNN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                              lambda1=lambda1, lambda2=lambda2, cached=False)
        elif args.model == 'gnnguard':
            threshold = param_dict.get('threshold', 0.1)
            model = GNNGUARD(num_features, num_classes, hids=[hidden_dim], dropout=dropout, threshold=threshold)
        elif args.model == 'noisy_gcn':
            beta = param_dict.get('beta', 0.1)
            model = NoisyGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, noise_ratio_1=beta)
        elif args.model == 'gcorn':
            model = GCORN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, order=3)
        elif args.model == 'prognn':
            alpha = param_dict.get('alpha', 0.1)
            beta = param_dict.get('beta', 2.0)
            # Apply all relevant hyperparameters for ProGNN
            model = ProGNN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                          alpha=alpha, beta=beta, lr=lr, lr_adj=lr*10, device=device,
                          epochs=args.epochs, symmetric=False, gamma=1.0, lambda_=0.001, phi=0.0)
        elif args.model == 'evennet':
            K = param_dict.get('K', 10)
            alpha = param_dict.get('alpha', 0.1)
            dprate = param_dict.get('dprate', 0.5)
            model = EvenNet(num_features, num_classes, hids=[hidden_dim], dropout=dropout, K=K, alpha=alpha, dprate=dprate)
        elif args.model == 'twirls':
            alp = param_dict.get('alp', 1.0)
            lam = param_dict.get('lam', 1.0)
            prop_step = param_dict.get('prop_step', 32)
            attention = param_dict.get('attention', True)
            model = TWIRLS(num_features, num_classes, hidden_channels=hidden_dim, dropout=dropout, alp=alp, lam=lam, prop_step=prop_step, attention=attention)
        elif args.model == 'softmedian':
            temperature = param_dict.get('temperature', 0.5)
            model = SoftMedianGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                                 temperature=temperature,  
                                 row_normalize=False, normalize=False, cached=True)
        elif args.model == 'softmediangdc':
            temperature = param_dict.get('temperature', 1.0)
            teleport_proba = param_dict.get('teleport_proba', 0.15)
            neighbors = param_dict.get('neighbors', 64)
            model = SoftMedianGDC(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                                 temperature=temperature, teleport_proba=teleport_proba,
                                 neighbors=neighbors, cached=True)
        elif args.model == 'grand':
            dropnode = param_dict.get('dropnode', 0.5)
            order = param_dict.get('order', 2)
            mlp_input_dropout = param_dict.get('mlp_input_dropout', 0.5)
            model = GRAND(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                         dropnode=dropnode, order=order, mlp_input_dropout=mlp_input_dropout)
        elif args.model == 'stable':
            cos = param_dict.get('cos', 0.3)
            jar = param_dict.get('jar', 0.03)
            alpha = param_dict.get('alpha', 0.3)
            model = Stable(num_features, num_classes, hids=[hidden_dim], dropout=dropout, device=device, alpha=alpha, cosine_threshold=cos, jaccard_threshold=jar)
        elif args.model == 'purify_jaccard':
            jaccard_threshold = param_dict.get('jaccard_threshold', 0.03)
            model = PurificationGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                                   purification_method='jaccard',
                                   jaccard_threshold=jaccard_threshold, 
                                   cosine_threshold=None,
                                   device=device)
        elif args.model == 'purify_cosine':
            cosine_threshold = param_dict.get('cosine_threshold', 0.1)
            model = PurificationGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                                   purification_method='cosine',
                                   jaccard_threshold=None, 
                                   cosine_threshold=cosine_threshold,
                                   device=device)
        elif args.model == 'rung':  
            lamb = param_dict.get('lamb', 0.8)
            gamma = param_dict.get('gamma', 2)
            model = RUNG(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                        lam_hat=lamb, gamma=gamma)
        else:
            raise ValueError(f"Unsupported model: {args.model}")
        
        # Handle structure learning models (ProGNN, Stable) separately
        if args.model in ['prognn', 'stable', 'purify_jaccard', 'purify_cosine']:
            # Train on clean data
            model.fit(full_data.x, full_data.edge_index, full_data.y, 
                     full_data.train_mask, full_data.val_mask, epochs=args.epochs)
            
            # Test on clean data - get both val and test accuracies
            clean_val_acc = model.test(full_data.x, full_data.y, full_data.val_mask)
            clean_test_acc = model.test(full_data.x, full_data.y, full_data.test_mask)
            clean_val_accs.append(clean_val_acc)
            clean_accs.append(clean_test_acc)
            
            # Load text attacked data and evaluate
            atk_meta_info = {
                'attack': args.attack,
                'ptb_rate': args.ptb_rate,
                'atk_emb_type': args.atk_emb_type,
                'seed': seed
            }
            
            try:
                atk_data = load_atk_graph_dataset_for_gnn(
                    args.dataset, device, atk_meta_info, re_split=args.re_split, 
                    path_prefix=root_path, emb_model=args.def_emb_type)
                
                # Reset and train on attacked data
                model.reset_parameters()
                model.fit(atk_data.x, atk_data.edge_index, atk_data.y,
                         atk_data.train_mask, atk_data.val_mask, epochs=args.epochs)
                
                # Test on attacked data - get both val and test accuracies
                attacked_val_acc = model.test(atk_data.x, atk_data.y, atk_data.val_mask)
                attacked_test_acc = model.test(atk_data.x, atk_data.y, atk_data.test_mask)
                attacked_val_accs.append(attacked_val_acc)
                attacked_accs.append(attacked_test_acc)
                
            except FileNotFoundError as e:
                logger.info(f"Missing attack data for seed {seed}: {e}")
                attacked_accs.append(clean_accs[-1])  # Use clean accuracy as fallback
                attacked_val_accs.append(clean_val_accs[-1])
        else:
            # Standard training procedure for other models
            if args.model == 'grand':
                n_samples = param_dict.get('n_samples', 2)
                reg_consistency = param_dict.get('reg_consistency', 1.0)
                sharpening_temperature = param_dict.get('sharpening_temperature', 0.5)
                trainer_before = GRANDTrainer(model, device=device, verbose=0, n_samples=n_samples,
                                    reg_consistency=reg_consistency, sharpening_temperature=sharpening_temperature)
            else:
                trainer_before = Trainer(model, device=device, verbose=0)
            trainer_before.reset_optimizer(lr=lr, weight_decay=wd)
            ckp = ModelCheckpoint(f'model_before_{args.model}_{args.attack}_{args.atk_emb_type}_{args.def_emb_type}_{args.ptb_rate}_{seed}.pth', monitor='val_acc')
            early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')
            trainer_before.fit(full_data, mask=(full_data.train_mask, full_data.val_mask), callbacks=[ckp, early_stopping], epochs=args.epochs, verbose=0)
            
            # Evaluate on clean test set
            test_logs = trainer_before.evaluate(full_data, mask=full_data.test_mask)
            val_acc = ckp.best
            clean_val_accs.append(val_acc)
            clean_accs.append(test_logs['acc'])

            # Load text attacked data and evaluate
            atk_meta_info = {
                'attack': args.attack,
                'ptb_rate': args.ptb_rate,
                'atk_emb_type': args.atk_emb_type,
                'seed': seed
            }
            
            try:
                atk_data = load_atk_graph_dataset_for_gnn(
                    args.dataset, device, atk_meta_info, re_split=args.re_split, 
                    path_prefix=root_path, emb_model=args.def_emb_type)

                model.reset_parameters()
                if args.model == 'grand':
                    n_samples = param_dict.get('n_samples', 2)
                    reg_consistency = param_dict.get('reg_consistency', 1.0)
                    sharpening_temperature = param_dict.get('sharpening_temperature', 0.5)
                    trainer_after = GRANDTrainer(model, device=device, verbose=0, n_samples=n_samples,
                                        reg_consistency=reg_consistency, sharpening_temperature=sharpening_temperature)
                else:
                    trainer_after = Trainer(model, device=device, verbose=0)
                trainer_after.reset_optimizer(lr=lr, weight_decay=wd)
                ckp = ModelCheckpoint(f'model_after_{args.model}_{args.attack}_{args.atk_emb_type}_{args.def_emb_type}_{args.ptb_rate}_{seed}.pth', monitor='val_acc')
                early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')
                trainer_after.fit(atk_data, mask=(atk_data.train_mask, atk_data.val_mask), callbacks=[ckp, early_stopping], epochs=args.epochs, verbose=0)
                attacked_logs = trainer_after.evaluate(atk_data, mask=atk_data.test_mask)
                attacked_val_accs.append(ckp.best)
                attacked_accs.append(attacked_logs['acc'])
                
            except FileNotFoundError as e:
                raise FileNotFoundError(f"Missing attack data for seed {seed}: {e}")
 
    # Log result - align with inductive logging format
    logger.info(f"> Clean Val Acc: {np.mean(clean_val_accs):.4f} ± {np.std(clean_val_accs):.4f}")
    logger.info(f"> Clean Test Acc: {np.mean(clean_accs):.4f} ± {np.std(clean_accs):.4f}")
    logger.info(f"> Attacked Val Acc: {np.mean(attacked_val_accs):.4f} ± {np.std(attacked_val_accs):.4f}")
    logger.info(f"> Attacked Test Acc: {np.mean(attacked_accs):.4f} ± {np.std(attacked_accs):.4f}")
    logger.info(f"> Time: {(time.time() - start_time) / len(seeds):.2f}s") 