import argparse
import logging
import os
import random
import subprocess
import sys
import time
from datetime import datetime
from typing import Dict, Optional

import numpy as np
import torch
import wandb
from dotenv import load_dotenv
from huggingface_hub import login

# This will load the NVCC flag along with other environment variables
load_dotenv(override=True)

from src.dataset_processing.perplexity.loader import load_perplexity_dataset  # noqa: E402
from src.loggers.setup_logging import setup_logging  # noqa: E402
from src.model_loading.common.config.model_config import ModelConfig  # noqa: E402
from src.model_loading.manager import ModelManager  # noqa: E402
from src.model_loading.registry.enhanced_registry import EnhancedModelRegistry  # noqa: E402
from src.perplexity.evaluator import ModelEvaluator  # noqa: E402

from src.model_loading.registry.registry import ModelRegistry  # noqa: E402, F401
from src.dataset_processing.common.processing.base_manager import BaseDatasetManager  # noqa: E402, F401
from src.dataset_processing.datasets.commonsenseqa.processor import CommonSenseQAProcessor  # noqa: E402, F401
from src.dataset_processing.datasets.coqa.processor import CoQAProcessor  # noqa: E402, F401
from src.dataset_processing.datasets.triviaqa.processor import TriviaQAProcessor  # noqa: E402, F401
from src.model_loading.loaders.gptq import GPTQModelLoader  # noqa: E402, F401

# To avoid the unsupported GNU version error for AQLM models
subprocess.run(["export", "NVCC_APPEND_FLAGS=--allow-unsupported-compiler"], shell=True)

# Configure logging
logger = setup_logging()

# Set up environment variables
def setup_cache_paths():
    """Set up cache paths for HuggingFace models and other dependencies"""
    logger.info("Setting up cache paths...")
    
    load_dotenv()
    BASE_USER_PATH = os.getenv("BASE_USER_PATH")
    if not BASE_USER_PATH:
        BASE_USER_PATH = os.path.expanduser("~") + "/"
        logger.warning(f"BASE_USER_PATH not found in .env, using default: {BASE_USER_PATH}")

    # To avoid issues when running
    os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    CACHE_PATH = f"{BASE_USER_PATH}.cache/huggingface"
    HUB_PATH = os.path.join(CACHE_PATH, "hub")

    if not os.path.exists(HUB_PATH):
        os.makedirs(HUB_PATH)
        print(f"Creating huggingface hub path at {HUB_PATH}")

    print(f"Setting cache path to {CACHE_PATH}")
    os.environ["TORCH_HOME"] = CACHE_PATH
    os.environ["HF_HOME"] = CACHE_PATH
    os.environ["HUGGINGFACE_HUB_CACHE"] = CACHE_PATH
    os.environ["HUGGINGFACE_ASSETS_CACHE"] = CACHE_PATH
    os.environ["TRANSFORMERS_CACHE"] = CACHE_PATH

    torch.hub.set_dir(CACHE_PATH)

    # Empty the cache
    with torch.no_grad():
        torch.cuda.empty_cache()


# Initialize model registry
base_registry = ModelRegistry()
model_registry = EnhancedModelRegistry(base_registry)


def authenticate_huggingface(max_retries: int = 3, initial_delay: float = 1.0) -> None:
    """
    Authenticate with Hugging Face with retry logic.

    Args:
        max_retries: Maximum number of retry attempts
        initial_delay: Initial delay between retries in seconds (will increase exponentially)
    """
    logger = logging.getLogger(__name__)

    # Load environment variables
    load_dotenv()
    huggingface_token = os.getenv("HUGGINGFACE_TOKEN")

    if huggingface_token is None:
        raise ValueError(
            f"Please set the HUGGINGFACE_TOKEN environment variable. "
            f"Looking in {os.path.join(os.getcwd(), '.env')}"
        )

    logger.info("Hugging Face token loaded successfully.")

    attempt = 0
    last_exception: Optional[Exception] = None

    while attempt < max_retries:
        try:
            login(token=huggingface_token, add_to_git_credential=True)
            logger.info("Successfully authenticated with Hugging Face")
            return

        except ValueError as e:
            last_exception = e
            attempt += 1

            if attempt < max_retries:
                delay = initial_delay * (2 ** (attempt - 1))  # Exponential backoff
                logger.warning(
                    f"Authentication attempt {attempt} failed. "
                    f"Retrying in {delay:.1f} seconds... "
                    f"Error: {str(e)}"
                )
                time.sleep(delay)

    # If we get here, all retries failed
    logger.error(f"Failed to authenticate after {max_retries} attempts")
    raise last_exception


def set_seeds(random_seed: int = 42):
    """Set seeds for reproducibility"""
    # Python's built-in random
    random.seed(random_seed)

    # Numpy
    np.random.seed(random_seed)

    # PyTorch
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # for multi-GPU

    # Additional PyTorch settings for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set seed for hash-based operations
    os.environ["PYTHONHASHSEED"] = str(random_seed)


def load_model_and_tokenizer(
    model_name: str,
    device: str,
    max_memory: Optional[Dict[str, str]] = None,
    apply_compile: bool = True,
) -> tuple:
    """Load model and tokenizer using registry"""
    model_identifier = model_registry.get_model_by_string(model_name)
    if model_identifier is None:
        raise ValueError(f"Unknown model name: {model_name}")

    config = ModelConfig(
        identifier=model_identifier,
        device=device,
        trust_remote_code=True,
        max_memory=max_memory,
        apply_compile=apply_compile,
    )
    manager = ModelManager()
    model, tokenizer = manager.load_model(config)

    # Configure tokenizer
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.eval()

    return model, tokenizer


