import argparse
import logging
import os
import random
import time
from datetime import datetime
from typing import Optional, List

import numpy as np
import torch
import wandb
from dotenv import load_dotenv
from huggingface_hub import login

# This will load the NVCC flag along with other environment variables
load_dotenv(override=True)

from src.dataset_processing.common.processing.base_manager import BaseDatasetManager  # noqa: E402, F401

from src.model_loading.registry.registry import ModelRegistry  # noqa: E402, F401
from src.dataset_processing.common.enums.source_types import DatasetSourceType  # noqa: E402
from src.dataset_processing.datasets.commonsenseqa.processor import CommonSenseQAProcessor  # noqa: E402, F401
from src.dataset_processing.datasets.coqa.processor import CoQAProcessor  # noqa: E402, F401
from src.dataset_processing.datasets.triviaqa.processor import TriviaQAProcessor  # noqa: E402, F401
from src.dataset_processing.factory import DatasetFactory  # noqa: E402
from src.loggers.setup_logging import setup_logging  # noqa: E402
from src.model_loading.loaders.gptq import GPTQModelLoader  # noqa: E402, F401
from src.model_loading.registry.enhanced_registry import EnhancedModelRegistry  # noqa: E402
from src.reliability_eval.common.config.experiment import GenerationExperimentConfig  # noqa: E402
from src.reliability_eval.common.enums.mappings import (  # noqa: E402
    DATASET_TYPE_MAP,
    GENERATION_STRATEGY_MAP,
    PERTURBATION_TYPE_MAP,
    PIPELINE_TYPE_MAP,
    PROMPT_STRATEGY_MAP,
    SOURCE_TYPE_MAP,
)
from src.reliability_eval.pipeline.config import BatchConfig, IntegratedPipelineConfig  # noqa: E402
from src.reliability_eval.pipeline.core import IntegratedEvaluationPipeline  # noqa: E402
from src.reliability_eval.pipeline.evaluation_pipelines.registry import EVALUATION_PIPELINES_DICT  # noqa: E402
from src.reliability_eval.pipeline.evaluation_pipelines.types import PipelineType  # noqa: E402
from src.reliability_eval.pipeline.processor.results_processor import ResultsProcessor  # noqa: E402

# Set up environment variables
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Configure logging
logger = setup_logging()

# Initialize the registries at the module level (outside any function)
base_registry = ModelRegistry()
model_registry = EnhancedModelRegistry(base_registry)


def set_seeds(random_seed: int = 42):
    """Set seeds for reproducibility"""
    # Python's built-in random
    random.seed(random_seed)

    # Numpy
    np.random.seed(random_seed)

    # PyTorch
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # for multi-GPU

    # Additional PyTorch settings for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set seed for hash-based operations
    os.environ["PYTHONHASHSEED"] = str(random_seed)


def authenticate_huggingface(max_retries: int = 5, initial_delay: float = 1.5) -> None:
    """
    Authenticate with Hugging Face with retry logic.

    Args:
        max_retries: Maximum number of retry attempts (default is 5)
        initial_delay: Initial delay between retries in seconds (will increase according to custom pattern)
    """
    logger = logging.getLogger(__name__)
    # Load environment variables
    load_dotenv()
    huggingface_token = os.getenv("HUGGINGFACE_TOKEN")

    if huggingface_token is None:
        raise ValueError(
            f"Please set the HUGGINGFACE_TOKEN environment variable. "
            f"Looking in {os.path.join(os.getcwd(), '.env')}"
        )
    logger.info("Hugging Face token loaded successfully.")

    attempt = 0
    last_exception: Optional[Exception] = None

    while attempt < max_retries:
        try:
            login(token=huggingface_token, add_to_git_credential=True)
            logger.info("Successfully authenticated with Hugging Face")
            return
        except ValueError as e:
            last_exception = e
            attempt += 1
            if attempt < max_retries:
                # Custom wait times: 1.5s, ~3s, ~6s, 120s (2min), 300s (5min)
                if attempt == 1:
                    delay = initial_delay
                elif attempt == 2:
                    delay = initial_delay * 2
                elif attempt == 3:
                    delay = initial_delay * 4
                elif attempt == 4:
                    delay = 120  # 2 minutes
                else:
                    delay = 300  # 5 minutes

                logger.warning(
                    f"Authentication attempt {attempt} failed. "
                    f"Retrying in {delay:.1f} seconds... "
                    f"Error: {str(e)}"
                )
                time.sleep(delay)

    # If we get here, all retries failed
    logger.error(f"Failed to authenticate after {max_retries} attempts")
    raise last_exception


