"""Embed text values using a pre-trained model."""

from __future__ import annotations

import datetime
import json
import logging
import sys
from pathlib import Path

import h5py
import pytz
import typer

from pathfmtools.embedding_models import get_embedding_model
from pathfmtools.io.schema import SCHEMA_VERSION
from pathfmtools.utils.devices import parse_device
from pathfmtools.utils.model_id import canon_model_id

logger = logging.getLogger(__name__)

TIMESTAMP = datetime.datetime.now(tz=pytz.timezone("US/Eastern")).strftime(
    "%Y-%m-%d-%H-%M-%S-%f-%Z",
)


app = typer.Typer()


@app.command()
def cli(
    model_name: str = typer.Option(..., help="Name of the model to use for embedding"),
    text_fpath: Path = typer.Option(..., help="Path to file containing text values to embed"),
    device: str = typer.Option(..., help="Device to use for embedding"),
    out_dir: Path = typer.Option(..., help="Directory to save output embeddings"),
) -> None:
    """Embed text values using a pre-trained model."""
    # Write run information to a JSON file to help with reproducibility.
    run_id = f"embed_text_device_{device}_{TIMESTAMP}"

    run_info_dir = out_dir / "run_info"
    run_info_dir.mkdir(exist_ok=True)
    info_fpath = run_info_dir / f"{run_id}.json"

    configure_root_logger(
        log_file=run_info_dir / f"{run_id}.log",
    )

    info_dict: dict[str, int | str | list] = {
        "model_name": model_name,
        "text_fpath": str(text_fpath),
        "device": device,
        "out_dir": str(out_dir),
    }
    info_dict["command_executed"] = sys.argv

    with info_fpath.open("w") as out_f:
        json.dump(info_dict, out_f)

    main(
        model_name=model_name,
        text_fpath=str(text_fpath),
        device=device,
        out_dir=out_dir,
    )


def main(
    model_name: str,
    text_fpath: str,
    device: str,
    out_dir: Path,
) -> None:
    """Embed text values and write embeddings to an HDF5 file.

    Args:
        model_name: Name of the model used for embedding.
        text_fpath: Path to a newline-delimited file of text strings.
        device: Device spec (e.g., "cpu", "cuda", "cuda:0").
        out_dir: Directory where the output HDF5 is created/updated.

    Side Effects:
        - Creates/updates ``<out_dir>/<model_name>_text_embeddings.h5`` with embeddings
          under a per-model group.

    """
    torch_dev = parse_device(device)

    embedding_model = get_embedding_model(model_name)(device=torch_dev)

    out_fname = out_dir / f"{model_name}_text_embeddings.h5"
    if out_fname.exists():
        msg = f"Output file {out_fname} already exists"
        logger.warning(msg)

    with Path(text_fpath).open("r") as text_f:
        text_values = [line for line in text_f.read().split("\n") if len(line) > 0]

    with h5py.File(out_fname, "a") as f:
        # Write schema version for new files
        if "schema_version" not in f.attrs:
            f.attrs["schema_version"] = SCHEMA_VERSION
        model_group_name = canon_model_id(model_name)
        grp = f.require_group(model_group_name)
        for text_val in text_values:
            embedding = embedding_model.embed_text([text_val]).detach().cpu().numpy()
            if text_val in grp:
                msg = (
                    f"Embedding for text value {text_val} already exists under '{model_group_name}'"
                )
                logger.warning(msg)
                continue
            grp.create_dataset(text_val, data=embedding)


def configure_root_logger(log_file: Path) -> None:
    """Configure the root logger.

    Args:
        log_debug_info (bool): Whether to log debug information.
        log_file (Path): The path to the file to which the log will be written.

    """
    level = logging.INFO
    formatter = logging.Formatter(
        (
            "%(asctime)s, %(levelname)-8s"
            "[%(filename)s:%(module)s:%(funcName)s:%(lineno)d] %(message)s"
        ),
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    console_handler = logging.StreamHandler()
    console_handler.setLevel(level)
    console_handler.setFormatter(formatter)
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(level)
    file_handler.setFormatter(formatter)

    logging.basicConfig(level=level, handlers=[console_handler, file_handler])


if __name__ == "__main__":
    app()
