import json
import os
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Callable

from serde import Serializable, deserialize_from_json


class ModelType(str, Enum):
    pass

    @classmethod
    def get_enum_from_string(cls, type_str: str):
        for model_type in ModelType.__subclasses__():
            try:
                return model_type(type_str)
            except ValueError:
                continue
        else:
            raise Exception(f"Model type {type_str} not found")

    @classmethod
    def deserialize(cls, deserialize: Callable, type_str: str) -> "ModelType":
        return cls.get_enum_from_string(type_str)


@dataclass
class ChatModelConfig:
    model_name: str
    default_temperature: float
    seed: int
    max_output_tokens: int
    supports_system: bool = True


@dataclass
class CompletionModelConfig:
    model_name: str
    default_temperature: float
    seed: int
    max_output_tokens: int


class Model:
    CONFIG_DIR = Path(__file__).parent / "model_configs"

    def __init__(self, model_type: ModelType):
        self.model_type = model_type
        config_path = Model.CONFIG_DIR / f"{model_type.value}.json"
        with open(config_path) as file:
            self.config_dict = json.load(file)

    def serialize(self) -> dict[str, Serializable]:
        return {"type": self.model_type, "config": self.config_dict}


class ChatModel(Model):
    def __init__(self, model_type: ModelType, enable_langsmith: bool = False):
        super().__init__(model_type)
        self.config = deserialize_from_json(ChatModelConfig, self.config_dict)

        if enable_langsmith:
            os.environ["LANGCHAIN_TRACING_V2"] = "true"
            os.environ["LANGCHAIN_ENDPOINT"] = "https://eu.api.smith.langchain.com"
            os.environ["LANGCHAIN_PROJECT"] = "Steganography_v2"

    def set_config(self, config: dict[str, Serializable]):
        self.config_dict = config
        self.config = deserialize_from_json(ChatModelConfig, config)

    @abstractmethod
    def query(self, messages: list[dict[str, str]]) -> str | None:
        """Query the model.

        Args:
            messages: the messages to pass to the model

        Returns:
            str: the LLM's response
        """
        pass

    @abstractmethod
    def query_with_retry(self, messages: list[dict[str, str]], backoff: int = 3) -> str:
        """Query the model.

        Args:
            messages: the messages to pass to the model
            backoff: the time to wait between unsuccessful calls

        Returns:
            str: the LLM's response
        """
        pass

    @classmethod
    def deserialize(
        cls, deserialize: Callable, type: str, config: dict[str, Serializable]
    ) -> "ChatModel":
        model_type = ModelType.get_enum_from_string(type)
        # This is hacky
        for subclass in cls.__subclasses__():
            if subclass.__name__ == "GPT":
                gpt_class = subclass
            else:
                claude_class = subclass
        if "gpt" in model_type.value.lower():
            ret = gpt_class(model_type)
        else:
            ret = claude_class(model_type)
        ret.set_config(config)
        return ret


class CompletionModel(Model):
    @abstractmethod
    def completion(self, prompt: str) -> str:
        """Run a completion on the model.

        Args:
            prompt: the text to complete

        Returns:
            the completed text
        """
        pass


class Classifier(Model):
    @abstractmethod
    def classify(self, text: str) -> int:
        """Classify the text.

        Args:
            text: the text to classify

        Returns:
            the class of the text
        """
        pass