def get_automatic_batch_size(
    model_name: str, dataset_type: str, num_repeats: int, partition: str
) -> int:
    """Automatically determine batch size based on model name, dataset type, number of repeats and GPU type."""
    automatic_batch_size = 32
    if partition not in ["gpu_gtx_1080", "gpu_a100"]:
        raise ValueError(f"Unsupported partition: {partition}")
    if partition == "gpu_gtx_1080":
        if dataset_type == "COQA":
            if "llama32_1b" in model_name or "llama32_3b" in model_name:
                automatic_batch_size = 8
            automatic_batch_size = 4
    if partition == "gpu_a100":
        if dataset_type == "COQA":
            automatic_batch_size = automatic_batch_size // 2
    if num_repeats > 1:
        automatic_batch_size = automatic_batch_size // num_repeats

    return automatic_batch_size


def convert_pipeline_types(pipeline_types_str: List[str]) -> List[PipelineType]:
    """Convert pipeline type strings to enums"""
    return [PIPELINE_TYPE_MAP[pt] for pt in pipeline_types_str]


def setup_cache_paths():
    """Set up cache paths for HuggingFace models and other dependencies"""
    logger.info("Setting up cache paths...")
    
    load_dotenv()
    BASE_USER_PATH = os.getenv("BASE_USER_PATH")
    if not BASE_USER_PATH:
        BASE_USER_PATH = os.path.expanduser("~") + "/"
        logger.warning(f"BASE_USER_PATH not found in .env, using default: {BASE_USER_PATH}")
    
    # To avoid the following problem when running (see https://github.com/pytorch/pytorch/issues/37377)
    os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"
    
    CACHE_PATH = f"{BASE_USER_PATH}.cache/huggingface"
    HUB_PATH = os.path.join(CACHE_PATH, "hub")

    if not os.path.exists(HUB_PATH):
        os.makedirs(HUB_PATH)
        print(f"Creating huggingface hub path at {HUB_PATH}")

    print(f"Setting cache path to {CACHE_PATH}")
    os.environ["TORCH_HOME"] = CACHE_PATH
    os.environ["HF_HOME"] = CACHE_PATH
    os.environ["HUGGINGFACE_HUB_CACHE"] = CACHE_PATH
    os.environ["HUGGINGFACE_ASSETS_CACHE"] = CACHE_PATH
    os.environ["TRANSFORMERS_CACHE"] = CACHE_PATH

    torch.hub.set_dir(CACHE_PATH)

    # Empty the cache
    with torch.no_grad():
        torch.cuda.empty_cache()


