import os
from typing import Any, Dict

from langchain_openai import ChatOpenAI
from transformers import GenerationConfig


class ApiLM:
    def __init__(
        self,
        model_name: str,
        token: str,
        generation_config: Dict,
    ):
        self.model_type = "instruct"
        self.model_name = model_name
        self.token = token
        self.generation_config = GenerationConfig(**generation_config)

        os.environ["OPENAI_API_KEY"] = token

        self.model = ChatOpenAI(
            model=model_name,
            temperature=self.generation_config.temperature,
            top_p=self.generation_config.top_p,
            max_tokens=self.generation_config.max_new_tokens,
        )

        object.__setattr__(
            self.model, "name_or_path", model_name
        )  # monkey patch for compatibility with no API models

        self.model = self.model.bind(logprobs=True)

    def generate(self, prompt: str, *args, **kwargs) -> Dict[str, Any]:
        log_likelihood = []
        log_likelihood_truncated = []

        tokens = []
        tokens_truncated = []
        tokens_decoded_generated = []
        tokens_decoded_generated_truncated = []

        messages = [
            (
                "system",
                "Answer the following question with one word or phrase:",
            ),  # TODO: make this configurable for non Q&A tasks
            ("human", prompt),
        ]
        # repeat self.num_generations times
        for i in range(self.generation_config.num_return_sequences):
            response = self.model.invoke(messages)
            content = response.content
            logprobs_content = response.response_metadata["logprobs"]
            tk = []
            logprob = []
            for item in logprobs_content["content"]:
                tk.append(item["token"])
                logprob.append(item["logprob"])

            log_likelihood.append(logprob)
            log_likelihood_truncated.append(logprob)
            tokens.append(tk)
            tokens_truncated.append(tk)
            tokens_decoded_generated.append(content)
            tokens_decoded_generated_truncated.append(content)

        result_dict = {
            "log_likelihood": log_likelihood,
            "log_likelihood_truncated": log_likelihood_truncated,
            "tokens": tokens,
            "tokens_truncated": tokens_truncated,
            "tokens_decoded_generated": tokens_decoded_generated,
            "tokens_decoded_generated_truncated": tokens_decoded_generated_truncated,
        }

        return result_dict
