import abc
import datetime
import json
import os
import threading
from pathlib import Path

from tqdm import tqdm
from transformers import PreTrainedTokenizer

from jailbreakbench import defenses
from jailbreakbench.config import LOGS_PATH, MAX_GENERATION_LENGTH, SYSTEM_PROMPTS, TEMPERATURE, TOP_P, Model
from jailbreakbench.llm.llm_output import LLMOutput


def _count_hf_tokens(
    tokenizer: PreTrainedTokenizer, input_texts: list[str], output_texts: list[str]
) -> tuple[list[int], list[int]]:
    """Counts the number of tokens in the input and output texts."""
    input_tokens = [len(ids) for ids in tokenizer(input_texts)["input_ids"]]  # type: ignore
    output_tokens = [len(ids) for ids in tokenizer(output_texts)["input_ids"]]  # type: ignore
    return input_tokens, output_tokens


class LLM(abc.ABC):
    temperature: float
    top_p: float
    max_n_tokens: int

    def __init__(self, model_name: str):
        """Initializes the LLM with the specified model name."""
        try:
            self.model_name = Model(model_name.lower())
        except ValueError:
            raise ValueError(
                f"Invalid model name: {model_name}. Expected one of: {', '.join([model.value for model in Model])}"
            )
        self.system_prompt = SYSTEM_PROMPTS[self.model_name]
        self.temperature = TEMPERATURE
        self.top_p = TOP_P
        self.max_n_tokens = MAX_GENERATION_LENGTH
        self.lock = threading.Lock()
        if not os.path.exists(LOGS_PATH):
            os.makedirs(LOGS_PATH, exist_ok=True)

    @abc.abstractmethod
    def update_max_new_tokens(self, n: int | None):
        """Update the maximum number of generation tokens.
        Subclass should override."""
        raise NotImplementedError()

    @abc.abstractmethod
    def query_llm(self, inputs: list[list[dict[str, str]]]) -> LLMOutput:
        """Query the LLM with supplied inputs.
        Subclass should override."""
        raise NotImplementedError()

    def query(
        self,
        prompts: list[str],
        behavior: str = "unspecified_behavior",
        phase: str = "dev",
        max_new_tokens: int | None = None,
        defense: str | None = None,
    ) -> LLMOutput:
        """Queries the model with the given prompts and logs the responses.
        Returns the responses and token counts."""

        if phase not in ["dev", "test", "eval"]:
            raise ValueError(f"Invalid phase: {phase}. Expected one of: 'dev', 'test', 'eval'")

        llm_output = self._get_responses(prompts, max_new_tokens, defense)
        self.log_responses(behavior, prompts, llm_output, phase)
        return llm_output

    def _get_responses(
        self, input_prompts: list[str], max_new_tokens: int | None = None, defense: str | None = None
    ) -> LLMOutput:
        """Generates responses from the model for the given prompts."""

        # TODO: consider inlining this into `query`.

        self.update_max_new_tokens(max_new_tokens)

        # Create list of conversations
        convs_list = [self.convert_prompt_to_conversation(p) for p in input_prompts]

        # Query the LLM; apply defense if applicable
        if defense is None:
            outputs = self.query_llm(convs_list)
        else:
            defense_ = defenses.DEFENSES[defense](self)  # type: ignore
            outputs = LLMOutput.from_list([defense_.query_with_dicts(p) for p in tqdm(convs_list)])

        return outputs

    def convert_prompt_to_conversation(self, prompt: str) -> list[dict[str, str]]:
        """Create a conversation dictionary containing the LLM-specific
        system prompt and the user-supplied input prompt."""
        if self.system_prompt is None:
            return [{"role": "system", "content": ""}, {"role": "user", "content": prompt}]
        return [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}]

    def log_responses(
        self,
        behavior: str,
        prompts: list[str],
        llm_output: LLMOutput,
        phase: str,
    ) -> None:
        """Logs the responses and metadata to a json file."""

        file_path = f"{LOGS_PATH}/{phase}/{self.model_name.value}/{behavior}.json"
        Path(file_path).parent.mkdir(parents=True, exist_ok=True)
        prompt_tokens = llm_output.prompt_tokens
        completion_tokens = llm_output.completion_tokens
        responses = llm_output.responses
        with self.lock:
            if not os.path.exists(file_path):
                with open(file_path, "w") as file:
                    json.dump([], file, indent=4)

            with open(file_path, "r+") as file:
                file.seek(0, os.SEEK_END)
                position = file.tell() - 1
                if position < 0:
                    position = 0
                file.seek(position, os.SEEK_SET)

                if position > 1:
                    file.write(",\n\t")

                timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat()
                log_entries: list[dict[str, str] | list[dict[str, str | int]]] = [{"timestamp": timestamp}]
                log_entries.append(
                    [
                        {"prompt": p, "response": r, "prompt_tokens": p_tokens, "response_tokens": r_tokens}
                        for p, r, p_tokens, r_tokens in zip(prompts, responses, prompt_tokens, completion_tokens)
                    ]
                )

                json.dump(log_entries, file, indent=4, separators=(",", ": "))
                # Close the list
                file.write("\n]")


class LLMDummy(LLM):
    """Dummy class for testing without GPU"""

    def query_llm(self, inputs: list[list[dict[str, str]]]) -> LLMOutput:
        """Generates dummy responses for the given prompts."""
        outputs = [f"Response to `{prompt[1]['content']}`" for prompt in inputs]
        llm_output = LLMOutput(
            responses=outputs,
            prompt_tokens=[len(prompt[1]["content"]) for prompt in inputs],
            completion_tokens=[len(output) for output in outputs],
        )
        return llm_output

    def update_max_new_tokens(self, n: int | None):
        """Update the maximum number of generation tokens."""
        if n is not None:
            self.max_n_tokens = n