def run_perplexity_evaluation(args):
    """Run perplexity evaluation with the provided arguments"""
    # Log configuration
    logger.info("Received configuration:")
    logger.info(f"Experiment ID: {args.exp_id}")
    logger.info(f"Model: {args.model_name}")
    logger.info(f"Device: {args.device}")
    logger.info(f"Random seed: {args.random_seed}")
    logger.info(f"Dataset: {args.dataset_name} ({args.split})")
    logger.info(f"Split: {args.split}")
    logger.info(f"Number of samples: {args.n_samples}")
    logger.info(f"Sequence length: {args.seq_length}")
    logger.info(f"Batch size: {args.batch_size}")
    logger.info(f"Metrics: {args.metrics}")

    # Initialize wandb if not disabled
    if not args.disable_wandb:
        run = wandb.init(  # noqa: F841
            project="llm-perpl-eval",
            name=f"{args.exp_id}-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
            config={
                "exp_id": args.exp_id,
                "model_name": args.model_name,
                "device": args.device,
                "random_seed": args.random_seed,
                "dataset_name": args.dataset_name,
                "split": args.split,
                "n_samples": args.n_samples,
                "seq_length": args.seq_length,
                "batch_size": args.batch_size,
                "metrics": args.metrics,
            },
        )

    # Set seed
    set_seeds(args.random_seed)

    # If dataset is C4, divide the number of entries by 20
    n_samples = args.n_samples
    if args.dataset_name.lower() == "c4":
        n_samples = n_samples // 20
        logger.info(f"Adjusting n_samples for C4 dataset to {n_samples}")

    # Parse max_memory if provided as string
    max_memory = args.max_memory
    if isinstance(max_memory, str) and max_memory.strip():
        try:
            max_memory = eval(max_memory)
            logger.info(f"Parsed max_memory as {max_memory}")
        except Exception:
            logger.warning(f"Could not parse max_memory string: {max_memory}, setting to None")
            max_memory = None

    # Load model and tokenizer
    logger.info("Loading model and tokenizer")
    model, tokenizer = load_model_and_tokenizer(
        args.model_name, args.device, max_memory=max_memory, apply_compile=args.apply_compile
    )

    # Load dataset
    logger.info("Loading evaluation dataset")
    dataloader = load_perplexity_dataset(
        dataset_name=args.dataset_name,
        split=args.split,
        tokenizer=tokenizer,
        n_samples=n_samples,
        seq_length=args.seq_length,
        batch_size=args.batch_size,
        seed=args.random_seed,
    )

    # Run evaluation
    logger.info("Starting evaluation")
    start_time = time.time()
    evaluator = ModelEvaluator(device=args.device)
    results = evaluator.evaluate(
        model=model,
        dataloader=dataloader,
        metrics=args.metrics,
        n_samples=n_samples,
        to_device=("awq" in args.model_name),
    )
    end_time = time.time()

    # Add execution time to results
    results["execution_time"] = end_time - start_time
    
    # Log results to wandb if enabled
    if not args.disable_wandb:
        wandb.log(results)
        wandb.finish()

    # Print summary to console
    logger.info("Results Summary:")
    for key, value in results.items():
        logger.info(f"{key}: {value}")

    return results


def main():
    parser = argparse.ArgumentParser(description="Run perplexity evaluation for language models")
    
    # Experiment ID
    parser.add_argument("--exp_id", type=str, required=True, help="Experiment ID")
    
    # Model configuration
    parser.add_argument("--model_name", type=str, required=True, help="Model name")
    parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on (cuda or cpu)")
    parser.add_argument("--max_memory", type=str, default=None, help="Maximum memory to use (format as python dict string)")
    parser.add_argument("--apply_compile", action="store_true", help="Apply torch.compile to the model")
    parser.add_argument("--random_seed", type=int, default=42, help="Random seed for reproducibility")
    
    # Dataset parameters
    parser.add_argument("--dataset_name", type=str, required=True, help="Dataset name")
    parser.add_argument("--split", type=str, default="test", help="Dataset split (e.g., train, validation, test)")
    parser.add_argument("--n_samples", type=int, default=10, help="Number of samples to evaluate")
    parser.add_argument("--seq_length", type=int, default=2048, help="Sequence length for tokenization")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for evaluation")
    
    # Metrics
    parser.add_argument("--metrics", type=str, nargs="+", default=["perplexity"], 
                        help="Metrics to compute (e.g., perplexity, loss)")
    
    # Additional options
    parser.add_argument("--disable_wandb", action="store_true", help="Disable Weights & Biases logging")
    parser.add_argument("--skip_hf_auth", action="store_true", help="Skip HuggingFace authentication")

    args = parser.parse_args()
    
    # Setup cache paths
    setup_cache_paths()
    
    # Authenticate with HuggingFace if needed
    if not args.skip_hf_auth:
        try:
            authenticate_huggingface(max_retries=3, initial_delay=1.0)
        except Exception as e:
            logger.error(f"Authentication failed: {str(e)}")
            raise e

    # Run the evaluation
    return run_perplexity_evaluation(args)


if __name__ == "__main__":
    # Add parent directory to sys.path to ensure imports work correctly
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    
    # When running as a script, just parse args and run
    main()