import gc
import torch
import torch.distributed as dist

from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.engine import DeepSpeedEngine
from deepspeed.runtime.lr_schedules import WarmupLR
from tqdm import tqdm
from typing import (
    Dict,
    Optional,
    Union,
)

from .logging import get_logger
from .models.utils import save_checkpoint
from .utils import (
    get_path,
    set_device,
)


def train(
    model: DeepSpeedEngine,
    optimizer: DeepSpeedOptimizer,
    dataloader: DeepSpeedDataLoader,
    config: Dict[str, Union[Dict, int, str]],
    scheduler: Optional[WarmupLR] = None,
    trained_epochs: int = -1,
    trained_steps: int = -1,
    finishing_this_epoch: bool = False,
) -> None:
    """ Train the model.

    Args:
        model (DeepSpeedEngine): The model.
        optimizer (DeepSpeedOptimizer): The optimizer.
        dataloader (DeepSpeedDataLoader): The dataloader.
        config (Dict[str, Union[Dict, int, str]]): The configuration.
        scheduler (Optional[WarmupLR], optional): The scheduler. Defaults to None.
        trained_epochs (int, optional): The trained epochs. Defaults to -1.
        trained_steps (int, optional): The trained steps. Defaults to -1.
        finishing_this_epoch (bool, optional): Whether finishing this epoch. Defaults to False.
    """
    SIMILARITY_AUX_LOSS_WARNING = 1

    # Get the local rank.
    local_rank = dist.get_rank() if dist.is_initialized() else 0

    # Initialize the multi-process logger.
    logger = get_logger()
    source = f'{get_path(source_file=__file__)}.{train.__name__}'

    device = model.device

    epochs = config['train']['epochs']
    log_time = int(config['log']['time'])
    save_epoch = config['checkpoint']['save_epoch'] \
        if config['checkpoint']['save_epoch'] else -1
    full_save_step = config['checkpoint']['full_save_step'] \
        if config['checkpoint']['full_save_step'] else -1
    partial_save_step = config['checkpoint']['partial_save_step'] \
        if config['checkpoint']['partial_save_step'] else -1

    # Disable the tqdm if the local rank is not 0.
    tqdm_disable = False if local_rank == 0 else True

    # Ajdust the trained epochs and steps.
    if finishing_this_epoch:
        trained_steps = -1
    else:
        if trained_epochs != -1:
            trained_epochs -= 1

    # Skip the trained epochs.
    for epoch in range(
        (trained_epochs + 1),
            epochs,
    ):
        model.train()

        batch_loss_list = []

        t = tqdm(
            iterable=dataloader,
            desc=f'[Epoch {epoch}/{epochs - 1}]',
            disable=tqdm_disable,
            dynamic_ncols=True,
        )

        for step, batch in enumerate(iterable=t):
            # Skip the trained steps.
            if step <= trained_steps:
                continue

            batch = set_device(
                batch=batch,
                device=device,
            )

            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels'],
            )

            # Only for debugging.
            # -----
            # batch_chosen_experts_indices = []
            # if hasattr(outputs, 'router_logits') and \
            #         outputs.router_logits is not None:
            #     for layer_idx, layer_router_logits in enumerate(
            #             outputs.router_logits):
            #         chosen_experts_indices = torch.topk(
            #             input=layer_router_logits,
            #             k=model.molos_config.selected_ex_num,
            #             dim=-1,
            #         ).indices
            #         layer_counts = torch.bincount(
            #             input=chosen_experts_indices.view(-1),
            #             minlength=model.molos_config.ex_num,
            #         )

            #         if (layer_counts == 0).any():
            #             logger.log(
            #                 message=
            #                 f'[rank{local_rank}] Layer{layer_idx:02d}: {layer_counts.tolist()}',
            #                 source=source,
            #             )

            #         batch_chosen_experts_indices.append(
            #             chosen_experts_indices.view(-1))

            #     batch_chosen_experts_indices = torch.cat(
            #         tensors=batch_chosen_experts_indices)
            #     batch_chosen_experts_counts = torch.bincount(
            #         input=batch_chosen_experts_indices,
            #         minlength=model.molos_config.ex_num,
            #     )

            #     logger.log(
            #         message=
            #         f'[rank{local_rank}] All Layers: {batch_chosen_experts_counts.tolist()}',
            #         source=source,
            #     )
            # -----

            model_loss = outputs.loss.item()
            loss = outputs.loss

            model.backward(loss=loss)
            model.step()

            batch_loss_list.append(model_loss)

            if hasattr(outputs, 'balance_loss') and \
                    outputs.balance_loss is not None:
                router_aux_loss_str = f'{outputs.balance_loss.item():.3f}'
            elif hasattr(outputs, 'router_aux_loss') and \
                    outputs.router_aux_loss is not None:
                router_aux_loss_str = f'{outputs.router_aux_loss.item():.3f}'
            else:
                router_aux_loss_str = 'None'

            if hasattr(outputs, 'router_z_loss') and \
                    outputs.router_z_loss is not None:
                router_z_loss_str = f'{outputs.router_z_loss.item():.3f}'
            else:
                router_z_loss_str = 'None'

            # MoLoS Model has no similarity loss.
            try:
                if outputs.similarity_aux_loss is not None:
                    similarity_aux_loss_str = \
                        f'{outputs.similarity_aux_loss.item():.3f}'
                else:
                    similarity_aux_loss_str = 'None'
            except AttributeError as e:
                if SIMILARITY_AUX_LOSS_WARNING:
                    logger.log(
                        message=
                        'The \'outputs.similarity_aux_loss\' is not an available attribute.',
                        level='warning',
                        source=source,
                    )

                    SIMILARITY_AUX_LOSS_WARNING = 0

                similarity_aux_loss_str = 'None'

            if step % log_time == 0 or step == (trained_steps + 1):
                logger.log(
                    message=
                    f'Epoch {epoch} | Step {step} | LR: {optimizer.param_groups[0]["lr"]:.3e} | M Loss: {model_loss:.3f} | LB Loss: {router_aux_loss_str} | Z Loss: {router_z_loss_str} | S Loss: {similarity_aux_loss_str} | Epoch Loss: {(sum(batch_loss_list) / len(batch_loss_list)):.3f}',
                    source=source,
                )

            t.set_postfix({
                # 'LR' means 'Learning Rate'.
                'LR':
                f'{optimizer.param_groups[0]["lr"]:.3e}',
                # 'M' means 'Model'.
                'M Loss':
                f'{model_loss:.3f}',
                # 'LB' means 'Load Balance'.
                'LB Loss':
                f'{router_aux_loss_str}',
                'Z Loss':
                f'{router_z_loss_str}',
                # 'S' means 'Similarity'.
                'S Loss':
                f'{similarity_aux_loss_str}',
                # 'E' means 'Epoch'.
                'E Loss':
                f'{(sum(batch_loss_list) / len(batch_loss_list)):.3f}',
            })

            if (step != 0) and \
                    (full_save_step != -1) and \
                    ((step + 1) != len(dataloader)) and \
                    ((step + 1) + full_save_step <= len(dataloader)) and \
                    ((step + 1) % full_save_step == 0):
                save_checkpoint(
                    model=model,
                    # Actually, the trained_epochs should be epoch - 1.
                    # Because before all the steps finishing, the epoch has not been finished.
                    trained_epochs=epoch,
                    trained_steps=step,
                    finishing_this_epoch=False,
                    outputs_dir=config['outputs_dir'],
                    exclude_frozen_parameters=False,
                    local_rank=local_rank,
                )
            elif (step == 0) or (
                    (partial_save_step != -1) and \
                    ((step + 1) != len(dataloader)) and \
                    ((step + 1) + partial_save_step <= len(dataloader)) and \
                    ((step + 1) % partial_save_step == 0)
            ):
                save_checkpoint(
                    model=model,
                    # Actually, the trained_epochs should be epoch - 1.
                    # Because before all the steps finishing, the epoch has not been finished.
                    trained_epochs=epoch,
                    trained_steps=step,
                    finishing_this_epoch=False,
                    outputs_dir=config['outputs_dir'],
                    exclude_frozen_parameters=True,
                    local_rank=local_rank,
                )

        logger.log(
            message=
            f'Epoch {epoch} | Epoch Loss: {(sum(batch_loss_list) / len(batch_loss_list)):.3f}',
            source=source,
        )

        if (save_epoch != -1) and ((epoch + 1) % save_epoch == 0):
            save_checkpoint(
                model=model,
                # Actually, the trained_epochs should be epoch - 1.
                # Because before all the steps finishing, the epoch has not been finished.
                trained_epochs=epoch,
                trained_steps=step,
                finishing_this_epoch=True,
                outputs_dir=config['outputs_dir'],
                exclude_frozen_parameters=False,
                local_rank=local_rank,
            )

        gc.collect()
        torch.cuda.empty_cache()
