"""Provides the internal functions used by the LLM client.

This module contains utility functions that handle local training and evaluation for
Large Language Model (LLM) clients in a federated learning setup. It sets up and
configures Trainers, loads checkpoints, applies parameters, runs training or evaluation,
gathers metrics, and cleans up resources.
"""

import copy
import gc
import time
from logging import DEBUG, ERROR
from typing import Any

import torch

# NOTE: We need this if we want to compile the model because the attention
# implementation in the MPT code is dispatched using a dictionary that raises:
# `AssertionError: Dict types must use ConstDictVariable.`
import torch._dynamo
from composer import Trainer
from composer.utils import dist
from flwr.common.logger import log
from flwr.common.recordset_compat import ConfigsRecord
from flwr.common.typing import NDArrays, Scalar
from llmfoundry.utils.config_utils import TrainConfig
from omegaconf import DictConfig

from repo.clients.configs import EvaluateConfig, FitConfig
from repo.clients.llm_config_functions import check_skip_init, get_train_config
from repo.clients.trainer_utils import (
    correct_time_before_fit,
    get_trainer_mutables_from_config,
    get_trainer_object,
    load_trainer_checkpoint,
    set_mutables_trainer,
    trainer_clean_up,
)
from repo.clients.utils import (
    get_client_state_struct,
    get_ndarrays_and_names_from_payload,
    manipulate_pre_training_ndarrays,
    post_process_client_result,
    set_initial_config_from_fit_config,
    streaming_shms_clean_up,
)
from repo.shm.utils import compress_with_strict
from repo.utils import (
    ModelStateNames,
    set_trainer_params_from_ndarrays,
)

torch._dynamo.config.suppress_errors = True  # type: ignore[reportAttributeAccessIssue]  # noqa: SLF001


