from typing import *
from pydantic import BaseModel as PyBase
import dataclasses


@dataclasses.dataclass
class ObservationPrompt:
    text: str
    image_paths: List[str]

class DECISION(PyBase):
    action: str

@dataclasses.dataclass
class Prompt:
    system_prompt: str
    observation_prompt: ObservationPrompt
    action_prompt: str
    schema: Optional[PyBase] = DECISION

class MaxTokenLimit(Exception):
    """Raised when the response stops due to max token limit."""
    pass


class BaseModel:
    _registry: Dict[str, Type["BaseModel"]] = {}

    def __init_subclass__(cls, *, model_type: str, **kwargs):
        super().__init_subclass__(**kwargs)
        if model_type in cls._registry:
            raise ValueError(f"Duplicated model_type: {model_type}.")
        cls._registry[model_type] = cls
        cls.model_type = model_type

    @classmethod
    def from_config(cls, spec: Dict[str, Any]) -> "BaseModel":
        try:
            ModelCls = cls._registry[spec["type"]]
        except KeyError:
            raise ValueError(f"[BaseModel] unknown type '{spec['type']}'. "
                             f"Available: {list(cls._registry)}")
        params = spec.get("params", {})
        return ModelCls(**params)

    def __init__(
        self,
        name,
        version,
        temperature,
        max_response_tokens,
        reasoning_model=False,
        max_reasoning_tokens=0,
        stream=False,
        input_price_per_1M_tokens=0.0,
        output_price_per_1M_tokens=0.0,
    ):
        self.name = name
        self.version = version
        self.temperature = temperature
        self.max_response_tokens = max_response_tokens
        self.reasoning_model = reasoning_model
        self.max_reasoning_tokens = max_reasoning_tokens
        self.stream = stream
        self.input_price_per_1M_tokens = input_price_per_1M_tokens
        self.output_price_per_1M_tokens = output_price_per_1M_tokens
        self.response_tokens = 8192
        self.reasoning_tokens = 16000
        self.create_client()

    def create_client(self):
        raise NotImplementedError

    def increase_max_tokens(self, factor: int = 2):
        """
        Increase the maximum token limits for response and reasoning.

        Args:
            factor (int): The multiplier used to increase both response_tokens and reasoning_tokens.

        Returns:
            None. This method updates the token limits in-place and prints the new values.
        """
        self.response_tokens = min(int(self.response_tokens * factor), self.max_response_tokens)
        self.reasoning_tokens = min(int(self.reasoning_tokens * factor), self.max_reasoning_tokens)
        print(f"Increased max_response_tokens to {self.response_tokens}")
        print(f"Increased max_reasoning_tokens to {self.reasoning_tokens}")

    def create_client(self):
        """
        Generate the specific client
        """
        raise NotImplementedError

    def generate(self, prompt: Prompt) -> Tuple[Dict[str, Any], Union[str, None], str, Dict[str, Any]]:
        """Generate response give prompt
        
        Args:
            prompt (Prompt): prompt for generation.

        Returns:
            messages (Dict[str, Any]): The messages sent to the model.
            reasoning (Union[str, None]): The reasoning generated by the model.
            content (str): The response generated by the model.
            token_info (Dict[str, Any]): A structured summary of token usage and associated costs.
        """
        raise NotImplementedError
