import torch
import time
import os
import os.path as osp
import numpy as np
import argparse
import logging
from datetime import datetime
from itertools import product

from torch_geometric.utils import to_undirected
from torch_geometric.data import Data

from greatx.attack.untargeted import Metattack, DICEAttack, PGDAttack, PRBCDAttack, GRBCDAttack
from greatx.datasets import GraphDataset
from greatx.nn.models import GCN, GAT, GNNGUARD, ElasticGNN, RUNG, RobustGCN, GCORN, NoisyGCN, APPNP, GPRGNN, ProGNN, Stable, PurificationGCN, EvenNet, TWIRLS, SoftMedianGCN, SoftMedianGDC, GRAND
from greatx.training import Trainer, GRANDTrainer
from greatx.training.callbacks import ModelCheckpoint, EarlyStopping
from greatx.utils import split_nodes

import sys
sys.path.append("../")
from common import set_seed, load_graph_dataset_for_gnn, load_inductive_graph_dataset_for_gnn

import yaml

# ------------ Argument Parsing ------------
parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data/")
parser.add_argument("--graph_save_dir", type=str, default="atkg")
parser.add_argument("--config", type=str, default="config.yaml")
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--model", type=str, default='gcn')
parser.add_argument("--atk_type", type=str, default='structure', choices=['structure', 'text', 'hybrid'])
parser.add_argument("--attack", type=str, default="strg")
parser.add_argument("--atk_phase", type=str, default="transductive", choices=['inductive', 'transductive'])
parser.add_argument("--atk_emb_type", type=str, default="bow", choices=['bow', 'roberta', 'Mistral-7B', 'MiniLM'])
parser.add_argument("--def_emb_type", type=str, default="bow", choices=['bow', 'roberta', 'Mistral-7B', 'MiniLM'])
parser.add_argument("--re_split", type=int, default=1)
parser.add_argument("--ptb_rate", type=float, default=0.2)
parser.add_argument("--device", type=int, default=1)
parser.add_argument("--patience", type=int, default=100)
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--use_existing_logs", action='store_true')
args = parser.parse_args()

# ------------ Load Config ------------
with open(args.config, 'r') as f:
    config = yaml.safe_load(f)

# Get model-specific hyperparameters
model_config = config['hyperparams'].get(args.model, {})
hyperparams = {}
for param_name, param_values in model_config.items():
    hyperparams[param_name] = param_values

assert args.atk_phase == 'transductive'
assert args.atk_type == 'structure', "Only structure attack is supported"
assert args.re_split == 1, "Only transductive split with re_split=1 (1/1/8) is supported for defense evaluation"
assert args.dataset != 'arxiv', "arxiv dataset is not supported for transductive evaluation"

# ------------ Fixed Settings ------------
root_path = args.root_path
emb_save_dir = f"{root_path}/datasets/{args.atk_emb_type}"
graph_save_dir = f"{root_path}/{args.graph_save_dir}"
seeds = range(3)
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

# ------------ Logging Setup ------------
log_dir = f"eval_logs_trans/{args.dataset}/{args.model}/{args.atk_type}_{args.attack}_{args.atk_phase}"
os.makedirs(log_dir, exist_ok=True)
log_file = f"{log_dir}/{args.atk_emb_type}_{args.def_emb_type}_ptb{int(args.ptb_rate*100)}.log"

# Check if log file exists, skip experiment if it does
if args.use_existing_logs and os.path.exists(log_file):
    print(f"Log file {log_file} already exists. Experiment already completed. Skipping...")
    exit(0)

logging.basicConfig(filename=log_file, level=logging.INFO, filemode='w', format='%(message)s')
logger = logging.getLogger()

