#!/usr/bin/python3
"""
LiteLLM-served large language models (LLM) as backbone optimizers.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import json
import logging
import numpy as np
import os
from litellm import completion
from pathlib import Path
from pydantic import BaseModel, ValidationError
from tenacity import retry, stop_after_attempt, wait_random
from transformers import AutoTokenizer
from typing import Any, Dict, Final, List, Optional, Tuple, Type, Union

from .llm import BaseLLMOptimizer, load_reflection_prompt
from .state import OptimizerState
from ..envs.base import BaseTask
from ..utils import json_loads


logger = logging.getLogger(__name__)


class LiteLLMOptimizer(BaseLLMOptimizer):
    max_context_length: int = NotImplemented

    def __init__(
        self,
        task: BaseTask,
        model_id: str,
        tokenizer_id: Optional[str] = None,
        batch_size: int = 1,
        temperature: float = 1.0,
        max_new_tokens: int = 2048,
        seed: int = 2025,
        promptdir: Union[Path, str] = os.path.join(
            os.path.dirname(__file__), "prompts"
        ),
        system_prompt_version: str = "base",
        user_prompt_version: str = "base",
        reflection: bool = True,
        ablate_distribution_shift: bool = False,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            model_id: LiteLLM model ID to use.
            tokenizer_id: tokenizer name to use if different from model.
            batch_size: batch size for optimization.
            temperature: temperature parameter for text generation.
            max_new_tokens: maximum number of new tokens to generate.
            seed: random seed. Default 2025.
            promptdir: directory containing the prompt templates.
            system_prompt_version: version of the system prompt to load.
            user_prompt_version: version of the user prompt to load.
            reflection: whether to perform reflection as in Ma YJ, et al. Proc
                ICLR (2024).
            ablate_distribution_shift: whether to ablate knowledge of the
                distribution shift.
        """
        self.optimizer_name = model_id
        self.tokenizer_id = tokenizer_id or self.optimizer_name

        super(LiteLLMOptimizer, self).__init__(
            task,
            batch_size,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            seed=seed,
            promptdir=promptdir,
            system_prompt_version=system_prompt_version,
            user_prompt_version=user_prompt_version,
            reflection=reflection,
            ablate_distribution_shift=ablate_distribution_shift,
            **kwargs
        )
        self.design_schema: Final[Type[BaseModel]] = task.design_schema
        self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)

    @retry(wait=wait_random(min=45, max=90), stop=stop_after_attempt(6))
    def forward(
        self, state: OptimizerState, knowledge: Dict[str, str], **kwargs
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        """
        Returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            knowledge: the prior knowledge to use for the optimization.
        Returns:
            A batch of candidates to evaluate of shape B, where B is the
            sampling batch size, and a dictionary of metadata.
        """
        del kwargs
        memory = state.memory.copy(deep=True)
        enc = self._tokenizer.encode

        num_input_tokens, num_output_tokens = 0, 0
        if self.reflection:
            reflection_prompt = load_reflection_prompt()
            # Only reflect on the most recent batch of designs.
            reflection_query = reflection_prompt.format(
                memory=state.memory.iloc[-self.batch_size:].to_markdown(),
                task=self.task.task_description(self.ablate_distribution_shift)
            )
            response, reflection_metadata = self.chat(
                messages=[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": reflection_query}
                ],
                tools=[]
            )
            num_input_tokens += max(
                getattr(reflection_metadata, "num_input_tokens", -1), 0
            )
            num_output_tokens += max(
                getattr(reflection_metadata, "num_output_tokens", -1), 0
            )
            knowledge["reflection"] = response["content"]
            logger.info(f"  Reflection: {knowledge['reflection']}")

        optim_prompt = ""
        buffer_tokens = len(enc(self.system_prompt)) + self.max_new_tokens
        while True:
            optim_prompt = self.user_prompt.format(memory=memory, **knowledge)
            num_tokens = len(enc(optim_prompt)) + buffer_tokens
            if not hasattr(self, "max_context_length") or (
                num_tokens <= self.max_context_length
            ):
                break
            if len(memory) >= self.batch_size:
                memory = memory.iloc[self.batch_size:]

        optim_prompt += (
            "\nOnly respond with valid JSON for your design.\n\n```json"
        )

        self._x, self._y = [], np.array([])
        out = []
        fallbacks = self.design_schema.model_fields
        for i in range(self.batch_size):
            response = completion(
                model=self.optimizer_name,
                messages=[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": optim_prompt}
                ],
                max_completion_tokens=self.max_new_tokens,
                response_format={
                    "type": "json_schema",
                    "json_schema": {
                        "name": "design",
                        "schema": self.design_schema.model_json_schema()
                    },
                    "strict": True
                },
                temperature=self.temperature,
                stop="```"
            )
            if i == 0:
                num_input_tokens += getattr(response.usage, "prompt_tokens", 0)
            num_output_tokens += getattr(
                response.usage, "completion_tokens", 0
            )

            wrapper = response.choices[0].message
            designs = wrapper.tool_calls or wrapper.content
            if not designs:
                continue
            if isinstance(designs, list):
                designs = designs[0].function.arguments
            content = json_loads(designs)
            assert isinstance(content, dict)
            while any(isinstance(val, dict) for val in content.values()):
                bad_keys = [
                    key for key in content.keys()
                    if isinstance(content[key], dict)
                ]
                for key in bad_keys:
                    content.update(content.pop(key))

            if self.task.task_name == "IWPCWarfarin-v0":
                if "warfarin_dose" not in content.keys():
                    content["warfarin_dose"] = content.pop(
                        next(iter(content.keys()))
                    )

            while True:
                try:
                    out.append(self.design_schema.model_validate(content))
                except ValidationError as e:
                    if self.task.task_name in ["IWPCWarfarin-v0"]:
                        break
                    for err in e.errors():
                        key = str(err["loc"][0])
                        if not hasattr(fallbacks[key], "annotation"):
                            raise e
                        if getattr(fallbacks[key], "annotation", None) == bool:
                            content[key] = self._rng.choice([True, False])
                        else:
                            content[key] = self._rng.choice([
                                x.value for x in getattr(
                                    fallbacks[key], "annotation", []
                                )
                            ])
                    continue
                break
        return out, {
            "num_input_tokens": num_input_tokens,
            "num_output_tokens": num_output_tokens
        }

    def chat(
        self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
    ) -> Tuple[Dict[str, Any], Dict[str, int]]:
        """
        Chat with the parent LLM.
        Input:
            messages: the input messages to the LLM.
            tools: the tools available to the LLM.
        Returns:
            The response from the LLM and associated metadata.
        """
        del messages, tools
        return NotImplemented


class MetaLlamaOptimizer(LiteLLMOptimizer):
    optimizer_name: str = NotImplemented

    tokenizer_name: str = NotImplemented

    max_context_length: int = 64000

    def __init__(
        self,
        task: BaseTask,
        batch_size: int = 1,
        temperature: float = 1.0,
        max_new_tokens: int = 2048,
        seed: int = 2025,
        promptdir: Union[Path, str] = os.path.join(
            os.path.dirname(__file__), "prompts"
        ),
        system_prompt_version: str = "base",
        user_prompt_version: str = "base",
        reflection: bool = True,
        ablate_distribution_shift: bool = False,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            batch_size: batch size for optimization.
            temperature: temperature parameter for text generation.
            max_new_tokens: maximum number of new tokens to generate.
            seed: random seed. Default 2025.
            promptdir: directory containing the prompt templates.
            system_prompt_version: version of the system prompt to load.
            user_prompt_version: version of the user prompt to load.
            reflection: whether to perform reflection as in Ma YJ, et al. Proc
                ICLR (2024).
            ablate_distribution_shift: whether to ablate knowledge of the
                distribution shift.
        """
        super(MetaLlamaOptimizer, self).__init__(
            task=task,
            model_id=self.optimizer_name,
            tokenizer_id=self.tokenizer_name,
            batch_size=batch_size,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            seed=seed,
            promptdir=promptdir,
            system_prompt_version=system_prompt_version,
            user_prompt_version=user_prompt_version,
            reflection=reflection,
            ablate_distribution_shift=ablate_distribution_shift,
            **kwargs
        )

    @retry(wait=wait_random(min=45, max=90), stop=stop_after_attempt(6))
    def chat(
        self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
    ) -> Tuple[Dict[str, Any], Dict[str, int]]:
        """
        Chat with the parent LLM.
        Input:
            messages: the input messages to the LLM.
            tools: the tools available to the LLM.
        Returns:
            The response from the LLM and associated metadata.
        """
        # Note: This is a hack to get the LLM to use the tools even for models
        # on Bedrock. The prompt prefix is derived from the chat template for
        # Meta Llama 3.3 70B, with an added sentence to allow for the option
        # to terminate tool use.
        prompt_prefix = ""
        if tools:
            tools_str = "\n\n".join(json.dumps(j, indent=4) for j in tools)
            prompt_prefix = (
                "Given the following functions, please respond with a JSON "
                "for a function call with its proper arguments that best "
                "answers the given prompt.\n\nRespond in the format "
                '{"name": function name, "parameters": dictionary of argument '
                "name and its value}. Do not use variables. If no additional "
                "function call is needed, respond to the original prompt "
                f"without JSON.\n\n{tools_str}\n\n"
            )
        prompt_suffix_msgs, tool_msgs = [], []
        for i in range(len(messages)):
            if messages[i]["role"] == "user":
                messages[i]["content"] = prompt_prefix + messages[i]["content"]
            elif messages[i]["role"] in ["tool", "assistant"]:
                prompt_suffix_msgs.append(messages[i])
                tool_msgs.append(i)
        prompt_suffix = ""
        if prompt_suffix_msgs:
            prompt_suffix = self._tokenizer.apply_chat_template(
                prompt_suffix_msgs, add_generation_prompt=False, tokenize=False
            )
            # Corresponds to the <|start_header_id|> token.
            assistant_start = str(self._tokenizer.added_tokens_decoder[128006])
            assistant_start += "assistant"
            if assistant_start in prompt_suffix:
                prompt_suffix = assistant_start + (
                    prompt_suffix.split(assistant_start, 1)[-1]
                )
        for i in tool_msgs[::-1]:
            messages.pop(i)
        messages[-1]["content"] += prompt_suffix

        response = completion(
            messages=messages,
            model=self.optimizer_name,
            max_completion_tokens=self.max_new_tokens,
            temperature=0
        )
        return response.choices[0].message.to_dict(), {
            "num_input_tokens": getattr(response.usage, "prompt_tokens", -1),
            "num_output_tokens": getattr(
                response.usage, "completion_tokens", -1
            )
        }


class MetaLlama70BOptimizer(MetaLlamaOptimizer):
    optimizer_name: str = "bedrock/us.meta.llama3-3-70b-instruct-v1:0"

    tokenizer_name: str = "meta-llama/Llama-3.3-70B-Instruct"


class MetaLlama8BOptimizer(MetaLlamaOptimizer):
    optimizer_name: str = "bedrock/us.meta.llama3-1-8b-instruct-v1:0"

    tokenizer_name: str = "meta-llama/Llama-3.1-8B-Instruct"


class DeepSeekR1Optimizer(LiteLLMOptimizer):
    optimizer_name: str = "bedrock/us.deepseek.r1-v1:0"

    tokenizer_name: str = "deepseek-ai/DeepSeek-R1"

    max_context_length: int = 64000

    def __init__(
        self,
        task: BaseTask,
        batch_size: int = 1,
        temperature: float = 1.0,
        max_new_tokens: int = 2048,
        seed: int = 2025,
        promptdir: Union[Path, str] = os.path.join(
            os.path.dirname(__file__), "prompts"
        ),
        system_prompt_version: str = "base",
        user_prompt_version: str = "base",
        reflection: bool = True,
        ablate_distribution_shift: bool = False,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            batch_size: batch size for optimization.
            temperature: temperature parameter for text generation.
            max_new_tokens: maximum number of new tokens to generate.
            seed: random seed. Default 2025.
            promptdir: directory containing the prompt templates.
            system_prompt_version: version of the system prompt to load.
            user_prompt_version: version of the user prompt to load.
            reflection: whether to perform reflection as in Ma YJ, et al. Proc
                ICLR (2024).
            ablate_distribution_shift: whether to ablate knowledge of the
                distribution shift.
        """
        super(DeepSeekR1Optimizer, self).__init__(
            task=task,
            model_id=self.optimizer_name,
            tokenizer_id=self.tokenizer_name,
            batch_size=batch_size,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            seed=seed,
            promptdir=promptdir,
            system_prompt_version=system_prompt_version,
            user_prompt_version=user_prompt_version,
            reflection=reflection,
            ablate_distribution_shift=ablate_distribution_shift,
            **kwargs
        )

    @retry(wait=wait_random(min=45, max=90), stop=stop_after_attempt(6))
    def forward(
        self, state: OptimizerState, knowledge: Dict[str, str], **kwargs
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        """
        Returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            knowledge: the prior knowledge to use for the optimization.
        Returns:
            A batch of candidates to evaluate of shape B, where B is the
            sampling batch size, and a dictionary of metadata.
        """
        del kwargs
        memory = state.memory.copy(deep=True)
        enc = self._tokenizer.encode

        num_input_tokens, num_output_tokens = 0, 0
        if self.reflection:
            reflection_prompt = load_reflection_prompt()
            # Only reflect on the most recent batch of designs.
            reflection_query = reflection_prompt.format(
                memory=state.memory.iloc[-self.batch_size:].to_markdown(),
                task=self.task.task_description(self.ablate_distribution_shift)
            )
            response, reflection_metadata = self.chat(
                messages=[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": reflection_query}
                ],
                tools=[]
            )
            num_input_tokens += max(
                getattr(reflection_metadata, "num_input_tokens", -1), 0
            )
            num_output_tokens += max(
                getattr(reflection_metadata, "num_output_tokens", -1), 0
            )
            knowledge["reflection"] = response["content"]
            logger.info(f"  Reflection: {knowledge['reflection']}")

        optim_prompt = ""
        buffer_tokens = len(enc(self.system_prompt)) + self.max_new_tokens
        while True:
            optim_prompt = self.user_prompt.format(memory=memory, **knowledge)
            num_tokens = len(enc(optim_prompt)) + buffer_tokens
            if not hasattr(self, "max_context_length") or (
                num_tokens <= self.max_context_length
            ):
                break
            if len(memory) >= self.batch_size:
                memory = memory.iloc[self.batch_size:]

        schema = json.dumps(self.design_schema.model_json_schema(), indent=2)
        optim_prompt += (
            "\nOnly respond with valid JSON for your design according to the "
            f"following schema:\n\n{schema}\n\n```json"
        )

        self._x, self._y = [], np.array([])
        out = []
        fallbacks = self.design_schema.model_fields
        for i in range(self.batch_size):
            response = completion(
                model=self.optimizer_name,
                messages=[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": optim_prompt}
                ],
                max_completion_tokens=self.max_new_tokens,
                temperature=self.temperature,
                stop="```"
            )
            if i == 0:
                num_input_tokens += getattr(response.usage, "prompt_tokens", 0)
            num_output_tokens += getattr(
                response.usage, "completion_tokens", 0
            )

            wrapper = response.choices[0].message
            designs = wrapper.tool_calls or wrapper.content
            if not designs:
                continue
            if isinstance(designs, list):
                designs = designs[0].function.arguments
            content = json_loads(designs)
            assert isinstance(content, dict)
            while any(isinstance(val, dict) for val in content.values()):
                bad_keys = [
                    key for key in content.keys()
                    if isinstance(content[key], dict)
                ]
                for key in bad_keys:
                    content.update(content.pop(key))

            if self.task.task_name == "IWPCWarfarin-v0":
                if "warfarin_dose" not in content.keys():
                    content["warfarin_dose"] = content.pop(
                        next(iter(content.keys()))
                    )

            while True:
                try:
                    out.append(self.design_schema.model_validate(content))
                except ValidationError as e:
                    if self.task.task_name in ["IWPCWarfarin-v0"]:
                        break
                    for err in e.errors():
                        key = str(err["loc"][0])
                        if not hasattr(fallbacks[key], "annotation"):
                            raise e
                        if getattr(fallbacks[key], "annotation", None) == bool:
                            content[key] = self._rng.choice([True, False])
                        else:
                            content[key] = self._rng.choice([
                                x.value for x in getattr(
                                    fallbacks[key], "annotation", []
                                )
                            ])
                    continue
                break
        return out, {
            "num_input_tokens": num_input_tokens,
            "num_output_tokens": num_output_tokens
        }

    @retry(wait=wait_random(min=45, max=90), stop=stop_after_attempt(6))
    def chat(
        self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
    ) -> Tuple[Dict[str, Any], Dict[str, int]]:
        """
        Chat with the parent LLM.
        Input:
            messages: the input messages to the LLM.
            tools: the tools available to the LLM.
        Returns:
            The response from the LLM and associated metadata.
        """
        # Note: This is a hack to get the LLM to use the tools even for models
        # on Bedrock. The prompt prefix is derived from the chat template for
        # Meta Llama 3.3 70B, with an added sentence to allow for the option
        # to terminate tool use.
        prompt_prefix = ""
        if tools:
            tools_str = "\n\n".join(json.dumps(j, indent=4) for j in tools)
            prompt_prefix = (
                "Given the following functions, please respond with a JSON "
                "for a function call with its proper arguments that best "
                "answers the given prompt.\n\nRespond in the format "
                '{"name": function name, "parameters": dictionary of argument '
                "name and its value}. Do not use variables. If no additional "
                "function call is needed, respond to the original prompt "
                f"without JSON.\n\n{tools_str}\n\n"
            )
        prompt_suffix_msgs, tool_msgs = [], []
        for i in range(len(messages)):
            if messages[i]["role"] == "user":
                messages[i]["content"] = prompt_prefix + messages[i]["content"]
            elif messages[i]["role"] in ["tool", "assistant"]:
                prompt_suffix_msgs.append(messages[i])
                tool_msgs.append(i)
        prompt_suffix = ""
        if prompt_suffix_msgs:
            prompt_suffix = self._tokenizer.apply_chat_template(
                prompt_suffix_msgs, add_generation_prompt=False, tokenize=False
            )
            # Corresponds to the <|start_header_id|> token.
            assistant_start = str(self._tokenizer.added_tokens_decoder[128006])
            assistant_start += "assistant"
            if assistant_start in prompt_suffix:
                prompt_suffix = assistant_start + (
                    prompt_suffix.split(assistant_start, 1)[-1]
                )
        for i in tool_msgs[::-1]:
            messages.pop(i)
        messages[-1]["content"] += prompt_suffix

        response = completion(
            messages=messages,
            model=self.optimizer_name,
            max_completion_tokens=self.max_new_tokens,
            temperature=0
        )
        return response.choices[0].message.to_dict(), {
            "num_input_tokens": getattr(response.usage, "prompt_tokens", -1),
            "num_output_tokens": getattr(
                response.usage, "completion_tokens", -1
            )
        }
