import torch
import time
import os
import os.path as osp
import numpy as np
import argparse
import logging
import json
from datetime import datetime
from itertools import product

from torch_geometric.utils import to_undirected
from torch_geometric.data import Data

from greatx.attack.untargeted import Metattack, DICEAttack, PGDAttack, PRBCDAttack, GRBCDAttack
from greatx.datasets import GraphDataset
from greatx.nn.models import GCN, GAT, GNNGUARD, ElasticGNN, RUNG, RobustGCN, GCORN, NoisyGCN, APPNP, GPRGNN, TWIRLS, EvenNet, SoftMedianGCN, SoftMedianGDC, GRAND, GUARDDUAL
from greatx.training import Trainer, GRANDTrainer
from greatx.training.callbacks import ModelCheckpoint, EarlyStopping

from greatx.utils import split_nodes

import sys
sys.path.append("../")
from common import set_seed, load_inductive_graph_dataset_for_gnn, load_inductive_atk_graph_dataset_for_gnn

import yaml

# ------------ Argument Parsing ------------
parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data/")
parser.add_argument("--config", type=str, default="config.yaml")
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--model", type=str, default='gcn')
parser.add_argument("--attack", type=str, default="textfooler", choices=['textfooler', 'llm', 'llm_Ministral', 'gpt'])
parser.add_argument("--atk_emb_type", type=str, default="bow", choices=['bow', 'roberta', 'Mistral-7B', 'MiniLM'])
parser.add_argument("--def_emb_type", type=str, default="bow", choices=['bow', 'roberta', 'Mistral-7B', 'MiniLM'])
parser.add_argument("--re_split", type=int, default=2)
parser.add_argument("--ptb_rate", type=float, default=0.1)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--patience", type=int, default=50)
parser.add_argument("--epochs", type=int, default=400)
parser.add_argument("--use_existing_logs", action='store_true')
args = parser.parse_args()

# ------------ Load Config ------------
with open(args.config, 'r') as f:
    config = yaml.safe_load(f)

# Get model-specific hyperparameters
model_config = config['hyperparams'].get(args.model, {})
hyperparams = {}
for param_name, param_values in model_config.items():
    hyperparams[param_name] = param_values

# Check for memory-constrained models/datasets
if args.model in ['rung', 'twirls'] and args.dataset in ["computer", "arxiv", "photo", "history"]:
    print("Skipping rung, twirls, softmedian, and softmediangdc for computer, photo, and arxiv due to memory constraints")
    exit(0)
elif args.model in ['softmedian', 'softmediangdc'] and args.dataset in ["computer", "arxiv", "photo"]:
    print("Skipping rung, twirls, softmedian, and softmediangdc for computer, photo, and arxiv due to memory constraints")
    exit(0)

assert args.re_split == 2
if args.dataset != 'arxiv':
    assert args.re_split == 2, "Only inductive split with re_split=2 (6/2/2) is supported for defense evaluation"
    seeds = range(3)
else:
    args.re_split = 0
    assert args.re_split == 0, "Only inductive split with re_split=0 (original split) is supported for dataset arxiv"
    seeds = [0]

# ------------ Fixed Settings ------------
root_path = args.root_path
# Use the actual GPU ID passed from the script
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

# ------------ Logging Setup ------------
log_dir = f"eval_logs_ind_text/{args.dataset}/{args.model}/{args.attack}"
os.makedirs(log_dir, exist_ok=True)
log_file = f"{log_dir}/{args.atk_emb_type}_{args.def_emb_type}_ptb{int(args.ptb_rate*100)}.log"

# Check if log file exists, skip experiment if it does
if args.use_existing_logs and os.path.exists(log_file):
    print(f"Log file {log_file} already exists. Experiment already completed. Skipping...")
    exit(0)

logging.basicConfig(filename=log_file, level=logging.INFO, filemode='w', format='%(message)s')
logger = logging.getLogger()

