#!/usr/bin/env python3
"""
Main script for running adversarial robustness experiments on temporal graphs.
This script can be used as an alternative to SLURM-based execution.

Usage:
    python main.py --config configs/tgn_tgbl-wiki_trbcd.yaml
    python main.py --config configs/tncn_tgbl-wiki_memstranding.yaml
"""

import argparse
import yaml
import os
import sys
from pathlib import Path
import logging
import torch
import numpy as np
import timeit
import json
from datetime import datetime

# Add the current directory to Python path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

# Internal imports
from tgb.utils.utils import set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator, MeanAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.early_stopping import EarlyStopMonitorModular
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.utils.utils import load_pkl

# Model imports
from models.tgnw import WeightedGraphAttentionEmbedding, WeightedTGN
from models.tncnw import WeightedTNCNLinkPred
from modules.emb_module import TimeEmbedding

# Training utilities
from utils.tgnw_linkpred import train_one_epoch, test
from utils.tncnw_linkpred import (
    train_one_epoch as tncnw_train_one_epoch,
    test as tncnw_test,
)

# Attack utilities
from utils.attack import memstranding_attack, tgnw_attack, nat_attack

# Spotlight anomaly detection
try:
    from modules.spotlight.spotlight import Spotlight

    SPOTLIGHT_AVAILABLE = True
except ImportError:
    SPOTLIGHT_AVAILABLE = False
    print("Warning: SPOTLIGHT not available. Anomaly detection will be skipped.")


def setup_logging():
    """Setup logging configuration."""
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[logging.FileHandler("experiment.log"), logging.StreamHandler()],
    )
    return logging.getLogger(__name__)


def load_config(config_path):
    """Load configuration from YAML file."""
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)

    # Extract fixed parameters (ignore seml and slurm sections)
    if "fixed" in config:
        return config["fixed"]
    else:
        # If no 'fixed' section, return the whole config
        return {k: v for k, v in config.items() if k not in ["seml", "slurm"]}


def get_attack_function(attack_type):
    """Get the appropriate attack function."""
    attack_functions = {
        "grbcd": tgnw_attack,  # GRBCD uses tgnw_attack
        "memstranding": memstranding_attack,
        "random": tgnw_attack,  # Random attack uses tgnw_attack
        "negatt": nat_attack,  # Negative attack uses nat_attack
    }

    if attack_type not in attack_functions:
        raise ValueError(
            f"Unknown attack type: {attack_type}. Available attacks: {list(attack_functions.keys())}"
        )

    return attack_functions[attack_type]


def setup_device():
    """Setup device (GPU if available, otherwise CPU)."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using GPU: {torch.cuda.get_device_name()}")
    else:
        device = torch.device("cpu")
        print("Using CPU")
    return device


def run_experiment(config_path):
    """Run a single experiment based on configuration file."""
    logger = setup_logging()
    logger.info(f"Starting experiment with config: {config_path}")

    # Load configuration
    config = load_config(config_path)
    logger.info(f"Configuration loaded: {config}")

    # Setup device
    device = setup_device()

    # Set random seed
    set_random_seed(config.get("seed", 42))

    # Extract parameters
    dataset_name = config["dataset_name"]

    # Determine model name from config file name
    config_file_name = Path(config_path).stem
    if config_file_name.startswith("tgn_"):
        model_name = "tgnw"  # Map tgn configs to tgnw
    elif config_file_name.startswith("tgnw_"):
        model_name = "tgnw"
    elif config_file_name.startswith("tncn_"):
        model_name = "tncnw"  # Map tncn configs to tncnw
    elif config_file_name.startswith("tncnw_"):
        model_name = "tncnw"
    else:
        model_name = "tgnw"  # Default fallback

    logger.info(f"Using model: {model_name}")

    # Load dataset
    logger.info(f"Loading dataset: {dataset_name}")
    try:
        dataset = PyGLinkPropPredDataset(name=dataset_name, root="datasets")
        logger.info(f"Dataset loaded successfully. Size: {len(dataset)}")
    except Exception as e:
        logger.error(f"Failed to load dataset: {e}")
        return

    # Extract attack parameters
    attack_params = config.get("attack_params", {})
    attack_type = attack_params.get("attack_type", "grbcd")
    adv_budget = attack_params.get("adv_budget", 0.05)

    logger.info(f"Attack configuration: {attack_type} with budget {adv_budget}")

    # Run attack if specified
    if config.get("adv_attack", False):
        logger.info("Running adversarial attack...")
        attack_func = get_attack_function(attack_type)
        logger.info(f"Using attack: {attack_type}")

        try:
            # This is a placeholder - you'll need to implement the actual attack logic
            # based on your specific model and dataset structure
            logger.info("Attack execution placeholder - implement actual attack logic")
        except Exception as e:
            logger.error(f"Attack failed: {e}")

    # Run anomaly detection if specified
    if config.get("detect_anomalies", False) and SPOTLIGHT_AVAILABLE:
        logger.info("Running anomaly detection...")
        try:
            # Setup and run SPOTLIGHT anomaly detection
            spotlight = Spotlight()
            logger.info(
                "SPOTLIGHT anomaly detection placeholder - implement actual detection logic"
            )
        except Exception as e:
            logger.error(f"Anomaly detection failed: {e}")
    elif config.get("detect_anomalies", False) and not SPOTLIGHT_AVAILABLE:
        logger.warning("Anomaly detection requested but SPOTLIGHT not available")

    # Save results
    results = {
        "config_file": config_path,
        "model_name": model_name,
        "dataset_name": dataset_name,
        "attack_type": attack_type,
        "adv_budget": adv_budget,
        "timestamp": datetime.now().isoformat(),
        "status": "completed",
    }

    # Save results to JSON file
    output_file = f"results_{model_name}_{dataset_name}_{attack_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    with open(output_file, "w") as f:
        json.dump(results, f, indent=2)

    logger.info(f"Experiment completed successfully! Results saved to {output_file}")


def main():
    """Main function to parse arguments and run experiments."""
    parser = argparse.ArgumentParser(
        description="Run adversarial robustness experiments on temporal graphs"
    )
    parser.add_argument(
        "--config", "-c", required=True, help="Path to configuration YAML file"
    )
    parser.add_argument(
        "--verbose", "-v", action="store_true", help="Enable verbose logging"
    )

    args = parser.parse_args()

    # Check if config file exists
    if not os.path.exists(args.config):
        print(f"Error: Configuration file '{args.config}' not found!")
        sys.exit(1)

    # Run experiment
    try:
        run_experiment(args.config)
    except Exception as e:
        print(f"Error running experiment: {e}")
        if args.verbose:
            import traceback

            traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()