def llm_fit(  # noqa: PLR0914, PLR0915
    external_objects: tuple[
        Trainer | None,
        FitConfig | EvaluateConfig | None,
        TrainConfig | None,
    ],
    payload: NDArrays,
    config: ConfigsRecord,
    llm_config: DictConfig,
) -> tuple[
    NDArrays,
    int,
    dict[str, Scalar] | dict[Any, Any],
    tuple[Trainer | None, FitConfig | EvaluateConfig | None, TrainConfig | None],
]:
    """Perform local training on the LLM client.

    This function sets up the model and trainer environment, handles checkpoint
    loading if available, sets model parameters, performs local training, and
    returns the updated model parameters, the number of trained samples, training
    metrics, and the configuration objects used for the training process.

    The function uses a skip initialization optimization that avoids recreating the
    trainer's mutable attributes and reloading checkpoints when the same client is
    used across consecutive rounds.

    Parameters
    ----------
    external_objects : tuple[
        Trainer | None,
        FitConfig | EvaluateConfig | None,
        TrainConfig | None
    ]
        A tuple containing:
        - An optional external Trainer object. If provided, it will be used instead
          of creating a new Trainer.
        - An optional external FitConfig or EvaluateConfig from a previous round.
        - An optional external TrainConfig from a previous round.
        These objects are used to optimize initialization when the same client is used
        in consecutive rounds.
    payload : NDArrays
        The model parameters to be loaded into the trainer before training.
    config : ConfigsRecord
        Configuration information for federated training.
    llm_config : DictConfig
        The local LLM configuration specifying trainer and model details.

    Returns
    -------
    tuple[NDArrays, int, dict[str, Scalar] | dict[Any, Any],
          tuple[Trainer | None, FitConfig | EvaluateConfig | None, TrainConfig | None]]
        A tuple containing:
        - The updated model parameters after training.
        - The number of samples used during training.
        - A dictionary of training metrics.
        - A tuple containing the final Trainer object, FitConfig, and TrainConfig for
          potential reuse in future rounds.

    """
    # Interpret external objects
    external_trainer, ext_fit_config, ext_train_cfg = external_objects
    try:
        fit_config = FitConfig(**config)  # type: ignore[reportArgumentType,arg-type]

        assert fit_config.transmission_mask is not None, "Transmission mask must be set"

        client_state_struct, start_time = (
            get_client_state_struct(fit_config),
            time.time_ns(),
        )
        train_metrics: dict[str, Scalar] = {}

        skip_iteration, checkpoint_exists, server_steps_cumulative = (
            set_initial_config_from_fit_config(fit_config, llm_config)
        )
        dummy_config = copy.deepcopy(llm_config)

        if external_trainer is not None:
            log(DEBUG, "External trainer object exists.")
            train_cfg, device, logged_cfg, icl_tasks_config_dict = get_train_config(
                cfg=llm_config,
                cid=fit_config.cid,
            )
            # Performs the heavy initialization only if strictly necessary
            if not check_skip_init(
                ext_fit_config,
                ext_train_cfg,
                fit_config,
                train_cfg,
            ):
                # Assign the device microbatch size from the previous execution
                assert ext_train_cfg is not None
                train_cfg.device_train_microbatch_size = (
                    ext_train_cfg.device_train_microbatch_size
                )
                trainer_mutable_attributes = get_trainer_mutables_from_config(
                    trainer=external_trainer,
                    train_cfg=train_cfg,
                    client_config=fit_config,
                    icl_tasks_config_dict=icl_tasks_config_dict,
                    device=device,
                    logged_cfg=logged_cfg,
                )
                set_mutables_trainer(
                    external_trainer,
                    trainer_mutable_attributes,
                    client_config=fit_config,
                )
                if checkpoint_exists:
                    log(DEBUG, "Checkpoint exists.")
                    # TODO(<Anonymous>): Add to the ignore keys all the keys related to
                    # model parameters
                    load_trainer_checkpoint(external_trainer, train_cfg)
        else:
            log(DEBUG, "External trainer object doesn't exit.")
            (
                external_trainer,
                train_cfg,
                _,
            ) = get_trainer_object(
                cfg=llm_config,
                client_config=fit_config,
            )

        # Retrieve the transmission mask
        (mask_transmission, layer_names_transmission, layer_types_transmission) = (
            fit_config.transmission_mask
        )

        manipulate_pre_training_ndarrays(
            payload=list(
                compress_with_strict(
                    data=payload,
                    selectors=mask_transmission,
                    strict=True,
                ),
            ),
            trainer=external_trainer,
            configs=(fit_config, dummy_config),
            client_state_struct=client_state_struct,
            train_metrics=train_metrics,
        )

        # NOTE: Extract parameters if our payload also contains momenta

        # Extract the trainable parameters
        # and their names from the payload
        parameters, parameter_names = get_ndarrays_and_names_from_payload(
            payload=list(
                compress_with_strict(
                    data=payload,
                    selectors=mask_transmission,
                    strict=True,
                ),
            ),
            layer_names=layer_names_transmission,
            layer_types=layer_types_transmission,
            model_state_type=ModelStateNames.PARAMETERS,
        )

        log(DEBUG, "Trainer object obtained.")
        train_metrics["client/fit_init_time"] = (time.time_ns() - start_time) * 1e-9

        if not skip_iteration and fit_config.n_local_steps > 0:
            correct_time_before_fit(
                trainer=external_trainer,
                fit_config=fit_config,
                server_steps_cumulative=server_steps_cumulative,
                client_state_struct=client_state_struct,
            )

            start_time = time.time_ns()
            set_trainer_params_from_ndarrays(
                parameters,
                external_trainer,
                fit_config.set_trainer_key_to_filter,
                parameters_names=parameter_names,
                filter_keys=fit_config.set_trainer_params_filter_keys,
                excluded_layers=fit_config.personalized_layers or [],
                frozen_layers=fit_config.frozen_layers,
            )

            # NOTE: We removed here the parameter checkers for efficiency reasons.
            # Re-add them if needed

            train_metrics["client/fit_set_parameters_time"] = (
                time.time_ns() - start_time
            ) * 1e-9

            if train_cfg.eval_first:
                # TODO(<Anonymous>): Modify metric names in MosaicML
                # to allow logging the first eval
                start_time = time.time_ns()
                external_trainer.eval()
                train_metrics["client/fit_pre_eval_time"] = (
                    time.time_ns() - start_time
                ) * 1e-9
                train_metrics.update(
                    {
                        f"PrePersonalization{k}": v.detach().cpu().item()  # type: ignore[attr-defined]
                        for k, v in external_trainer.state.eval_metric_values.items()
                    },
                )

            try:
                start_time = time.time_ns()
                external_trainer.fit(
                    duration=(
                        0 if skip_iteration else str(fit_config.n_local_steps) + "ba"
                    ),
                )
                # If automicrobatching is enabled, try to get the final microbatch size
                # and set it for the next steps
                if external_trainer.state.auto_microbatching:
                    external_trainer.state.auto_microbatching = False
                    final_device_train_microbatch_size = (
                        external_trainer.state.device_train_microbatch_size
                    )
                    assert isinstance(final_device_train_microbatch_size, int)
                    train_cfg.device_train_microbatch_size = (
                        final_device_train_microbatch_size
                    )
                train_metrics["client/fit_time"] = (time.time_ns() - start_time) * 1e-9
            except Exception as e:
                log(ERROR, "llm_fit::trainer.fit", exc_info=e, stack_info=True)
                raise

        trained_model_states, n_samples_trained = post_process_client_result(
            train_metrics=train_metrics,
            trainer=external_trainer,
            payload=payload,
            client_state_struct=client_state_struct,
            llm_config=llm_config,
            fit_config=fit_config,
        )

    except Exception as e:
        log(ERROR, "Error in llm_fit function", exc_info=e, stack_info=True)
        raise
    else:
        return (
            trained_model_states,
            n_samples_trained or 1,
            train_metrics,
            (external_trainer, fit_config, train_cfg),
        )


