import os
from dataclasses import dataclass, field
from typing import List

from ollama import Client
from transformers import LlamaTokenizer


@dataclass
class LLMConfig:
    tokenizer: str = field(
        default=os.getenv("LLM_TOKENIZER", "meta-llama/Llama-2-7b"),
        metadata="Path to the LLAMA tokenizer in Huggingface format.")
    host: str = field(
        default=os.getenv("LLM_HOST", "http://localhost:11434"),
        metadata={"help": "URL of the OLLAMA server."})
    base_model: str = field(
        default=os.getenv("LLM_BASE_MODEL", "llama2:7b"),
        metadata={"help": "Base model to use for the LLM planner."})
    target_model_name: str = field(
        default=os.getenv("LLM_TARGET_MODEL_NAME", "planner:7b"),
        metadata={"help": "Name of the model to create and use for the LLM planner"})
    seed: str = field(
        default=os.getenv("LLM_SEED", "5920596225"),
        metadata={"help": "Random seed to use for the LLM planner."})
    use_context: bool = field(
        default=True,
        metadata={"help": "Whether to use previous context of the planning when generating new instructions."})
    system_message: str = field(
        metadata={"help": "A general prompt to prime the LLM."},
        default="Your task is to provide instructions for two agents. The goal is to guide the agents to pickup valuable items. PINK items are more valuable than GREEN, and GREEN items are more valuable than YELLOW. Follow the instruction format given below. Do NOT output anything else in addition.")
    template: str = field(
        metadata={"help": "A general prompt to prime the LLM."},
        default=f"""<<SYS>>{{{{ .System }}}}<</SYS>>
Here are some examples of the instructions:
[INST]obs: A1 respawn; A2 respawn;[/INST]
inst: A1 examine box 1; A2 examine box 2;
[INST]obs: A1 (PINK YELLOW); A2 (GREEN YELLOW);[/INST]
inst: A1 pickup PINK; A2 pickup GREEN;
[INST]obs: A1 has PINK; A2 has GREEN;[/INST]
inst: A1 goto PINK; A2 goto GREEN;
[INST]obs: A1 respawn; A2 respawn;[/INST]
inst: A1 examine box 1; A2 examine box 2;
[INST]obs: A1 (GREEN YELLOW); A2 (YELLOW YELLOW);[/INST]
inst: A1 pickup GREEN; A2 pickup YELLOW;
[INST]obs: A1 has GREEN; A2 has YELLOW;[/INST]
inst: A1 goto GREEN; A2 goto YELLOW;
[INST]obs: A1 respawn; A2 respawn;[/INST]
inst: A1 examine box 1; A2 examine box 2;
[INST]obs: A1 (PINK YELLOW); A2 (PINK YELLOW);[/INST]
inst: A1 pickup PINK; A2 pickup PINK;
[INST]obs: A1 has PINK; A2 has PINK;[/INST]
inst: A1 goto PINK; A2 goto PINK;
[INST]obs: A1 respawn; A2 respawn;[/INST]
inst: A1 examine box 1; A2 examine box 2;
[INST]obs: A1 (GREEN YELLOW); A2 (GREEN YELLOW);[/INST]
inst: A1 pickup GREEN; A2 pickup GREEN;
[INST]obs: A1 has GREEN; A2 has GREEN;[/INST]
inst: A1 goto GREEN; A2 goto GREEN;
[INST]{{{{ .Prompt }}}}[/INST]
""")

    @classmethod
    def modelfile(cls):
        return f"""
FROM {cls.base_model}

PARAMETER stop "[INST]"
PARAMETER stop "[/INST]"
PARAMETER stop "<<SYS>>"
PARAMETER stop "<</SYS>>"
PARAMETER temperature 0.0

SYSTEM \"\"\"{cls.system_message}\"\"\"

TEMPLATE \"\"\"{cls.template}\"\"\"
"""

    @classmethod
    def prompt_body(cls, prompt: str):
        result = cls.template
        result = result.replace(f"{{{{ .System }}}}", cls.system_message)
        result = result.replace(f"{{{{ .Prompt }}}}", prompt)
        return result


@dataclass
class LLMResponse:
    response: str = field(
        default="",
        metadata={"help": "The response generated by the model."})
    inst: str = field(
        default="",
        metadata={"help": "The instruction/prompt that generated the response."})
    context: List[int]  = field(
        default_factory=list,
        metadata={"help": "Encoding of the current conversation history."})
    eval_duration: int = field(
        default=0,
        metadata={"help": "Time taken to generate the response in nanoseconds."})


class LLMGateway:
    def __init__(self, config: LLMConfig):
        self.client = Client(host=config.host)
        self.config = config
        self.tokenizer = LlamaTokenizer.from_pretrained(config.tokenizer)

    def pull(self) -> bool:
        result = self.client.pull(self.config.base_model, stream=False)
        return result["status"] == "success"

    def create(self, config: str) -> bool:
        result = self.client.create(model=self.config.target_model_name, modelfile=config)
        return result["status"] == "success"

    def generate(self, prompt: str, context: List[int] = []) -> LLMResponse:
        options = {"seed": int(self.config.seed)}
        context = context if self.config.use_context else None
        response = self.client.generate(model=self.config.target_model_name, prompt=prompt, context=context, options=options)
        return LLMResponse(
            inst=prompt,
            response=response["response"],
            context=response["context"],
            eval_duration=response["eval_duration"])

    def tokens(self, prompt: str) -> List[int]:
        return self.tokenizer.encode(prompt)
