from abc import ABC, abstractmethod
from typing import Any, Dict, List, Sequence, Callable
import logging

from datasets import Dataset

from typing import Callable, Union
from transformers import PreTrainedModel
from vllm import LLM, SamplingParams  # type: ignore

RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]


class Environment(ABC):

    def __init__(self, **kwargs: Any):
        for key, value in kwargs.items():
            setattr(self, key, value)
        self.logger = logging.getLogger(f"verifiers.envs.{self.__class__.__name__}")
        self.tokenizer = None
        self.dataset = None
        self.eval_dataset = None
        self.eot_id = 151643                        # For Qwen2Tokenizer, this is <|endoftext|>, which is the pad_token
        self.message_end_id = 151645                # For Qwen2Tokenizer, this is <|im_end|>, which is the eos_token
        self.message_end_newline_id = None          # For Qwen2Tokenizer, this is actually \\n. In Qwen2.5’s chat template, each assistant message ends with `<|im_end|>\n`
        self.reward_funcs = []
        self.reward_weights = []

    @abstractmethod
    def get_dataset(self, **kwargs: Any) -> Dataset | None:
        pass

    @abstractmethod
    def get_eval_dataset(self, **kwargs: Any) -> Dataset | None:
        pass

    @abstractmethod
    def get_reward_funcs(self, **kwargs: Any) -> List[RewardFunc]:
        pass
    
    @abstractmethod
    def get_reward_weights(self, **kwargs: Any) -> List[float]:
        pass
    
    @abstractmethod
    def generate(self,
                 prompts: List[List[Dict[str, Any]]],
                 llm: LLM,
                 sampling_params: SamplingParams,
                 **kwargs: Any) -> Dict[str, List[Sequence[int]] | List[str] | List[List[Dict[str, Any]]]]:
        pass