logger.info("=" * 80)
logger.info(f"EXPERIMENT - Dataset: {args.dataset}, Model: {args.model}")
logger.info(f"Attack Type: {args.atk_type}, Attack: {args.attack}, Attack Phase: {args.atk_phase}")
logger.info(f"Embedding: atk={args.atk_emb_type}, def={args.def_emb_type}, ptb_rate={args.ptb_rate}")
logger.info(f"Training: epochs={args.epochs}, patience={args.patience}")
logger.info(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
logger.info("=" * 80)

# Skip for prognn if too large
if args.model == 'prognn' and args.dataset in ["pubmed", "reddit", "history", "computer", "photo", "wikics", "instagram"]:
    logger.info("Skipping prognn for large datasets")
    exit(0)
elif args.model == 'stable' and args.dataset in ["computer", "photo"]:
    logger.info("Skipping prognn and stable for large datasets")
    exit(0)
elif args.model == 'rung' and args.dataset in ["computer", "photo"]:
    logger.info("Skipping rung for large datasets")
    exit(0)
elif args.model == 'twirls' and args.dataset in ["computer", "photo"]:
    logger.info("Skipping twirls for large datasets")
    exit(0)
elif args.model == 'softmediangdc' and args.dataset in ["computer", "photo"]:
    logger.info("Skipping softmediangdc for large datasets")
    exit(0)

# ------------ Run over hyperparams ------------
# Create all combinations of hyperparameters
param_names = list(hyperparams.keys())
param_values = list(hyperparams.values())
param_combinations = list(product(*param_values))

for params in param_combinations:
    # Create a dictionary of parameter values
    param_dict = dict(zip(param_names, params))
    
    # Extract common parameters - align with inductive settings
    lr = param_dict.get('learning_rates', 0.01)
    wd = param_dict.get('weight_decays', 0.0)
    dropout = param_dict.get('dropouts', 0.5)
    if args.dataset in ['cora', 'citeseer', 'pubmed', 'instagram', 'wikics']:
        hidden_dim = 128
    else:
        hidden_dim = 256
    if args.model == 'gat':
        hidden_dim = hidden_dim // 8
    
    clean_accs, evasion_accs, poison_accs = [], [], []
    clean_val_accs, poison_val_accs = [], []  # Track validation accuracies
    
    logger.info("-" * 60)
    logger.info(f"Hyperparams: {param_dict}")
    start_time = time.time()
    
    for seed in seeds:
        set_seed(seed)
        full_data = load_graph_dataset_for_gnn(args.dataset, device, re_split=args.re_split, path_prefix=root_path, emb_model=args.def_emb_type)
        full_data = full_data.to(device)

        num_features = full_data.x.shape[-1]
        num_classes = full_data.y.max().item() + 1
        
        # Initialize model based on args.model with appropriate hyperparameters
        if args.model == 'gcn':
            model = GCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'gat':
            model = GAT(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'appnp':
            alpha = param_dict.get('alpha', 0.1)
            model = APPNP(num_features, num_classes, hids=[hidden_dim], dropout=dropout, alpha=alpha)
        elif args.model == 'gprgnn':
            alpha = param_dict.get('alpha', 0.1)
            model = GPRGNN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, alpha=alpha)
        elif args.model == 'robustgcn':
            model = RobustGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout)
        elif args.model == 'elasticgnn':
            lambda1 = param_dict.get('lambda1', 0)
            lambda2 = param_dict.get('lambda2', 0)
            model = ElasticGNN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                              lambda1=lambda1, lambda2=lambda2, cached=False)
        elif args.model == 'gnnguard':
            threshold = param_dict.get('threshold', 0.1)
            model = GNNGUARD(num_features, num_classes, hids=[hidden_dim], dropout=dropout, threshold=threshold)
        elif args.model == 'noisy_gcn':
            beta = param_dict.get('beta', 0.1)
            model = NoisyGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, noise_ratio_1=beta)
        elif args.model == 'gcorn':
            model = GCORN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, order=3)
        elif args.model == 'prognn':
            alpha = param_dict.get('alpha', 0.1)
            beta = param_dict.get('beta', 2.0)
            # Apply all relevant hyperparameters for ProGNN
            model = ProGNN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                          alpha=alpha, beta=beta, lr=lr, lr_adj=lr*10, device=device,
                          epochs=args.epochs, symmetric=False, gamma=1.0, lambda_=0.001, phi=0.0)
        elif args.model == 'evennet':
            K = param_dict.get('K', 10)
            alpha = param_dict.get('alpha', 0.1)
            dprate = param_dict.get('dprate', 0.5)
            model = EvenNet(num_features, num_classes, hids=[hidden_dim], dropout=dropout, K=K, alpha=alpha, dprate=dprate)
        elif args.model == 'twirls':
            alp = param_dict.get('alp', 1.0)
            lam = param_dict.get('lam', 1.0)
            prop_step = param_dict.get('prop_step', 32)
            attention = param_dict.get('attention', True)
            model = TWIRLS(num_features, num_classes, hidden_channels=hidden_dim, dropout=dropout, alp=alp, lam=lam, prop_step=prop_step, attention=attention)
        elif args.model == 'softmedian':
            temperature = param_dict.get('temperature', 0.5)
            model = SoftMedianGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                                 temperature=temperature,  
                                 row_normalize=False, normalize=False, cached=True)
        elif args.model == 'softmediangdc':
            temperature = param_dict.get('temperature', 1.0)
            teleport_proba = param_dict.get('teleport_proba', 0.15)
            neighbors = param_dict.get('neighbors', 64)
            model = SoftMedianGDC(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                                 temperature=temperature, teleport_proba=teleport_proba,
                                 neighbors=neighbors, cached=True)
        elif args.model == 'grand':
            dropnode = param_dict.get('dropnode', 0.5)
            order = param_dict.get('order', 2)
            mlp_input_dropout = param_dict.get('mlp_input_dropout', 0.5)
            model = GRAND(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                         dropnode=dropnode, order=order, mlp_input_dropout=mlp_input_dropout)
        elif args.model == 'stable':
            cos = param_dict.get('cos', 0.3)
            jar = param_dict.get('jar', 0.03)
            alpha = param_dict.get('alpha', 0.3)
            model = Stable(num_features, num_classes, hids=[hidden_dim], dropout=dropout, device=device, alpha=alpha, cosine_threshold=cos, jaccard_threshold=jar)
        elif args.model == 'purify_jaccard':
            jaccard_threshold = param_dict.get('jaccard_threshold', 0.03)
            model = PurificationGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                                   purification_method='jaccard',
                                   jaccard_threshold=jaccard_threshold, 
                                   cosine_threshold=None,
                                   device=device)
        elif args.model == 'purify_cosine':
            cosine_threshold = param_dict.get('cosine_threshold', 0.1)
            model = PurificationGCN(num_features, num_classes, hids=[hidden_dim], dropout=dropout, 
                                   purification_method='cosine',
                                   jaccard_threshold=None, 
                                   cosine_threshold=cosine_threshold,
                                   device=device)
        elif args.model == 'rung':  
            lamb = param_dict.get('lamb', 0.8)
            gamma = param_dict.get('gamma', 2)
            model = RUNG(num_features, num_classes, hids=[hidden_dim], dropout=dropout,
                        lam_hat=lamb, gamma=gamma)
        else:
            raise ValueError(f"Unsupported model: {args.model}")
        
        # Handle structure learning models (ProGNN, Stable) separately
        if args.model in ['prognn', 'stable', 'purify_jaccard', 'purify_cosine']:
            # Train on clean data
            model.fit(full_data.x, full_data.edge_index, full_data.y, 
                     full_data.train_mask, full_data.val_mask, epochs=args.epochs)
            
            # Test on clean data
            clean_acc = model.test(full_data.x, full_data.y, full_data.test_mask)
            clean_val_accs.append(clean_acc)  # Use test accuracy as validation accuracy
            clean_accs.append(clean_acc)
            
            # Load attack graph and evaluate poisoning attack
            atk_path = f"{graph_save_dir}/{args.dataset}/{args.attack}/{args.atk_emb_type}_{int(args.ptb_rate*100)}_{seed}.pt"
            if not os.path.exists(atk_path):
                raise FileNotFoundError(f"Missing attack graph: {atk_path}")
                
            perturbed_edge_index = torch.load(atk_path).to(device)
            perturbed_data = Data(x=full_data.x, y=full_data.y, edge_index=perturbed_edge_index,
                                  train_mask=full_data.train_mask,
                                  val_mask=full_data.val_mask,
                                  test_mask=full_data.test_mask)
            
            # Reset and train on perturbed data
            model.reset_parameters()
            model.fit(perturbed_data.x, perturbed_data.edge_index, perturbed_data.y,
                     perturbed_data.train_mask, perturbed_data.val_mask, epochs=args.epochs)
            
            # Test on perturbed data - get both val and test accuracies
            poison_val_acc = model.test(perturbed_data.x, perturbed_data.y, perturbed_data.val_mask)
            poison_test_acc = model.test(perturbed_data.x, perturbed_data.y, perturbed_data.test_mask)
            poison_val_accs.append(poison_val_acc)
            poison_accs.append(poison_test_acc)
        else:
            # Standard training procedure for other models
            if args.model == 'grand':
                n_samples = param_dict.get('n_samples', 2)
                reg_consistency = param_dict.get('reg_consistency', 1.0)
                sharpening_temperature = param_dict.get('sharpening_temperature', 0.5)
                trainer_before = GRANDTrainer(model, device=device, verbose=0, n_samples=n_samples,
                                    reg_consistency=reg_consistency, sharpening_temperature=sharpening_temperature)
            else:
                trainer_before = Trainer(model, device=device, verbose=0)
            trainer_before.reset_optimizer(lr=lr, weight_decay=wd)
            ckp = ModelCheckpoint(f'model_before_{args.model}_{args.attack}_{args.atk_emb_type}_{args.def_emb_type}_{args.ptb_rate}_{seed}.pth', monitor='val_acc')
            early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')
            trainer_before.fit(full_data, mask=(full_data.train_mask, full_data.val_mask), callbacks=[ckp, early_stopping], epochs=args.epochs, verbose=0)
            
            # Evaluate on clean test set
            test_logs = trainer_before.evaluate(full_data, mask=full_data.test_mask)
            val_acc = ckp.best
            clean_val_accs.append(val_acc)
            clean_accs.append(test_logs['acc'])

            # Load attack graph and evaluate poisoning attack
            atk_path = f"{graph_save_dir}/{args.dataset}/{args.attack}/{args.atk_emb_type}_{int(args.ptb_rate*100)}_{seed}.pt"
            if not os.path.exists(atk_path):
                raise FileNotFoundError(f"Missing attack graph: {atk_path}")
                
            perturbed_edge_index = torch.load(atk_path).to(device)
            perturbed_data = Data(x=full_data.x, y=full_data.y, edge_index=perturbed_edge_index,
                                  train_mask=full_data.train_mask,
                                  val_mask=full_data.val_mask,
                                  test_mask=full_data.test_mask)

            model.reset_parameters()
            if args.model == 'grand':
                n_samples = param_dict.get('n_samples', 2)
                reg_consistency = param_dict.get('reg_consistency', 1.0)
                sharpening_temperature = param_dict.get('sharpening_temperature', 0.5)
                trainer_after = GRANDTrainer(model, device=device, verbose=0, n_samples=n_samples,
                                    reg_consistency=reg_consistency, sharpening_temperature=sharpening_temperature)
            else:
                trainer_after = Trainer(model, device=device, verbose=0)
            trainer_after.reset_optimizer(lr=lr, weight_decay=wd)
            ckp = ModelCheckpoint(f'model_after_{args.model}_{args.attack}_{args.atk_emb_type}_{args.def_emb_type}_{args.ptb_rate}_{seed}.pth', monitor='val_acc')
            early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')
            trainer_after.fit(perturbed_data, mask=(perturbed_data.train_mask, perturbed_data.val_mask), callbacks=[ckp, early_stopping], epochs=args.epochs, verbose=0)
            poison_logs = trainer_after.evaluate(perturbed_data, mask=perturbed_data.test_mask)
            poison_val_accs.append(ckp.best)
            poison_accs.append(poison_logs['acc'])
 
    # Log result - align with inductive logging format
    logger.info(f"> Clean Val Acc: {np.mean(clean_val_accs):.4f} ± {np.std(clean_val_accs):.4f}")
    logger.info(f"> Clean Test Acc: {np.mean(clean_accs):.4f} ± {np.std(clean_accs):.4f}")
    logger.info(f"> Poison Val Acc: {np.mean(poison_val_accs):.4f} ± {np.std(poison_val_accs):.4f}")
    logger.info(f"> Poison Test Acc: {np.mean(poison_accs):.4f} ± {np.std(poison_accs):.4f}")
    logger.info(f"> Time: {(time.time() - start_time) / len(seeds):.2f}s")
