"""Eval mosaicml checkpoints with accelerate."""

import argparse
from logging import INFO
from pathlib import Path

import vllm
import vllm.model_executor
import vllm.model_executor.models
import vllm.model_executor.models.mpt
from flwr.common import log
from lighteval.main_accelerate import main as main_accelerate
from lighteval.parsers import parser_accelerate

# Necessary to avoid pyright complaint
from repo.eval_light.patched_mpt import (  # type: ignore[attr-defined]
    MPTMLP,  # type: ignore[attr-defined]
    MPTAttention,  # type: ignore[attr-defined]
    MPTBlock,  # type: ignore[attr-defined]
    MPTForCausalLM,  # type: ignore[attr-defined]
    MPTModel,  # type: ignore[attr-defined]
)

# NOTE: patching because VLLM does not support
# rope, preferable to forking and handling package wheel
# type ignore for assigning to a class
vllm.model_executor.models.mpt.MPTAttention = MPTAttention  # type: ignore[misc]
vllm.model_executor.models.mpt.MPTMLP = MPTMLP  # type: ignore[misc]
vllm.model_executor.models.mpt.MPTBlock = MPTBlock  # type: ignore[misc]
vllm.model_executor.models.mpt.MPTModel = MPTModel  # type: ignore[misc]
vllm.model_executor.models.mpt.MPTForCausalLM = MPTForCausalLM  # type: ignore[misc]


def modify_args_for_accelerate(args: argparse.Namespace) -> argparse.Namespace:
    """Modify the args namespace for accelerate.

    Parameters
    ----------
    args : argparse.Namespace
        The command-line arguments.

    Returns
    -------
    str
        The modified model arguments string.

    Raises
    ------
    ValueError
        If the output directory already exists.

    """
    checkpoint: str = args.checkpoint
    wte_checkpoint: str | None = args.wte_checkpoint

    new_output_dir = Path(args.output_dir) / checkpoint
    if wte_checkpoint is not None:
        new_output_dir /= wte_checkpoint

    new_output_dir /= "vllm" if "vllm" in args.model_args else "no_vllm"

    log(INFO, f"Output directory: {new_output_dir}")

    # Update args only if the output directory does not exist.
    if not new_output_dir.exists():
        args.output_dir = str(new_output_dir)
        args.model_args = f"pretrained={args.safe_tensors_dir},{args.model_args}"
    else:
        msg = f"Output directory {new_output_dir} already exists."
        raise ValueError(msg)
    return args


if __name__ == "__main__":
    default_parser: argparse.ArgumentParser = parser_accelerate()
    default_parser.add_argument(
        "--checkpoint",
        type=str,
        required=True,
        help="Checkpoint to evaluate",
    )
    default_parser.add_argument(
        "--wte_checkpoint",
        type=str,
        required=False,
        help="Checkpoint to load an embedding from",
        default=None,
    )
    # Add a default directory in which to save the model safetensors
    default_parser.add_argument(
        "--safe_tensors_dir",
        type=str,
        default="safe_tensors_test",
        help="Directory in which to save the model safetensors",
    )

    default_args = default_parser.parse_args()

    args = modify_args_for_accelerate(default_args)

    main_accelerate(args)
