import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from typing import Sequence, Any, Union, List, Optional, Dict


def load_model(
    pretrained_model_name_or_path: Optional[str] = None,
    local_model_path: Optional[str] = None,
    fine_tune_config: Optional[Dict[str, Any]] = None,
    token: Optional[str] = None,
) -> None:
    """
    Load a new model for the agent from
    a local path and update the agent's model.

    Arguments:
        local_model_path (str): The file path to the model to be loaded.
        pretrained_model_name_or_path (str): An identifier for
                                                the model on Huggingface.

    Raises:
        Exception: If the model cannot be loaded from the given
                    path or identifier.
                    Possible reasons include file not found,
                    incorrect model ID,
                    or network issues while fetching the model.
    """
    
    bnb_config = None
    bnb_config_default = {}

    if fine_tune_config is not None:
        if fine_tune_config.get("bnb_config") is not None:
            bnb_config_default.update(fine_tune_config["bnb_config"])
    if bnb_config_default != {}:
        bnb_config = BitsAndBytesConfig(**bnb_config_default)

    try:
        if local_model_path is None:
            model = AutoModelForCausalLM.from_pretrained(
                pretrained_model_name_or_path,
                device_map= 'cuda',
                torch_dtype=torch.bfloat16,
                **(
                    {"quantization_config": bnb_config}
                    if bnb_config is not None
                    else {}
                ),
                token=token,
                rope_scaling={"type": "dynamic", "factor": 2.0}  # Added proper RoPE scaling config
            )
            info_msg = (
                f"Successfully loaded new model "
                f"'{pretrained_model_name_or_path}' from "
                f"Hugging Face"
            )
        else:
            model = AutoModelForCausalLM.from_pretrained(
                local_model_path,
                device_map= 'cpu',
                torch_dtype=torch.bfloat16,
                **(
                    {"quantization_config": bnb_config}
                    if bnb_config is not None
                    else {}
                ),
                local_files_only=True,
                token=token,
                rope_scaling={"type": "dynamic", "factor": 2.0}  # Added proper RoPE scaling config
            )
            info_msg = (
                f"Successfully loaded new model "
                f"'{pretrained_model_name_or_path}' from "
                f"'{local_model_path}'"
            )
        from peft import LoraConfig

        lora_config = None
        lora_config_default = {}

        if fine_tune_config is not None:
            if fine_tune_config.get("lora_config") is not None:
                lora_config_default.update(fine_tune_config["lora_config"])
        
        from peft import get_peft_model

        if lora_config_default != {}:
            lora_config = LoraConfig(**lora_config_default)
            model = get_peft_model(model, lora_config)
        
        return model

        # log the successful model loading
        print(info_msg)

    except Exception as e:
        # Handle exceptions during model loading,
        # such as file not found or load errors
        error_msg = (
            f"Failed to load model '{pretrained_model_name_or_path}' "
            f"from '{local_model_path}': {e}"
        )

        print(error_msg)

        raise

def load_tokenizer(
    pretrained_model_name_or_path: Optional[str] = None,
    local_tokenizer_path: Optional[str] = None,
) -> None:
    """
    Load the tokenizer from a local path.

    Arguments:
        local_tokenizer_path (str): The file path to the
                                    tokenizer to be loaded.
        pretrained_model_name_or_path (str): An identifier
                                            for the model on Huggingface.
        fine_tune_config (dict, optional): Configuration options for
                                            fine-tuning the model,
                                            including QLoRA and training
                                            arguments.
    Raises:
        Exception: If the tokenizer cannot be loaded from the
        given path or identifier. Possible reasons include file not found,
        incorrect model ID, or network issues while fetching the tokenizer.
    """

    try:
        if local_tokenizer_path is None:
            tokenizer = AutoTokenizer.from_pretrained(
                pretrained_model_name_or_path,
                token = token,
            )
            # log the successful tokenizer loading
            print(
                f"Successfully loaded new tokenizer for model "
                f"'{pretrained_model_name_or_path}' from Hugging Face",
            )

        else:
            tokenizer = AutoTokenizer.from_pretrained(
                local_tokenizer_path,
            )
            # log the successful tokenizer loading
            print(
                f"Successfully loaded new tokenizer for model "
                f"'{pretrained_model_name_or_path}'"
                f" from '{local_tokenizer_path}'",
            )
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'})

        return tokenizer


    except Exception as e:
        # Handle exceptions during model loading,
        # such as file not found or load errors
        error_message = (
            f"Failed to load tokenizer for model"
            f" '{pretrained_model_name_or_path}' from "
            f"'{local_tokenizer_path}': {e}"
        )
        print(error_message)

        raise