def llm_eval(
    external_trainer: Trainer | None,
    payload: NDArrays,
    config: ConfigsRecord,
    llm_config: DictConfig,
) -> tuple[float, int, dict[str, Scalar], Trainer]:
    """Perform local evaluation on the LLM client.

    This function sets up the trainer environment for evaluation, applies model
    parameters, runs the evaluation, collects evaluation metrics, and cleans up
    resources.

    Parameters
    ----------
    external_trainer : Trainer | None
        An optional external Trainer object. If provided, it will be used instead
        of creating a new Trainer.
    payload : NDArrays
        The model parameters to be used for evaluation.
    config : ConfigsRecord
        Configuration information for federated evaluation.
    llm_config : DictConfig
        The local LLM configuration specifying trainer and model details.

    Returns
    -------
    tuple[float, int, dict[str, Scalar], Trainer]
        A tuple containing:
        - The (unused) loss value set to 0.0 in this function.
        - The number of samples evaluated.
        - A dictionary of evaluation metrics.
        - The final Trainer object.

    """
    start_time = time.time_ns()
    num_samples = 0
    eval_metrics: dict[str, Scalar] = {}

    llm_config.autoresume = False  # type: ignore[union-attr]
    llm_config.save_folder = None  # type: ignore[union-attr]
    llm_config.load_path = None  # type: ignore[union-attr]
    llm_config.loggers = None  # type: ignore[union-attr]

    client_eval_config = EvaluateConfig(**dict(config))  # type: ignore[reportArgumentType,arg-type]
    assert client_eval_config.transmission_mask is not None, (
        "Transmission mask must be set"
    )

    (
        external_trainer,
        _,
        _,
    ) = get_trainer_object(
        cfg=llm_config,
        client_config=client_eval_config,
    )

    # NOTE: Extract parameters
    _, layer_names, layer_types = client_eval_config.transmission_mask

    # NOTE: Extract only the parameters
    # for eval, no momenta needed
    # since we are not training
    parameters, _parameter_names = get_ndarrays_and_names_from_payload(
        payload,
        layer_names=layer_names,
        layer_types=layer_types,
        model_state_type=ModelStateNames.PARAMETERS,
    )

    eval_metrics["client/eval_init_time"] = (time.time_ns() - start_time) * 1e-9

    start_time = time.time_ns()
    set_trainer_params_from_ndarrays(
        parameters,
        external_trainer,
        client_eval_config.set_trainer_key_to_filter,
        filter_keys=client_eval_config.set_trainer_params_filter_keys,
        excluded_layers=[],
    )

    gc.collect()
    torch.cuda.empty_cache()
    eval_metrics["client/eval_set_parameters_time"] = (
        time.time_ns() - start_time
    ) * 1e-9

    start_time = time.time_ns()
    external_trainer.eval()
    eval_metrics["client/eval_time"] = (time.time_ns() - start_time) * 1e-9
    start_time = time.time_ns()
    if dist.get_global_rank() == 0:
        num_samples = (
            external_trainer.state.eval_timestamp._sample.value  # noqa: SLF001
        )
        eval_metrics.update(
            {
                "Val" + k: v.detach().cpu().item()  # type: ignore[attr-defined]
                for k, v in external_trainer.state.eval_metric_values.items()
            },
        )
        eval_metrics["client/eval_metrics_collection_time"] = (
            time.time_ns() - start_time
        ) * 1e-9

    start_time = time.time_ns()
    trainer_clean_up(trainer=external_trainer)
    streaming_shms_clean_up()
    if dist.get_global_rank() == 0:
        eval_metrics["client/eval_trainer_closing_time"] = (
            time.time_ns() - start_time
        ) * 1e-9

    return 0.0, num_samples, eval_metrics, external_trainer
