import argparse
import logging
import os
import random
import sys
import time
from typing import Optional

import numpy as np
import torch
from dotenv import load_dotenv
from huggingface_hub import login

# This will load the NVCC flag along with other environment variables
load_dotenv(override=True)

from src.dataset_processing.common.enums.source_types import DatasetSourceType  # noqa: E402
from src.dataset_processing.factory import DatasetFactory  # noqa: E402
from src.loggers.setup_logging import setup_logging  # noqa: E402
from src.reliability_eval.common.enums.mappings import (  # noqa: E402
    DATASET_TYPE_MAP,
    PERTURBATION_TYPE_MAP,
    SOURCE_TYPE_MAP,
)

from src.model_loading.registry.registry import ModelRegistry  # noqa: E402, F401
from src.dataset_processing.common.processing.base_manager import BaseDatasetManager  # noqa: E402, F401
from src.dataset_processing.datasets.commonsenseqa.processor import CommonSenseQAProcessor  # noqa: E402, F401
from src.dataset_processing.datasets.coqa.processor import CoQAProcessor  # noqa: E402, F401
from src.dataset_processing.datasets.triviaqa.processor import TriviaQAProcessor  # noqa: E402, F401
from src.model_loading.loaders.gptq import GPTQModelLoader  # noqa: E402, F401

# Configure logging
logger = setup_logging()

# Set up environment variables and cache paths
def setup_cache_paths():
    """Set up cache paths for HuggingFace models and other dependencies"""
    logger.info("Setting up cache paths...")
    
    load_dotenv()
    BASE_USER_PATH = os.getenv("BASE_USER_PATH")
    if not BASE_USER_PATH:
        BASE_USER_PATH = os.path.expanduser("~") + "/"
        logger.warning(f"BASE_USER_PATH not found in .env, using default: {BASE_USER_PATH}")

    # To avoid issues when running
    os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    CACHE_PATH = f"{BASE_USER_PATH}.cache/huggingface"
    HUB_PATH = os.path.join(CACHE_PATH, "hub")

    if not os.path.exists(HUB_PATH):
        os.makedirs(HUB_PATH)
        print(f"Creating huggingface hub path at {HUB_PATH}")

    print(f"Setting cache path to {CACHE_PATH}")
    os.environ["TORCH_HOME"] = CACHE_PATH
    os.environ["HF_HOME"] = CACHE_PATH
    os.environ["HUGGINGFACE_HUB_CACHE"] = CACHE_PATH
    os.environ["HUGGINGFACE_ASSETS_CACHE"] = CACHE_PATH
    os.environ["TRANSFORMERS_CACHE"] = CACHE_PATH

    torch.hub.set_dir(CACHE_PATH)

    # Empty the cache
    with torch.no_grad():
        torch.cuda.empty_cache()


def authenticate_huggingface(max_retries: int = 3, initial_delay: float = 1.0) -> None:
    """
    Authenticate with Hugging Face with retry logic.

    Args:
        max_retries: Maximum number of retry attempts
        initial_delay: Initial delay between retries in seconds (will increase exponentially)
    """
    logger = logging.getLogger(__name__)

    # Load environment variables
    load_dotenv()
    huggingface_token = os.getenv("HUGGINGFACE_TOKEN")

    if huggingface_token is None:
        raise ValueError(
            f"Please set the HUGGINGFACE_TOKEN environment variable. "
            f"Looking in {os.path.join(os.getcwd(), '.env')}"
        )

    logger.info("Hugging Face token loaded successfully.")

    attempt = 0
    last_exception: Optional[Exception] = None

    while attempt < max_retries:
        try:
            login(token=huggingface_token, add_to_git_credential=True)
            logger.info("Successfully authenticated with Hugging Face")
            return

        except ValueError as e:
            last_exception = e
            attempt += 1

            if attempt < max_retries:
                delay = initial_delay * (2 ** (attempt - 1))  # Exponential backoff
                logger.warning(
                    f"Authentication attempt {attempt} failed. "
                    f"Retrying in {delay:.1f} seconds... "
                    f"Error: {str(e)}"
                )
                time.sleep(delay)

    # If we get here, all retries failed
    logger.error(f"Failed to authenticate after {max_retries} attempts")
    raise last_exception


