import torch
import time
import os
import os.path as osp
import numpy as np
import argparse
import logging
from datetime import datetime
from itertools import product

from torch_geometric.utils import to_undirected
from torch_geometric.data import Data

from greatx.attack.untargeted import Metattack, DICEAttack, PGDAttack, PRBCDAttack, GRBCDAttack
from greatx.datasets import GraphDataset
from greatx.nn.models import GCN, GAT, GNNGUARD, ElasticGNN, RUNG, RobustGCN, GCORN, NoisyGCN, APPNP, GPRGNN, TWIRLS, EvenNet, SoftMedianGCN, SoftMedianGDC, GRAND, GUARDDUAL
from greatx.training import Trainer, GRANDTrainer
from greatx.training.callbacks import ModelCheckpoint, EarlyStopping

from greatx.utils import split_nodes

import sys
sys.path.append("../")
from common import set_seed, load_graph_dataset_for_gnn, load_inductive_graph_dataset_for_gnn

import yaml

# ------------ Argument Parsing ------------
parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data/")
parser.add_argument("--graph_save_dir", type=str, default="atkg")
parser.add_argument("--config", type=str, default="config.yaml")
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--model", type=str, default='gcn')
parser.add_argument("--atk_type", type=str, default='structure', choices=['structure', 'text', 'hybrid'])
parser.add_argument("--attack", type=str, default="pgd")
parser.add_argument("--atk_phase", type=str, default="inductive", choices=['inductive', 'transductive'])
parser.add_argument("--atk_emb_type", type=str, default="bow", choices=['bow', 'roberta', 'Mistral-7B', 'MiniLM'])
parser.add_argument("--def_emb_type", type=str, default="bow", choices=['bow', 'roberta', 'Mistral-7B', 'MiniLM'])
parser.add_argument("--re_split", type=int, default=2)
parser.add_argument("--ptb_rate", type=float, default=0.1)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--patience", type=int, default=50)
parser.add_argument("--epochs", type=int, default=400)
parser.add_argument("--use_existing_logs", action='store_true')
args = parser.parse_args()

# ------------ Load Config ------------
with open(args.config, 'r') as f:
    config = yaml.safe_load(f)

# Get model-specific hyperparameters
model_config = config['hyperparams'].get(args.model, {})
hyperparams = {}
for param_name, param_values in model_config.items():
    hyperparams[param_name] = param_values


if args.model in ['rung', 'twirls', 'softmedian', 'softmediangdc'] and args.dataset in ["history" "photo" "computer", "arxiv"]:
    logger.info("Skipping rung, twirls, softmedian, and softmediangdc for computer, photo, and arxiv due to memory constraints")
    exit(0)

assert args.atk_phase == 'inductive'
if args.dataset != 'arxiv':
    assert args.re_split == 2, "Only inductive split with re_split=2 (6/2/2) is supported for defense evaluation"
    seeds = range(3)
else:
    args.re_split = 0
    assert args.re_split == 0, "Only inductive split with re_split=0 (original split) is supported for dataset arxiv"
    seeds = [0]

# ------------ Fixed Settings ------------
root_path = args.root_path
emb_save_dir = f"{root_path}/datasets/{args.atk_emb_type}"
graph_save_dir = f"{root_path}/{args.graph_save_dir}"
# Use the actual GPU ID passed from the script
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

# ------------ Logging Setup ------------
log_dir = f"eval_logs_ind/{args.dataset}/{args.model}/{args.atk_type}_{args.attack}_{args.atk_phase}"
os.makedirs(log_dir, exist_ok=True)
log_file = f"{log_dir}/{args.atk_emb_type}_{args.def_emb_type}_ptb{int(args.ptb_rate*100)}.log"

# Check if log file exists, skip experiment if it does
if args.use_existing_logs and os.path.exists(log_file):
    print(f"Log file {log_file} already exists. Experiment already completed. Skipping...")
    exit(0)

logging.basicConfig(filename=log_file, level=logging.INFO, filemode='w', format='%(message)s')
logger = logging.getLogger()