def run_reliability_evaluation(args):
    """Run reliability evaluation pipeline with the provided arguments"""
    # Log configuration
    logger.info("Received configuration:")
    logger.info(f"Experiment ID: {args.exp_id}")
    logger.info(f"Model: {args.model_name}")
    logger.info(f"Dataset: {args.dataset_type}")
    logger.info(f"Perturbation: {args.perturbation_type} (intensity: {args.perturbation_intensity})")
    logger.info(f"Device: {args.device}")
    
    args.pipeline_types = [
        "nll_pipeline",
        "confidence_pipeline",
        "entropy_pipeline",
        "topk_pipeline",
        # "semantic_entropy_pipeline"  # Only if num_repeats > 1
    ]
    
    # Initialize wandb if not disabled
    if not args.disable_wandb:
        run = wandb.init(  # noqa: F841
            project="llm-reliability-eval",
            name=f"{args.exp_id}-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
            config={
                "exp_id": args.exp_id,
                "model_name": args.model_name,
                "device": args.device,
                "random_seed": args.random_seed,
                "max_memory": args.max_memory,
                "apply_compile": args.apply_compile,
                "partition": args.partition,
                "dataset_type": args.dataset_type,
                "source_type": args.source_type,
                "num_entries": args.num_entries,
                "perturbation_type": args.perturbation_type,
                "perturbation_intensity": args.perturbation_intensity,
                "force_reprocess": args.force_reprocess,
                "prompt_strategy": args.prompt_strategy,
                "generation_strategy": args.generation_strategy,
                "num_repeats": args.num_repeats,
                "max_new_tokens": args.max_new_tokens,
                "temperature": args.temperature,
                "num_beams": args.num_beams,
                "num_return_sequences": args.num_return_sequences,
                "top_k": args.top_k,
                "top_p": args.top_p,
                "batch_size": args.batch_size,
                "shuffle": args.shuffle,
                "drop_last": args.drop_last,
                "pipeline_types": args.pipeline_types,
            },
        )

    # Set seeds for reproducibility
    set_seeds(random_seed=args.random_seed)
    
    # Set batch size
    if args.batch_size == "auto":
        batch_size = get_automatic_batch_size(
            args.model_name, args.dataset_type, args.num_repeats, args.partition
        )
        logger.info(f"Automatically determined batch size: {batch_size}")
    else:
        batch_size = args.batch_size

    # Set dataset name and num shots
    dataset_name = args.dataset_type.lower()
    num_shots = 0

    # Convert strings to enums
    dataset_type_enum = DATASET_TYPE_MAP[args.dataset_type]
    source_type_enum = SOURCE_TYPE_MAP[args.source_type]
    perturbation_type_enum = PERTURBATION_TYPE_MAP[args.perturbation_type]
    prompt_strategy_enum = PROMPT_STRATEGY_MAP[args.prompt_strategy]
    generation_strategy_enum = GENERATION_STRATEGY_MAP[args.generation_strategy]
    pipeline_types_enum = convert_pipeline_types(args.pipeline_types)

    # Get dataset processor
    config_class = DatasetFactory.create_config(dataset_type_enum, source_type_enum)

    # Create dataset config
    dataset_config_dict = {
        "dataset_type": dataset_type_enum,
        "dataset_name": dataset_name,
        "source_type": source_type_enum,
        "num_entries": args.num_entries,
        "num_shots": num_shots,
        "force_reprocess": args.force_reprocess,
        "random_seed": args.random_seed,
    }
    if source_type_enum == DatasetSourceType.PROCESSED:
        dataset_config_dict["perturbation_type"] = perturbation_type_enum
        dataset_config_dict["perturbation_intensity"] = args.perturbation_intensity

    dataset_config = config_class(**dataset_config_dict)

    # Get model identifier using the string-based registry
    model_identifier = model_registry.get_model_by_string(args.model_name)
    if model_identifier is None:
        raise ValueError(f"Unknown model name: {args.model_name}")

    # Create generation config
    generation_config_kwargs = {
        "dataset_name": dataset_name,
        "prompt_strategy": prompt_strategy_enum,
        "generation_strategy": generation_strategy_enum,
        "num_repeats": args.num_repeats,
        "num_shots": num_shots,
        "max_new_tokens": args.max_new_tokens,
        "temperature": args.temperature,
    }

    # Add optional parameters if provided
    if args.num_beams is not None:
        generation_config_kwargs["num_beams"] = args.num_beams
    if args.num_return_sequences is not None:
        generation_config_kwargs["num_return_sequences"] = args.num_return_sequences
    if args.top_k is not None:
        generation_config_kwargs["top_k"] = args.top_k
    if args.top_p is not None:
        generation_config_kwargs["top_p"] = args.top_p

    generation_config = GenerationExperimentConfig(**generation_config_kwargs)

    # Create batch config
    batch_config = BatchConfig(
        batch_size=batch_size, shuffle=args.shuffle, drop_last=args.drop_last
    )

    # Create pipeline config
    pipeline_config = IntegratedPipelineConfig(
        model_identifier=model_identifier,
        device=args.device,
        max_memory=args.max_memory,
        apply_compile=args.apply_compile,
        random_seed=args.random_seed,
        dataset_config=dataset_config,
        generation_config=generation_config,
        batch_config=batch_config,
        pipeline_types=pipeline_types_enum,
        evaluation_config=EVALUATION_PIPELINES_DICT,
        exp_id=args.exp_id,
        num_excel_entries=0
    )

    # Run pipeline
    logger.info("Running reliability evaluation pipeline...")
    pipeline = IntegratedEvaluationPipeline()
    start_time = time.time()
    results = pipeline.run_evaluation(pipeline_config)
    end_time = time.time()
    execution_time = end_time - start_time
    logger.info(f"Execution time: {execution_time:.2f} seconds")

    processed_summary = ResultsProcessor.post_process_results(results)
    processed_summary["execution_time"] = execution_time
    logger.info("Reliability evaluation pipeline complete.")
    
    # Log results to wandb if enabled
    if not args.disable_wandb:
        wandb.log(processed_summary)
        wandb.finish()

    # Print summary to console
    logger.info("Results Summary:")
    for key, value in processed_summary.items():
        if isinstance(value, dict):
            logger.info(f"{key}:")
            for subkey, subvalue in value.items():
                logger.info(f"  {subkey}: {subvalue}")
        else:
            logger.info(f"{key}: {value}")

    return processed_summary


