import gc
import json
import peft
import torch
import torch.distributed as dist
import torch.nn as nn

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)
from typing import (
    Dict,
    List,
    Literal,
    Optional,
    Tuple,
    Union,
)

from .tokenizer import initialize_tokenizer
from ..models import (
    LlamaForCausalLM,
    LlamaMLP,
    LlamaMoEForCausalLM,
    MoLoSConfig,
    MoLoSLlamaForCausalLM,
    MoSLEConfig,
    MoSLELlamaForCausalLMV2,
    MoSLELlamaForCausalLMV3,
)
from ..SVDLLM import (
    get_calibration_dataset,
    profile,
    whiten,
)
from ..utils import (
    get_json_config,
    get_min_local_rank,
)


def initialize_model(
    base_model_name: str,
    model_type: Literal['molos', 'mosle', 'vanilla'],
    chosen_expert_index: int,
    lora_config_path: str,
    model_config_path: str,
    freeze_modules: Dict[str, Union[List[str], str]],
    cache_dir: str,
    is_training: bool,
    is_original_upcycling: bool = False,
    use_checkpoint: bool = False,
    checkpoint_dir: Optional[str] = None,
    checkpoint_tag: Optional[str] = None,
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
    """ Initializes the model.

    Args:
        base_model_name (str): The name of the base model.
        model_type (Literal['molos', 'mosle', 'vanilla']): The type of the model.
        chosen_expert_index (int): The index of the chosen expert.
        lora_config_path (str): The path to the LoRA configuration file.
        model_config_path (str): The path to the model configuration file.
        freeze_modules (Dict[str, Union[List[str], str]]): The modules freezing configuration.
        cache_dir (str): The path to the cache directory.
        is_training (bool): Whether the model is in training mode.
        is_original_upcycling (bool, optional): Whether to use the original upcycling. Defaults to False.
        use_checkpoint (bool): Whether to use the checkpoint. Defaults to False.
        checkpoint_dir (Optional[str], optional): The path to the checkpoint directory. Defaults to None.
        checkpoint_tag (Optional[str], optional): The tag of the checkpoint. Defaults to None.

    Returns:
        Tuple[PreTrainedModel, PreTrainedTokenizerBase]: The initialized model and tokenizer.
    """

    # Get the local rank.
    local_rank = dist.get_rank() if dist.is_initialized() else 0

    # Initialize the multi-process logger.
    try:
        from ..logging import get_logger
        from ..utils import get_path

        logger = get_logger()
        source = f'{get_path(source_file=__file__)}.{initialize_model.__name__}'
    except:
        pass

    device = torch.device(device=local_rank)

    # Initialize the base model and its tokenizer.
    if 'llamamoe' in base_model_name.lower():
        base_model = LlamaMoEForCausalLM.from_pretrained(
            pretrained_model_name_or_path=base_model_name,
            attn_implementation='flash_attention_2',
            # attn_implementation='flash_attention_3',
            tie_word_embeddings=False,
            use_cache=False,
            trust_remote_code=True,
        )
    elif 'llama' in base_model_name.lower():
        base_model = LlamaForCausalLM.from_pretrained(
            pretrained_model_name_or_path=base_model_name,
            attn_implementation='flash_attention_2',
            # attn_implementation='flash_attention_3',
            tie_word_embeddings=False,
            use_cache=False if is_training else True,
        )
    else:
        base_model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=base_model_name,
            attn_implementation='flash_attention_2',
            # attn_implementation='flash_attention_3',
            tie_word_embeddings=False,
            use_cache=False if is_training else True,
        )

    base_model.config.use_cache = False if is_training else True
    base_model.eval()

    tokenizer = initialize_tokenizer(
        model=base_model,
        model_name=base_model_name,
    )

    # Set a default sequence length.
    sequence_length = 2048

    message = f'The base model is initialized.'
    try:
        logger.log(
            message=message,
            source=source,
        )
    except:
        print(message)

        pass

    match model_type:
        case 'molos':
            molos_config = MoLoSConfig()
            if model_config_path is not None:
                molos_config_values = get_json_config(
                    config_path=model_config_path)
                molos_config = MoLoSConfig(**molos_config_values)

            # chosen_expert_index == -2 means that the user wants to use the chosen expert index from the model configuration.
            if chosen_expert_index != -2:
                molos_config.chosen_ex_idx = chosen_expert_index
                logger.log(
                    message=
                    '`chosen_ex_idx` is set to the value from the command line.',
                    level='warning',
                    source=source,
                )

                # chosen_expert_index == -1 means that the user wants to use the trained router to select the expert.
                if chosen_expert_index == -1:
                    molos_config.output_router_logits = True
                    logger.log(
                        message=
                        '`output_router_logits` is set to True to avoid the error.',
                        level='warning',
                        source=source,
                    )

            base_model.sequence_length = molos_config.sequence_length
            sequence_length = molos_config.sequence_length

            model = MoLoSLlamaForCausalLM(
                config=base_model.config,
                molos_config=molos_config,
            )
            model.eval()

            message = 'The empty MoLoS model is initialized.'

            try:
                logger.log(
                    message=message,
                    source=source,
                )
            except:
                print(message)

                pass

            ## Load the state dict except the MLP parameters from the base model to the MoLoS model.
            partial_state_dict = {}
            for name, param in base_model.named_parameters():
                if 'mlp' not in name:
                    partial_state_dict[name] = param

            model.load_state_dict(
                state_dict=partial_state_dict,
                strict=False,
                assign=False,
            )

            del partial_state_dict
            ## -----

            message = 'The MoLoS model (except experts) is loaded from the base model.'

            try:
                logger.log(
                    message=message,
                    source=source,
                )
            except:
                print(message)

                pass

            if not is_original_upcycling:
                early_exit = False

                for dataset_idx, (dataset_name, dataset) in enumerate(
                        iterable=molos_config.calibration_datasets.items()):
                    data_number = dataset['data_number']
                    batch_size = dataset['batch_size']

                    try:
                        calibration_dataset, calibration_dataset_name = \
                            get_calibration_dataset(
                                tokenizer=tokenizer,
                                data_name=dataset_name,
                                data_number=data_number,
                                sequence_length=sequence_length,
                                batch_size=batch_size,
                                cache_dir=cache_dir,
                                local_rank=local_rank,
                                only_return_name=True,
                        )

                        layers_profiling, layers_profiling_name = profile(
                            base_model=base_model,
                            base_model_name=base_model_name.split('/')[-1],
                            calibration_dataset=calibration_dataset,
                            calibration_dataset_name=calibration_dataset_name,
                            cache_dir=cache_dir,
                            local_rank=local_rank,
                            device=device,
                            only_return_name=True,
                        )

                        whiten(
                            base_model=base_model,
                            model=model,
                            layers_profiling=layers_profiling,
                            layers_profiling_name=layers_profiling_name,
                            ratio=molos_config.ex_params_ratio,
                            cache_dir=cache_dir,
                            local_rank=local_rank,
                            device=device,
                            expert_idx=dataset_idx,
                        )
                    except Exception as e:
                        early_exit = True

                        logger.log(
                            message=
                            'Failed to initialize experts without loading profilings.',
                            level='warning',
                            source=source,
                        )

                        if local_rank == 0:
                            calibration_dataset, calibration_dataset_name = \
                                get_calibration_dataset(
                                    tokenizer=tokenizer,
                                    data_name=dataset_name,
                                    data_number=data_number,
                                    sequence_length=sequence_length,
                                    batch_size=batch_size,
                                    cache_dir=cache_dir,
                                    local_rank=local_rank,
                            )

                            layers_profiling, layers_profiling_name = profile(
                                base_model=base_model,
                                base_model_name=base_model_name.split('/')[-1],
                                calibration_dataset=calibration_dataset,
                                calibration_dataset_name=
                                calibration_dataset_name,
                                cache_dir=cache_dir,
                                local_rank=local_rank,
                                device=device,
                            )

                            whiten(
                                base_model=base_model,
                                model=model,
                                layers_profiling=layers_profiling,
                                layers_profiling_name=layers_profiling_name,
                                ratio=molos_config.ex_params_ratio,
                                cache_dir=cache_dir,
                                local_rank=local_rank,
                                device=device,
                                expert_idx=dataset_idx,
                            )

                    message = f'The expert {dataset_idx} is initialized with the {dataset_name} dataset.'

                    try:
                        logger.log(
                            message=message,
                            source=source,
                        )
                    except:
                        print(message)

                        pass

                message = 'The MoLoS model is initialized.'

                try:
                    logger.log(
                        message=message,
                        source=source,
                    )
                except:
                    print(message)

                    pass

                if early_exit:
                    exit(code=0)
            else:
                for layer_idx, layer in enumerate(iterable=model.model.layers):
                    layer.moe.experts = nn.ModuleList(modules=[
                        LlamaMLP(config=base_model.config) \
                            for _ in range(molos_config.ex_num)
                    ])

                for ex_idx in range(molos_config.ex_num):
                    partial_state_dict = {}
                    for name, param in base_model.named_parameters():
                        if 'mlp' in name:
                            _name = name.replace('mlp',
                                                 f'moe.experts.{ex_idx}')
                            partial_state_dict[_name] = param

                    if local_rank == 0:
                        for key in partial_state_dict.keys():
                            if key not in model.state_dict().keys():
                                raise ValueError()

                    model.load_state_dict(
                        state_dict=partial_state_dict,
                        strict=False,
                        assign=False,
                    )

                    del partial_state_dict
        case 'mosle':
            base_model = base_model.to(device=device)

            # model = MoSLELlamaForCausalLMV2(
            #     base_model=model,
            #     config=model.config,
            #     mosle_config=MoSLEConfig(),
            # )
            model = MoSLELlamaForCausalLMV3(
                base_model=base_model,
                config=base_model.config,
                mosle_config=MoSLEConfig(),
            )

            message = 'The MoSLE model is initialized.'

            try:
                logger.log(
                    message=message,
                    source=source,
                )
            except:
                print(message)

                pass
        case 'vanilla':
            model = base_model

            message = 'The Vanilla model is initialized.'

            try:
                logger.log(
                    message=message,
                    source=source,
                )
            except:
                print(message)

                pass
        case _:
            message = f'Invalid model type: {model_type}. The model type must be one of [molos, mosle, vanilla].'

            try:
                logger.log(
                    message=message,
                    level='error',
                    source=source,
                )
            except:
                print(message)

                pass

            raise ValueError(message)

    del base_model

    gc.collect()
    torch.cuda.empty_cache()

    model = model.to(device=device)

    # Initialize the LoRA model.
    if lora_config_path:
        lora_config_item = None

        with open(
                file=lora_config_path,
                mode='r',
        ) as file:
            lora_config_item = json.load(fp=file)

            file.close()

        if isinstance(lora_config_item, Dict) and \
                (all(isinstance(v, Dict) for v in lora_config_item.values())
        ):
            is_peft_model_initialized = False
            lora_config_names = []

            for lora_config_name, lora_config in lora_config_item.items():
                lora_config_names.append(lora_config_name)
                lora_config = peft.LoraConfig(**lora_config)

                if is_peft_model_initialized:
                    model.add_adapter(
                        adapter_name=lora_config_name,
                        peft_config=lora_config,
                    )
                else:
                    model = peft.get_peft_model(
                        model=model,
                        peft_config=lora_config,
                        adapter_name=lora_config_name,
                    )

                    is_peft_model_initialized = True

            model.base_model.set_adapter(adapter_name=lora_config_names)
        else:
            lora_config = peft.LoraConfig(**lora_config_item)
            model = peft.get_peft_model(
                model=model,
                peft_config=lora_config,
            )

        message = 'The LoRA model is initialized.'

        try:
            logger.log(
                message=message,
                source=source,
            )
        except:
            print(message)

            pass

        ## Check whether the trainable parameters are correct.
        # model.print_trainable_parameters()
        ## -----

    ## Print the model memory usage.
    message = f'Model Memory Usage: {(torch.cuda.memory_allocated() / 1024 / 1024):.2f} MiB'

    try:
        logger.log(
            message=message,
            source=source,
        )
    except:
        print(message)

        pass
    ## -----

    ## Freeze the specific parameters of the model.
    if freeze_modules['type'] != 'none':
        # If the `freeze_modules['type']` is 'negative'.
        match_grad = True
        unmatch_grad = False

        if freeze_modules['type'] == 'positive':
            match_grad = False
            unmatch_grad = True

        for param_name, param in model.named_parameters():
            for module_keyword in freeze_modules['keywords']:
                if module_keyword in param_name:
                    param.requires_grad = match_grad
                elif 'auto' in freeze_modules['type']:
                    param.requires_grad = unmatch_grad
    ## -----

    ## Check the trainable parameters.
    message = 'The trainable parameters are:'
    for param_name, param in model.named_parameters():
        if param.requires_grad:
            message += f'\n- {param_name}'

    try:
        logger.log(
            message=message,
            source=source,
        )
    except:
        print(message)

        pass

    if model_type == 'molos':
        message = model.print_trainable_parameters()
        try:
            logger.log(
                message=message,
                source=source,
            )
        except:
            print(message)

            pass
    ## -----

    ## Check whether the model state dict is the same across all processes. If all processes are the same, the model is initialized correctly.
    # params_sum = sum(
    #     torch.sum(input=param).item() for param in model.state_dict().values())

    # message = f'Model Parameters Sum of Rank {local_rank}: {params_sum}'

    # try:
    #     logger.log(
    #         message=message,
    #         source=source,
    #     )
    # except:
    #     print(message)

    #     pass
    ## -----

    ## Broadcast the model parameters to all processes.
    # minimum_local_rank = get_min_local_rank()
    # for param in model.state_dict().values():
    #     dist.broadcast(
    #         tensor=param,
    #         src=minimum_local_rank,
    #     )
    ## -----

    return (
        model,
        tokenizer,
    )