logger.info("=" * 80)
logger.info(f"EXPERIMENT - Dataset: {args.dataset}, Model: {args.model}")
logger.info(f"Attack Type: {args.atk_type}, Attack: {args.attack}, Attack Phase: {args.atk_phase}")
logger.info(f"Embedding: atk={args.atk_emb_type}, def={args.def_emb_type}, ptb_rate={args.ptb_rate}")
logger.info(f"Training: epochs={args.epochs}, patience={args.patience}")
logger.info(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
logger.info("=" * 80)

# ------------ Run over hyperparams ------------
# Create all combinations of hyperparameters
param_names = list(hyperparams.keys())
param_values = list(hyperparams.values())
param_combinations = list(product(*param_values))

for params in param_combinations:
    # Create a dictionary of parameter values
    param_dict = dict(zip(param_names, params))
    
    # Extract common parameters
    lr = param_dict.get('learning_rates', 0.01)
    wd = param_dict.get('weight_decays', 0.0)
    dropout = param_dict.get('dropouts', 0.5)
    if args.dataset in ['cora', 'citeseer', 'pubmed', 'instagram', 'wikics']:
        hidden_dim = 128
    else:
        hidden_dim = 256
    if args.model == 'gat':
        hidden_dim = hidden_dim // 8
    
    clean_accs, evasion_accs, poison_accs = [], [], []
    clean_val_accs, poison_val_accs = [], []
    
    logger.info("-" * 60)
    logger.info(f"Hyperparams: {param_dict}")
    start_time = time.time()
    
    for seed in seeds:
        set_seed(seed)
        full_data, (train_data, val_data, test_data) = load_inductive_graph_dataset_for_gnn(args.dataset, device, re_split=args.re_split, path_prefix=root_path, emb_model=args.def_emb_type, seed=seed)
        full_data = full_data.to(device)
        train_data = train_data.to(device)
        val_data = val_data.to(device)
        test_data = test_data.to(device)

        num_features = full_data.x.shape[-1]
        num_classes = full_data.y.max().item() + 1
        
        # Initialize model based on args.model with appropriate hyperparameters
        if args.model == 'gcn':
            model = GCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'gat':
            model = GAT(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'appnp':
            alpha = param_dict.get('alpha', 0.1)
            model = APPNP(num_features, num_classes, hids=[hidden_dim], dropout=dropout, alpha=alpha)
        elif args.model == 'gprgnn':
            alpha = param_dict.get('alpha', 0.1)
            model = GPRGNN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, alpha=alpha)
        elif args.model == 'robustgcn':
            model = RobustGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'elasticgnn':
            lambda1 = param_dict.get('lambda1', 0)
            lambda2 = param_dict.get('lambda2', 0)
            model = ElasticGNN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                              lambda1=lambda1, lambda2=lambda2, cached=False)
        elif args.model == 'gnnguard':
            threshold = param_dict.get('threshold', 0.1)
            model = GNNGUARD(num_features, num_classes, hids=[hidden_dim], dropout=dropout, threshold=threshold)
        elif args.model == 'guarddual':
            if args.def_emb_type != 'roberta':
                print("Guarddual only use roberta as embedding")
                continue
            num_train_nodes = train_data.x.shape[0]
            num_val_nodes = val_data.x.shape[0] - num_train_nodes
            model = GUARDDUAL(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                              dataset_name=args.dataset, train_mask=full_data.train_mask, val_mask=full_data.val_mask)
        elif args.model == 'noisy_gcn':
            beta = param_dict.get('beta', 0.1)
            model = NoisyGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, noise_ratio_1=beta)
        elif args.model == 'gcorn':
            model = GCORN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, order=3)
        elif args.model == 'evennet':
            K = param_dict.get('K', 10)
            alpha = param_dict.get('alpha', 0.1)
            dprate = param_dict.get('dprate', 0.5)
            model = EvenNet(num_features, num_classes, hids=[hidden_dim], dropout=dropout, K=K, alpha=alpha, dprate=dprate)
        elif args.model == 'twirls':
            alp = param_dict.get('alp', 1.0)
            lam = param_dict.get('lam', 1.0)
            prop_step = param_dict.get('prop_step', 32)
            attention = param_dict.get('attention', True)
            model = TWIRLS(num_features, num_classes, hidden_channels=hidden_dim, dropout=dropout, alp=alp, lam=lam, prop_step=prop_step, attention=attention)
        elif args.model == 'rung':  
            lamb = param_dict.get('lamb', 0.8)
            gamma = param_dict.get('gamma', 2)
            model = RUNG(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                        lam_hat=lamb, gamma=gamma)
        elif args.model == 'softmedian':
            temperature = param_dict.get('temperature', 0.5)
            model = SoftMedianGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                                 temperature=temperature,  
                                 row_normalize=False, normalize=False, cached=False)
        elif args.model == 'softmediangdc':
            temperature = param_dict.get('temperature', 1.0)
            teleport_proba = param_dict.get('teleport_proba', 0.15)
            neighbors = param_dict.get('neighbors', 64)
            model = SoftMedianGDC(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                                 temperature=temperature, teleport_proba=teleport_proba,
                                 neighbors=neighbors, cached=False)
        elif args.model == 'grand':
            dropnode = param_dict.get('dropnode', 0.5)
            order = param_dict.get('order', 2)
            mlp_input_dropout = param_dict.get('mlp_input_dropout', 0.5)
            model = GRAND(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                         dropnode=dropnode, order=order, mlp_input_dropout=mlp_input_dropout)
        else:
            raise ValueError(f"Unsupported model: {args.model}")
            
        # Use GRANDTrainer for GRAND model, standard Trainer for others
        if args.model == 'grand':
            n_samples = param_dict.get('n_samples', 2)
            reg_consistency = param_dict.get('reg_consistency', 1.0)
            sharpening_temperature = param_dict.get('sharpening_temperature', 0.5)
            trainer = GRANDTrainer(model, device=device, verbose=0, n_samples=n_samples,
                                  reg_consistency=reg_consistency, sharpening_temperature=sharpening_temperature)
        else:
            trainer = Trainer(model, device=device, verbose=0)
        trainer.reset_optimizer(lr=lr, weight_decay=wd)
        ckp = ModelCheckpoint(f'model_before_{args.dataset}_{args.model}_{args.attack}_{args.atk_emb_type}_{args.def_emb_type}_{args.ptb_rate}_{seed}.pth', monitor='val_acc')
        early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')

        num_train_nodes = train_data.x.shape[0]
        num_val_nodes = val_data.x.shape[0] - num_train_nodes
        
        # Create validation mask: False for train nodes, True for val nodes
        val_mask_for_val_data = torch.cat([
            torch.zeros(num_train_nodes, dtype=torch.bool, device=device),  # train nodes
            torch.ones(num_val_nodes, dtype=torch.bool, device=device)      # val nodes
        ])
        
        trainer.fit((train_data, val_data), mask=(None, val_mask_for_val_data), 
                   verbose=1, callbacks=[ckp, early_stopping], epochs=args.epochs)
        
        # Evaluate on clean test set
        test_logs = trainer.evaluate(test_data, mask=full_data.test_mask)
        val_acc = ckp.best
        clean_val_accs.append(val_acc)
        clean_accs.append(test_logs['acc'])

        # Load attack graph and evaluate evasion attack
        atk_path = f"{graph_save_dir}/{args.dataset}/{args.attack}/{args.atk_emb_type}_{int(args.ptb_rate*100)}_{seed}.pt"
        if not os.path.exists(atk_path):
            raise FileNotFoundError(f"Missing attack graph: {atk_path}")
        
        test_features = test_data.x
        test_labels = test_data.y
        perturbed_edge_index = torch.load(atk_path)
        perturbed_test_data = Data(
            x=test_features,
            y=test_labels,
            edge_index=perturbed_edge_index
        )

        evasion_logs = trainer.evaluate(perturbed_test_data, mask=full_data.test_mask)
        evasion_accs.append(evasion_logs['acc'])
 
    # Log result
    logger.info(f"> Clean Val Acc: {np.mean(clean_val_accs):.4f} ± {np.std(clean_val_accs):.4f}")
    logger.info(f"> Clean Test Acc: {np.mean(clean_accs):.4f} ± {np.std(clean_accs):.4f}")
    logger.info(f"> Evasion Acc: {np.mean(evasion_accs):.4f} ± {np.std(evasion_accs):.4f}")

    logger.info(f"> Time: {(time.time() - start_time) / len(seeds):.2f}s")
