"""Download a model checkpoint, load its state.

Supports optional pretrained embeddings, and save the pretrained model to a directory.
"""

import argparse
import os
from logging import INFO
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any

import torch
from composer.loggers import RemoteUploaderDownloader
from composer.utils.checkpoint import _torch_load_with_validation  # noqa: PLC2701
from flwr.common import log
from llmfoundry.utils.builders import (
    build_composer_model,
    build_tokenizer,
)
from llmfoundry.utils.config_utils import (
    process_init_device,
)

from repo.conf.constants import ModelConfig, Tokenizers
from repo.file_utils import create_remote_up_down, download_file_from_s3
from repo.utils import (
    load_model_parameters_from_file,
)

os.environ["S3_ENDPOINT_URL"] = "http://anonymous.anonymous.:9000"


def load_model_state(  # noqa: PLR0913, PLR0917
    model_state: dict[str, torch.Tensor],
    local_path: Path,
    parameter_names: list[str],
    temp_dir: TemporaryDirectory[str],
    up_down: RemoteUploaderDownloader,
    wte_checkpoint: str | None = None,
) -> dict[str, torch.Tensor]:
    """Load the model state from a local file or a checkpoint.

    Parameters
    ----------
    model_state : dict[str, torch.Tensor]
        The initial model state.
    local_path : Path
        The local path to the checkpoint.
    parameter_names : list[str]
        The names of the parameters.
    temp_dir : TemporaryDirectory[str]
        The temporary directory.
    up_down : RemoteUploaderDownloader
        The remote uploader and downloader.
    wte_checkpoint : str, optional
        The checkpoint to load the word token embedding from, by default None.

    Returns
    -------
    dict[str, torch.Tensor]
        The loaded model state.

    """
    if local_path.suffix == ".pt":
        model_state = _torch_load_with_validation(
            local_path,
            map_location="cpu",
        )["state"]["model"]
    elif local_path.suffix == ".npz":
        server_model = load_model_parameters_from_file(local_path)
        for name, array in zip(
            parameter_names,
            server_model,
            strict=False,
        ):
            model_state[name] = torch.tensor(
                array,
                dtype=model_state[name].dtype,
                device=model_state[name].device,
            )

    if wte_checkpoint is not None:
        wte_temp_dir_path = Path(temp_dir.name) / wte_checkpoint
        wte_local_path = wte_temp_dir_path
        download_file_from_s3(
            up_down,
            wte_checkpoint,
            wte_local_path,
        )
        if wte_local_path.suffix == ".pt":
            wte_state: dict[str, torch.Tensor] = _torch_load_with_validation(
                wte_local_path,
                map_location="cpu",
            )["state"]["model"]
            # For any key containing "wte" in the name, load from wte_state
            for key in model_state:
                if "wte" in key:
                    model_state[key] = wte_state[key]
                    log(INFO, f"Loaded {key} from {wte_local_path}")
        else:
            wte_server_model = load_model_parameters_from_file(wte_local_path)
            for name in parameter_names:
                if "wte" in name:
                    model_state[name] = torch.tensor(
                        wte_server_model[0],
                        dtype=model_state[name].dtype,
                        device=model_state[name].device,
                    )
                    log(INFO, f"Loaded {name} from {wte_local_path}")
                    break
    return model_state


def download_and_save_pretrained_model(
    checkpoints: tuple[str, str | None],
    safe_tensors_dir: str,
    checkpoints_dir: str,
    model_type: str,
    tokenizer: str,
) -> None:
    """Build the model, download its checkpoint, save.

    Parameters
    ----------
    checkpoints : tuple[str, str | None]
        The checkpoint to download and
        the checkpoint to load the word token embedding from.
    safe_tensors_dir : str
        The directory in which to save the model's safetensors.
    checkpoints_dir : str
        The directory in which the checkpoints are stored.
    model_type : str
        The type of the model.
    tokenizer : str
        The type of the tokenizer.

    """
    checkpoint, wte_checkpoint = checkpoints
    # Build remote up_down client.
    up_down = create_remote_up_down(
        bucket_name=checkpoints_dir,
        prefix="",
        run_uuid="download",
        num_attempts=3,
        client_config={"connect_timeout": 3600, "read_timeout": 3600},
    )

    # Build model.
    model_config: dict[str, Any] = dict(ModelConfig[model_type].value)
    tokenizer_conf: dict[str, Any] = dict(Tokenizers[tokenizer].value)

    model = build_composer_model(
        name=str(model_config["name"]),
        cfg=model_config,
        tokenizer=build_tokenizer(
            str(tokenizer_conf["name"]),
            tokenizer_conf["kwargs"],
        ),
        init_context=process_init_device(model_config, None),
        master_weights_dtype=None,
    )
    general_temp_dir = TemporaryDirectory()
    os.environ["TMPDIR"] = general_temp_dir.name
    temp_dir = TemporaryDirectory()
    log(INFO, "Temp dir: %s", temp_dir.name)

    temp_dir_path = Path(temp_dir.name) / checkpoint
    local_path = temp_dir_path

    # The below 'download_and_save_model' routine is inlined here:
    model_state = model.state_dict()
    parameter_names = sorted(model_state.keys())

    # Download the checkpoint.
    download_file_from_s3(
        up_down,
        checkpoint,
        local_path,
    )

    # Load state from the checkpoint.
    new_model_state = load_model_state(
        model_state,
        local_path,
        parameter_names,
        temp_dir,
        up_down,
        wte_checkpoint,
    )
    model.load_state_dict(new_model_state)
    model.model.save_pretrained(safe_tensors_dir)  # type: ignore[reportAttributeAccessIssue]

    # Save tokenizer.
    tokenizer_object = build_tokenizer(
        str(tokenizer_conf["name"]),
        tokenizer_conf["kwargs"],
    )
    tokenizer_object.save_pretrained(safe_tensors_dir)


if __name__ == "__main__":
    default_parser: argparse.ArgumentParser = argparse.ArgumentParser()
    default_parser.add_argument(
        "--checkpoint_dir",
        type=str,
        required=True,
        help="Checkpoint directory",
    )
    default_parser.add_argument(
        "--checkpoint",
        type=str,
        required=True,
        help="Checkpoint to evaluate",
    )
    default_parser.add_argument(
        "--wte_checkpoint",
        type=str,
        required=False,
        help="Checkpoint to load an embedding from",
        default=None,
    )
    default_parser.add_argument(
        "--tokenizer",
        type=str,
        required=False,
        help="Checkpoint to load an embedding from",
        default=None,
    )
    default_parser.add_argument(
        "--model_type",
        type=str,
        required=False,
        help="Checkpoint to load an embedding from",
        default="SMOLLM_1B",
    )
    # Add a default directory in which to save the model safetensors
    default_parser.add_argument(
        "--safe_tensors_dir",
        type=str,
        default="safe_tensors_test",
        help="Directory in which to save the model safetensors",
    )
    args = default_parser.parse_args()

    # Print args
    log(INFO, "Augments: %s", args)

    download_and_save_pretrained_model(
        checkpoints=(args.checkpoint, args.wte_checkpoint),
        safe_tensors_dir=args.safe_tensors_dir,
        checkpoints_dir=args.checkpoint_dir,
        model_type=args.model_type,
        tokenizer=args.tokenizer,
    )
