#!/usr/bin/python3
"""
LLM-based Entropy-guided Optimization with kNowledgeable priors (LEON).

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import click
import logging
import numpy as np
import os
import sys
import torch
import warnings
from math import isclose
from pathlib import Path
from tqdm import tqdm
from typing import Any, Dict, Optional, Tuple, Union

import leon
from leon.knowledge import get_knowledge_source_options, make_knowledge
from leon.embedding import get_embedder_options, get_embedder
from leon.optim import (
    get_optimizers,
    get_user_prompt_versions,
    HumanBaselineOptimizer,
    MajorityBaselineOptimizer
)
from leon.optim.state import OptimizerState
from leon.utils import initialize_designs


@click.command()
@click.option(
    "--task",
    "-t",
    "task_name",
    type=click.Choice(list(leon.registry.keys())),
    required=True,
    help="Biomedical zero-shot optimization task."
)
@click.option(
    "--index",
    "-i",
    type=int,
    required=True,
    help="Index of the test design to optimize."
)
@click.option(
    "--optimizer",
    "-o",
    type=click.Choice(get_optimizers()),
    required=True,
    help="Backbone optimizer."
)
@click.option(
    "--knowledge-source",
    "-s",
    type=click.Choice(get_knowledge_source_options()),
    default=("default",),
    multiple=True,
    show_default=True,
    help="Prior knowledge source."
)
@click.option(
    "--embedder-name",
    "-e",
    type=click.Choice(get_embedder_options()),
    default="openai/text-embedding-3-small",
    show_default=True,
    help="Embedding model to use for embedding tasks."
)
@click.option(
    "--knowledge-top-k",
    type=int,
    default=8,
    show_default=True,
    help="Number of documents to retrieve per query per knowledge source."
)
@click.option(
    "--equivalence-relation",
    type=click.Choice(leon.core.get_equivalence_relation_options()),
    default="KMeansEquivalenceRelation",
    show_default=True,
    help="Equivalence relation to use in LEON."
)
@click.option(
    "--lambda",
    "lambda_",
    type=float,
    default=None,
    show_default=True,
    help="Source critic weighting hyperparameter."
)
@click.option(
    "--mu",
    "mu_",
    type=float,
    default=None,
    show_default=True,
    help="LLM certainty hyperparameter."
)
@click.option(
    "--batch-size",
    "-b",
    type=int,
    default=32,
    show_default=True,
    help="Sampling batch size."
)
@click.option(
    "--temperature",
    type=float,
    default=1.0,
    show_default=True,
    help="Temperature for optimization."
)
@click.option(
    "--w0",
    type=float,
    default=1.0,
    show_default=True,
    help="1-Wasserstein distance constraint bound."
)
@click.option(
    "--user-prompt-version",
    "-u",
    type=click.Choice(get_user_prompt_versions()),
    default="base",
    show_default=True,
    help="Version of the user prompt to load. Only used for LLM optimizers."
)
@click.option(
    "--ablate-leon/--use-leon",
    default=False,
    help="Whether to ablate the entropy penalization in the forward model."
)
@click.option(
    "--reflection/--no-reflection",
    default=False,
    help="Whether to perform reflection as in Ma YJ et al. Proc ICLR (2024)."
)
@click.option(
    "--ablate-distribution-shift/--reason-distribution-shift",
    default=False,
    help="Whether to ablate knowledge of the distribution shift."
)
@click.option(
    "--seed",
    type=int,
    default=2025,
    show_default=True,
    help="Random seed."
)
@click.option(
    "--savedir",
    type=str,
    default="results",
    show_default=True,
    help="Path to save the optimization results to."
)
@click.option(
    "--fast-dev-run/--full-run",
    default=False,
    help="Run a short unit test."
)
@click.option(
    "--online",
    type=float,
    default=0.0,
    show_default=True,
    help="Whether to run online optimization experiments."
)
@click.option("-v", "--verbose", count=True, help="Verbosity.")
def main(
    task_name: str,
    index: int,
    optimizer: str,
    knowledge_source: Tuple[str, ...],
    embedder_name: str,
    knowledge_top_k: int,
    lambda_: Optional[float],
    mu_: Optional[float],
    equivalence_relation: str,
    batch_size: int,
    temperature: float,
    w0: float,
    user_prompt_version: str,
    ablate_leon: bool,
    reflection: bool,
    ablate_distribution_shift: bool,
    seed: int,
    savedir: Union[Path, str],
    fast_dev_run: bool,
    online: Union[float, bool],
    verbose: int
):
    """LLM-based Entropy-guided Optimization with kNowledge (LEON)."""
    setup_logging(verbose)

    if isclose(online, 0.0) or isclose(online, 1.0):
        online = True if isclose(online, 1.0) else False

    task: leon.envs.BaseTask = leon.make(
        task_name, seed=seed, online=online
    )

    categorical_config = task.train[0].discrete_features()

    if index < 0 or index >= len(task.test):
        raise ValueError(
            f"Index {index} is out of bounds for task {task_name}."
        )
    test_idxs = np.random.default_rng(seed).choice(
        len(task.test), size=len(task.test), replace=False
    )
    ref = task.test[test_idxs[index]]

    optim: leon.optim.BaseOptimizer = getattr(leon.optim, optimizer)(
        task=task,
        ref=ref,
        batch_size=batch_size,
        temperature=temperature,
        user_prompt_version=user_prompt_version,
        reflection=reflection,
        categorical_bounds=[
            list(range(len(vals))) for vals in categorical_config.values()
        ]
    )

    optim_savedir = optim.optimizer_name
    if isinstance(optim, leon.optim.BaseLLMOptimizer):
        optim_savedir = f"{optim_savedir}-{user_prompt_version}"
    if reflection:
        optim_savedir += "-reflection"
    if not isclose(float(online), 0.0):
        optim_savedir += f"-online_{float(online)}"
    if equivalence_relation != "KMeansEquivalenceRelation":
        optim_savedir += f"-{equivalence_relation}"
    savedir = os.path.join(savedir, optim_savedir)
    if os.path.isdir(os.path.join(savedir, task_name)) and any(
        fn.endswith(f"-{ref.id_}.npz")
        for fn in os.listdir(os.path.join(savedir, task_name))
    ):
        print(f"Skipping completed {task_name} {optim_savedir} {ref.id_}.")
        return

    state = OptimizerState(
        task=task,
        optimizer_name=optimizer,
        individual=ref,
        savedir=savedir,
        seed=seed
    )

    embedder = get_embedder(embedder_name)

    knowledge = {"task": task.task_description(ablate_distribution_shift)}
    prior_knowledge, knowledge_metadata = make_knowledge(
        task=task,
        optimizer=optim,
        z=state.individual,
        knowledge_sources=knowledge_source,
        embedder=embedder,
        top_k=knowledge_top_k,
        user_prompt_version=user_prompt_version,
        ablate_distribution_shift=ablate_distribution_shift
    )
    knowledge["prior_knowledge"] = prior_knowledge

    assert task.ndim is not NotImplemented or (
        isinstance(optim, leon.optim.BaseLLMOptimizer)
    )

    transform = getattr(
        leon.core,
        "IdentityTransform" if ablate_leon else "EntropyPenalizedTransform"
    )

    sim = getattr(leon.core, equivalence_relation)(
        task=task, embedder=embedder
    )
    forward_model = transform(
        task=task,
        critic=leon.model.LipschitzMLP(in_dim=task.ndim),
        Xp=torch.vstack([
            task.train[i].as_tensor() for i in range(len(task.train))
        ]),
        equiv_relation=sim,
        lambda_=lambda_,
        mu_=mu_,
        W0=w0
    )

    reference_designs = task.reduce([
        task.train[i] for i in range(len(task.train))
    ])

    mus, r2s, lambdas = [], [], []
    for _ in tqdm(
        range(state.max_samples // batch_size),
        desc=f"{task.task_name}: {optim.optimizer_name} | {index}"
    ):
        if state.has_converged:
            break

        if not state.designs and task.ndim is not NotImplemented and (
            not isinstance(
                optim, (HumanBaselineOptimizer, MajorityBaselineOptimizer)
            )
        ):
            new_xq, metadata = initialize_designs(task, batch_size, seed=seed)
        else:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                new_xq, metadata = optim(state=state, knowledge=knowledge)
            if fast_dev_run:
                break

        if not new_xq:
            continue
        full_xq = task.extend(new_xq, state.individual)
        new_yq, mu_hat, r2 = forward_model(full_xq)
        lambdas.append(getattr(forward_model, "lambda_", -1.0))
        mus.append(mu_hat)
        r2s.append(r2)
        state.log(new_xq, new_yq, **metadata)

        optim.fit(state.designs, state.predictions)
        forward_model.fit(reference_designs, state.designs)

    lambdas.append(getattr(forward_model, "lambda_", -1.0))
    kwargs: Dict[str, Any] = {
        "cli": " ".join(["python"] + sys.argv),
        "max_samples": state.max_samples,
        "user_prompt_version": user_prompt_version,
        "id_": getattr(state.individual, "id_", None),
        "knowledge_metadata": knowledge_metadata,
        "knowledge": knowledge,
        "mu_hat": np.array(mus),
        "r2": np.array(r2s),
        "lambda": np.array(lambdas),
        **metadata
    }
    state.evaluate_and_save(fast_dev_run=fast_dev_run, **kwargs)


def setup_logging(verbosity: int) -> None:
    """
    Sets up the logger.
    Input:
        verbosity: the verbosity level for the logger.
    Returns:
        None.
    """
    levels = [logging.WARNING, logging.INFO, logging.DEBUG]
    level = levels[max(min(verbosity, len(levels) - 1), 0)]
    return logging.basicConfig(
        level=level,
        stream=sys.stdout,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )


if __name__ == "__main__":
    main()