def main():
    parser = argparse.ArgumentParser(description="Run reliability evaluation for language models")
    
    # Experiment ID
    parser.add_argument("--exp_id", type=str, required=True, help="Experiment ID")
    
    # Model configuration
    parser.add_argument("--model_name", type=str, required=True, help="Model name")
    parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on (cuda or cpu)")
    parser.add_argument("--max_memory", type=str, default=None, help="Maximum memory to use (format as python dict string)")
    parser.add_argument("--apply_compile", action="store_true", help="Apply torch.compile to the model")
    parser.add_argument("--partition", type=str, default="gpu_gtx_1080", choices=["gpu_gtx_1080", "gpu_a100"], 
                        help="GPU partition type to use for automatic batch size determination")
    parser.add_argument("--random_seed", type=int, default=42, help="Random seed for reproducibility")
    
    # Dataset configuration
    parser.add_argument("--dataset_type", type=str, default="COMMONSENSEQA", 
                        choices=list(DATASET_TYPE_MAP.keys()), help="Dataset type")
    parser.add_argument("--source_type", type=str, default="raw", 
                        choices=list(SOURCE_TYPE_MAP.keys()), help="Source type")
    parser.add_argument("--num_entries", type=int, default=None, help="Number of dataset entries to use")
    parser.add_argument("--perturbation_type", type=str, default="none", 
                        choices=list(PERTURBATION_TYPE_MAP.keys()), help="Perturbation type")
    parser.add_argument("--perturbation_intensity", type=int, default=0, help="Perturbation intensity")
    parser.add_argument("--force_reprocess", action="store_true", help="Force reprocessing of dataset")
    
    # Generation configuration
    parser.add_argument("--prompt_strategy", type=str, default="Original", 
                        choices=list(PROMPT_STRATEGY_MAP.keys()), help="Prompt strategy")
    parser.add_argument("--generation_strategy", type=str, default="multinomial_sampling", 
                        choices=list(GENERATION_STRATEGY_MAP.keys()), help="Generation strategy")
    parser.add_argument("--num_repeats", type=int, default=1, help="Number of generation repeats")
    parser.add_argument("--max_new_tokens", type=int, default=25, help="Maximum number of new tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling")
    parser.add_argument("--num_beams", type=int, default=None, help="Number of beams for beam search")
    parser.add_argument("--num_return_sequences", type=int, default=None, help="Number of sequences to return")
    parser.add_argument("--top_k", type=int, default=None, help="Top-k sampling parameter")
    parser.add_argument("--top_p", type=float, default=None, help="Top-p sampling parameter")
    
    # Batch configuration
    parser.add_argument("--batch_size", type=str, default="auto", help="Batch size or 'auto' for automatic determination")
    parser.add_argument("--shuffle", action="store_true", help="Shuffle dataset")
    parser.add_argument("--drop_last", action="store_true", help="Drop last incomplete batch")
    
    # Additional options
    parser.add_argument("--disable_wandb", action="store_true", help="Disable Weights & Biases logging")
    parser.add_argument("--skip_hf_auth", action="store_true", help="Skip HuggingFace authentication")

    args = parser.parse_args()
    
    # Parse max_memory if provided as string
    if args.max_memory and isinstance(args.max_memory, str):
        try:
            args.max_memory = eval(args.max_memory)
        except Exception:
            logger.warning(f"Could not parse max_memory string: {args.max_memory}, setting to None")
            args.max_memory = None
    
    # Parse batch_size if not 'auto'
    if args.batch_size != "auto":
        try:
            args.batch_size = int(args.batch_size)
        except Exception:
            logger.warning(f"Could not parse batch_size: {args.batch_size}, using 'auto'")
            args.batch_size = "auto"
    
    # Setup cache paths
    setup_cache_paths()
    
    # Authenticate with HuggingFace if needed
    if not args.skip_hf_auth:
        try:
            authenticate_huggingface(max_retries=3, initial_delay=1.0)
        except Exception as e:
            logger.error(f"Authentication failed: {str(e)}")
            raise e

    # Run the evaluation
    return run_reliability_evaluation(args)


if __name__ == "__main__":
    main()