import argparse
import deepspeed
import gc
import os
import torch
import torch.distributed as dist

from deepspeed.inference.engine import InferenceEngine
from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.engine import DeepSpeedEngine
from deepspeed.runtime.lr_schedules import WarmupLR
from deepspeed.utils.zero_to_fp32 import \
    get_fp32_state_dict_from_zero_checkpoint

from torch.utils.data import DataLoader
from transformers import (
    PreTrainedModel,
    PreTrainedTokenizerBase,
)
from typing import (
    Dict,
    Optional,
    Tuple,
    Union,
)

from .dataset import initialize_dataset
from .model import initialize_model
from .tokenizer import initialize_tokenizer
from ..models.utils import load_checkpoint


def initialize(
    args: argparse.Namespace,
    config: Dict[str, Union[Dict, int, str]],
) -> Tuple[
        Union[DeepSpeedEngine, InferenceEngine, PreTrainedModel],
        PreTrainedTokenizerBase,
        Optional[DeepSpeedOptimizer],
        Optional[DeepSpeedDataLoader],
        Optional[WarmupLR],
        Optional[Dict[str, Union[bool, int, float]]],
]:
    """ Initialize the necessary components.

    Args:
        args (argparse.Namespace): The arguments.
        config (Dict[str, Union[Dict, int, str]]): The configuration.

    Returns:
        Tuple[ Union[DeepSpeedEngine, InferenceEngine, PreTrainedModel], PreTrainedTokenizerBase, Optional[DeepSpeedOptimizer], Optional[DeepSpeedDataLoader], Optional[WarmupLR], Optional[Dict[str, Union[bool, int, float]]], ]: The model, the tokenizer, the optimizer, the dataloader, the scheduler, and the client state.
    """

    # Get the local rank.
    local_rank = dist.get_rank() if dist.is_initialized() else 0

    # Initialize the multi-process logger.
    try:
        from ..logging import get_logger
        from ..utils import get_path

        logger = get_logger()
        source = f'{get_path(source_file=__file__)}.{initialize_model.__name__}'
    except:
        pass

    device = torch.device(device=local_rank)

    # Initialize the model.
    model, tokenizer = initialize_model(
        base_model_name=config['model']['base_name'],
        model_type=args.model_type,
        chosen_expert_index=args.chosen_expert_index,
        lora_config_path=args.lora_config,
        model_config_path=args.model_config,
        freeze_modules=config['model']['freeze_modules'],
        cache_dir=os.path.join(
            config['lfs_path'],
            'caches',
        ),
        is_training=True if args.mode == 'train' else False,
        is_original_upcycling=args.is_original_upcycling,
        use_checkpoint=config['checkpoint']['use_checkpoint'],
        checkpoint_dir=config['checkpoint']['load_path']['dir'],
        checkpoint_tag=config['checkpoint']['load_path']['tag'],
    )

    ## Initialize the dataset and dataloader.
    dataset = None
    collate_fn = None
    dataloader = None
    if args.mode != 'chat' and args.mode != 'test':
        dataset = initialize_dataset(
            config=config,
            mode=args.mode,
        )

        collate_fn = dataset.get_collate_fn(
            config=config,
            tokenizer=tokenizer,
        )

        # TODO: Infer mode needs to be checked and fixed.
        if args.mode == 'infer':
            generator = torch.Generator(device=device)

            dataloader = DataLoader(
                dataset=dataset,
                batch_size=config['infer']['batch_size'],
                collate_fn=collate_fn,
                generator=generator,
            )
    ## -----

    # Load the PyTorch checkpoint.
    if config['checkpoint']['use_checkpoint'] and \
            config['checkpoint']['load_checkpoint_type'] == 'pytorch':
        state_dict = get_fp32_state_dict_from_zero_checkpoint(
            checkpoint_dir=config['checkpoint']['load_path']['dir'],
            tag=config['checkpoint']['load_path']['tag'],
        )
        model.load_state_dict(
            state_dict=state_dict,
            strict=False,
            assign=False,
        )

    optimizer = None
    scheduler = None
    client_state = None

    match args.mode:
        # TODO: Chat and infer mode needs to be checked and fixed.
        case 'chat' | 'infer':
            model = deepspeed.init_inference(
                model=model,
                config=args.deepspeed_config,
            )
        case 'train':
            # Use 8 bit optimizer.
            # -----
            # try:
            #     import bitsandbytes as bnb

            #     updatable_params = [
            #         param for param in model.parameters() \
            #             if param.requires_grad
            #     ]
            #     optimizer = bnb.optim.AdamW8bit(
            #         params=updatable_params,
            #         lr=1e-3,
            #         betas=(0.9, 0.999),
            #         eps=1e-8,
            #         weight_decay=1e-2,
            #         optim_bits=8,
            #         # Uncomment if you are training with LoRA.
            #         # min_8bit_size=0,
            #     )
            # except Exception as e:
            #     raise e

            # model, optimizer, dataloader, scheduler = deepspeed.initialize(
            #     model=model,
            #     optimizer=optimizer,
            #     model_parameters=model.parameters(),
            #     training_data=dataset,
            #     collate_fn=collate_fn,
            #     config=args.deepspeed_config,
            # )
            # -----

            model, optimizer, dataloader, scheduler = deepspeed.initialize(
                model=model,
                model_parameters=model.parameters(),
                training_data=dataset,
                collate_fn=collate_fn,
                config=args.deepspeed_config,
            )

            # Load the DeepSpeed checkpoint.
            if config['checkpoint']['use_checkpoint'] and \
                    config['checkpoint']['load_checkpoint_type'] == 'deepspeed':
                model, client_state = load_checkpoint(
                    model=model,
                    checkpoint_dir=config['checkpoint']['load_path']['dir'],
                    tag=config['checkpoint']['load_path']['tag'],
                )
        case 'test':
            if hasattr(model, 'merge_and_unload'):
                model = model.merge_and_unload()

                message = 'Adapters are merged and unloaded.'
                logger.log(
                    message=message,
                    level='info',
                    source=source,
                )
            else:
                message = 'The model does not have the merge_and_unload method.'
                logger.log(
                    message=message,
                    level='warning',
                    source=source,
                )
        case _:
            message = 'The mode is not recognized.'
            logger.log(
                message=message,
                level='error',
                source=source,
            )

            raise ValueError(message)

    gc.collect()
    torch.cuda.empty_cache()

    return (
        model,
        tokenizer,
        optimizer,
        dataloader,
        scheduler,
        client_state,
    )