logger.info("=" * 80)
logger.info(f"EXPERIMENT - Dataset: {args.dataset}, Model: {args.model}")
logger.info(f"Text Attack: {args.attack}")
logger.info(f"Embedding: atk={args.atk_emb_type}, def={args.def_emb_type}, ptb_rate={args.ptb_rate}")
logger.info(f"Training: epochs={args.epochs}, patience={args.patience}")
logger.info(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
logger.info("=" * 80)

# ------------ Run over hyperparams ------------
# Create all combinations of hyperparameters
param_names = list(hyperparams.keys())
param_values = list(hyperparams.values())
param_combinations = list(product(*param_values))

for params in param_combinations:
    # Create a dictionary of parameter values
    param_dict = dict(zip(param_names, params))
    
    # Extract common parameters
    lr = param_dict.get('learning_rates', 0.01)
    wd = param_dict.get('weight_decays', 0.0)
    dropout = param_dict.get('dropouts', 0.5)
    if args.dataset in ['cora', 'citeseer', 'pubmed', 'instagram', 'wikics']:
        hidden_dim = 128
    else:
        hidden_dim = 256
    if args.model == 'gat':
        hidden_dim = hidden_dim // 8
    
    clean_accs, attacked_accs = [], []
    clean_val_accs, attacked_val_accs = [], []
    
    logger.info("-" * 60)
    logger.info(f"Hyperparams: {param_dict}")
    start_time = time.time()
    
    for seed in seeds:
        set_seed(seed)
        
        # Load clean data
        full_data, (train_data, val_data, test_data) = load_inductive_graph_dataset_for_gnn(
            args.dataset, device, re_split=args.re_split, 
            path_prefix=root_path, emb_model=args.def_emb_type, seed=seed)
        full_data = full_data.to(device)
        train_data = train_data.to(device)
        val_data = val_data.to(device)
        test_data = test_data.to(device)

        num_features = full_data.x.shape[-1]
        num_classes = full_data.y.max().item() + 1
        
        # Initialize model based on args.model with appropriate hyperparameters
        if args.model == 'gcn':
            model = GCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'gat':
            model = GAT(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'appnp':
            alpha = param_dict.get('alpha', 0.1)
            model = APPNP(num_features, num_classes, hids=[hidden_dim], dropout=dropout, alpha=alpha)
        elif args.model == 'gprgnn':
            alpha = param_dict.get('alpha', 0.1)
            model = GPRGNN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, alpha=alpha)
        elif args.model == 'robustgcn':
            model = RobustGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'elasticgnn':
            lambda1 = param_dict.get('lambda1', 0)
            lambda2 = param_dict.get('lambda2', 0)
            model = ElasticGNN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                              lambda1=lambda1, lambda2=lambda2, cached=False)
        elif args.model == 'gnnguard':
            threshold = param_dict.get('threshold', 0.1)
            model = GNNGUARD(num_features, num_classes, hids=[hidden_dim], dropout=dropout, threshold=threshold)
        elif args.model == 'guarddual':
            if args.def_emb_type != 'roberta':
                print("Guarddual only use roberta as embedding")
                continue
            num_train_nodes = train_data.x.shape[0]
            num_val_nodes = val_data.x.shape[0] - num_train_nodes
            model = GUARDDUAL(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                              dataset_name=args.dataset, num_train_nodes=num_train_nodes, num_val_nodes=num_val_nodes)
        elif args.model == 'noisy_gcn':
            beta = param_dict.get('beta', 0.1)
            model = NoisyGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, noise_ratio_1=beta)
        elif args.model == 'gcorn':
            model = GCORN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, order=3)
        elif args.model == 'evennet':
            K = param_dict.get('K', 10)
            alpha = param_dict.get('alpha', 0.1)
            dprate = param_dict.get('dprate', 0.5)
            model = EvenNet(num_features, num_classes, hids=[hidden_dim], dropout=dropout, K=K, alpha=alpha, dprate=dprate)
        elif args.model == 'twirls':
            alp = param_dict.get('alp', 1.0)
            lam = param_dict.get('lam', 1.0)
            prop_step = param_dict.get('prop_step', 32)
            attention = param_dict.get('attention', True)
            model = TWIRLS(num_features, num_classes, hidden_channels=hidden_dim, dropout=dropout, alp=alp, lam=lam, prop_step=prop_step, attention=attention)
        elif args.model == 'rung':  
            lamb = param_dict.get('lamb', 0.8)
            gamma = param_dict.get('gamma', 2)
            model = RUNG(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                        lam_hat=lamb, gamma=gamma)
        elif args.model == 'softmedian':
            temperature = param_dict.get('temperature', 0.5)
            model = SoftMedianGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                                 temperature=temperature,  
                                 row_normalize=False, normalize=False, cached=False)
        elif args.model == 'softmediangdc':
            temperature = param_dict.get('temperature', 1.0)
            teleport_proba = param_dict.get('teleport_proba', 0.15)
            neighbors = param_dict.get('neighbors', 64)
            model = SoftMedianGDC(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                                 temperature=temperature, teleport_proba=teleport_proba,
                                 neighbors=neighbors, cached=False)
        elif args.model == 'grand':
            dropnode = param_dict.get('dropnode', 0.5)
            order = param_dict.get('order', 2)
            mlp_input_dropout = param_dict.get('mlp_input_dropout', 0.5)
            model = GRAND(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                         dropnode=dropnode, order=order, mlp_input_dropout=mlp_input_dropout)
        else:
            raise ValueError(f"Unsupported model: {args.model}")
            
        # Use GRANDTrainer for GRAND model, standard Trainer for others
        if args.model == 'grand':
            n_samples = param_dict.get('n_samples', 2)
            reg_consistency = param_dict.get('reg_consistency', 1.0)
            sharpening_temperature = param_dict.get('sharpening_temperature', 0.5)
            trainer = GRANDTrainer(model, device=device, verbose=0, n_samples=n_samples,
                                  reg_consistency=reg_consistency, sharpening_temperature=sharpening_temperature)
        else:
            trainer = Trainer(model, device=device, verbose=0)
        trainer.reset_optimizer(lr=lr, weight_decay=wd)
        ckp = ModelCheckpoint(f'model_before_{args.dataset}_{args.model}_{args.attack}_{args.atk_emb_type}_{args.def_emb_type}_{args.ptb_rate}_{seed}.pth', monitor='val_acc')
        early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')

        num_train_nodes = train_data.x.shape[0]
        num_val_nodes = val_data.x.shape[0] - num_train_nodes
        
        # Create validation mask: False for train nodes, True for val nodes
        val_mask_for_val_data = torch.cat([
            torch.zeros(num_train_nodes, dtype=torch.bool, device=device),  # train nodes
            torch.ones(num_val_nodes, dtype=torch.bool, device=device)      # val nodes
        ])
        
        trainer.fit((train_data, val_data), mask=(None, val_mask_for_val_data), 
                   verbose=1, callbacks=[ckp, early_stopping], epochs=args.epochs)
        
        # Evaluate on clean test set
        test_logs = trainer.evaluate(test_data, mask=full_data.test_mask)
        val_acc = ckp.best
        clean_val_accs.append(val_acc)
        clean_accs.append(test_logs['acc'])

        # Load text attacked data and evaluate
        atk_meta_info = {
            'attack': args.attack,
            'ptb_rate': args.ptb_rate,
            'atk_emb_type': args.atk_emb_type,
            'seed': seed
        }
        
        try:
            atk_full_data, (atk_train_data, atk_val_data, atk_test_data) = load_inductive_atk_graph_dataset_for_gnn(
                args.dataset, device, atk_meta_info, re_split=args.re_split, 
                path_prefix=root_path, emb_model=args.def_emb_type, seed=seed)
            
            atk_full_data = atk_full_data.to(device)
            atk_test_data = atk_test_data.to(device)

            attacked_logs = trainer.evaluate(atk_test_data, mask=full_data.test_mask)
            attacked_accs.append(attacked_logs['acc'])
            attacked_val_accs.append(val_acc)  # Use same validation accuracy
            
        except FileNotFoundError as e:
            raise FileNotFoundError(f"Missing attack data for seed {seed}: {e}")
            
 
    # Log result
    logger.info(f"> Clean Val Acc: {np.mean(clean_val_accs):.4f} ± {np.std(clean_val_accs):.4f}")
    logger.info(f"> Clean Test Acc: {np.mean(clean_accs):.4f} ± {np.std(clean_accs):.4f}")
    logger.info(f"> Attacked Test Acc: {np.mean(attacked_accs):.4f} ± {np.std(attacked_accs):.4f}")

    logger.info(f"> Time: {(time.time() - start_time) / len(seeds):.2f}s") 