def set_seeds(random_seed: int = 42):
    """Set seeds for reproducibility"""
    # Python's built-in random
    random.seed(random_seed)

    # Numpy
    np.random.seed(random_seed)

    # PyTorch
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # for multi-GPU

    # Additional PyTorch settings for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set seed for hash-based operations
    os.environ["PYTHONHASHSEED"] = str(random_seed)


def process_datasets(args):
    """Process and save datasets according to configuration."""
    logger.info("Dataset processing configuration:")
    logger.info(f"Dataset type: {args.dataset_type}")
    logger.info(f"Source type: {args.source_type}")
    logger.info(f"Number of entries: {args.num_entries}")
    logger.info(f"Random seed: {args.random_seed}")
    logger.info(f"Perturbation type: {args.perturbation_type}")
    logger.info(f"Perturbation intensity: {args.perturbation_intensity}")
    logger.info(f"Force reprocess: {args.force_reprocess}")

    # Convert strings to enums
    dataset_type_enum = DATASET_TYPE_MAP[args.dataset_type]
    source_type_enum = SOURCE_TYPE_MAP[args.source_type]
    perturbation_type_enum = PERTURBATION_TYPE_MAP[args.perturbation_type]

    # Set seeds for reproducibility
    set_seeds(random_seed=args.random_seed)

    # Set dataset name and num shots
    dataset_name = args.dataset_type.lower()
    num_shots = 0

    # Get dataset processor
    config_class = DatasetFactory.create_config(dataset_type_enum, source_type_enum)

    # Create dataset config
    dataset_config_dict = {
        "dataset_type": dataset_type_enum,
        "dataset_name": dataset_name,
        "source_type": source_type_enum,
        "num_entries": args.num_entries,
        "num_shots": num_shots,
        "force_reprocess": args.force_reprocess,
        "random_seed": args.random_seed,
    }

    if source_type_enum == DatasetSourceType.PROCESSED:
        logger.info("Setting up processed dataset configuration...")
        dataset_config_dict["perturbation_type"] = perturbation_type_enum
        dataset_config_dict["perturbation_intensity"] = args.perturbation_intensity

    dataset_config = config_class(**dataset_config_dict)

    # Process dataset
    start_time = time.time()
    processor = DatasetFactory.create_processor(
        dataset_type=dataset_type_enum
    )
    processor.process_dataset(dataset_config)
    end_time = time.time()
    
    logger.info(f"Dataset processing complete in {end_time - start_time:.2f} seconds.")

    return {"success": True, "execution_time": end_time - start_time}


def main():
    parser = argparse.ArgumentParser(description="Process datasets for reliability evaluation")
    
    # Dataset configuration
    parser.add_argument("--dataset_type", type=str, required=True,
                        choices=list(DATASET_TYPE_MAP.keys()),
                        help="Type of dataset to process")
    parser.add_argument("--source_type", type=str, default="raw",
                        choices=list(SOURCE_TYPE_MAP.keys()),
                        help="Source type (raw or processed)")
    parser.add_argument("--num_entries", type=int, default=None,
                        help="Number of entries to process (None for all)")
    parser.add_argument("--random_seed", type=int, default=42,
                        help="Random seed for reproducibility")
    
    # Perturbation parameters (for processed datasets)
    parser.add_argument("--perturbation_type", type=str, default="none",
                        choices=list(PERTURBATION_TYPE_MAP.keys()),
                        help="Type of perturbation to apply")
    parser.add_argument("--perturbation_intensity", type=int, default=0,
                        help="Intensity of perturbation")
    
    # Additional options
    parser.add_argument("--force_reprocess", action="store_true",
                        help="Force reprocessing of datasets")
    parser.add_argument("--skip_hf_auth", action="store_true",
                        help="Skip HuggingFace authentication")

    args = parser.parse_args()
    
    # Setup cache paths
    setup_cache_paths()
    
    # Authenticate with HuggingFace if needed
    if not args.skip_hf_auth:
        try:
            authenticate_huggingface(max_retries=3, initial_delay=1.0)
        except Exception as e:
            logger.error(f"Authentication failed: {str(e)}")
            raise e

    # Process the datasets
    return process_datasets(args)


if __name__ == "__main__":
    # Add parent directory to sys.path to ensure imports work correctly
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    
    # When running as a script, just parse args and run
    main()