"""
This file contains the abstract class for a model policy, whether it's an API-based model or local one.
It also contains utility classes, most notably training datasets.
"""

import abc
import dataclasses
import random
import asyncio
from typing import Any
from utils.async_utils import run_coroutine


@dataclasses.dataclass
class Sample(abc.ABC):
    history: list[dict[str, str]]


@dataclasses.dataclass
class SingleSample(Sample):
    history: list[dict[str, str]]
    output: str
    aux_info: dict[str, Any] = dataclasses.field(default_factory=dict)


@dataclasses.dataclass
class PairedSample(Sample):
    history: list[dict[str, str]]
    winning_output: str
    losing_output: str
    aux_info: dict[str, Any] = dataclasses.field(default_factory=dict)

@dataclasses.dataclass
class EvaluatedSample(Sample):
    history: list[dict[str, str]]
    output: str
    reward: float
    aux_info: dict[str, Any] = dataclasses.field(default_factory=dict)


class Policy(abc.ABC):
    """A model policy. This can be an API-based model, or a local model."""

    colloquial_name: str  # e.g. "GPT-4o", "Llama-3.1-8B-Instruct", "Martingale-trained Llama"
    identifier: str  # unique random string

    @abc.abstractmethod
    def __init__(self, colloquial_name: str):
        """Each subclass should implement its own instantiation logic, after calling Policy.__init__() at the beginning."""
        self.colloquial_name = colloquial_name
        self.identifier = hex(random.randint(0, 2**64 - 1))[2:]

    def __str__(self):
        return f"{self.colloquial_name}-ID-{self.identifier}"

    @abc.abstractmethod
    async def infer_single_async(
        self, 
        history: list[dict[str, str]], 
        disable_system_prompt: bool = False,
        **kwargs
    ) -> str:
        """Each subclass should implement its own async generation method."""
        raise NotImplementedError

    def infer_single(
        self,
        history: list[dict[str, str]],
        disable_system_prompt: bool = False,
        **kwargs
    ) -> str:
        """
        Given a dialogue history, return a single response.

        :param history: The dialogue history, in OpenAI format.
        :type history: list[dict[str, str]]
        :return: The single response.
        :rtype: str
        """
        return run_coroutine(self.infer_single_async(history, disable_system_prompt, **kwargs))

    async def infer_batch_async(
        self, 
        histories: list[list[dict[str, str]]], 
        disable_system_prompt: bool = False,
        **kwargs
    ) -> list[str]:
        """Same as `infer_batch`, but async."""
        return await asyncio.gather(
            *[self.infer_single_async(history, disable_system_prompt, **kwargs) for history in histories]
        )

    def infer_batch(
        self, 
        histories: list[list[dict[str, str]]], 
        disable_system_prompt: bool = False,
        **kwargs
    ) -> list[str]:
        """
        Given a list of dialogue histories, return a list of responses.
        By default, this method runs `infer_single_async` for each sample individually. You should implement your own `infer_batch` if this becomes the perfomance bottleneck, for example if you're using a local model or a batching API.

        :param histories: The list of dialogue histories, in OpenAI format.
        :type histories: list[list[dict[str, str]]]
        :return: The list of responses.
        :rtype: list[str]
        """
        return run_coroutine(self.infer_batch_async(histories, disable_system_prompt, **kwargs))

    def train_sft(self, samples: list[SingleSample]) -> "Policy":
        """Perform SFT training. Each subclass should implement this by itself, if supported."""
        raise NotImplementedError

    def train_dpo(self, samples: list[PairedSample]) -> "Policy":
        """Perform DPO training. Each subclass should implement this by itself, if supported."""
        raise NotImplementedError

    def train_ppo(self, samples: list[EvaluatedSample]) -> "Policy":
        """Perform offline PPO training with pre-determined reward values. Each subclass should implement this by itself, if supported."""
        raise NotImplementedError

    def train(self, samples: list[Sample]) -> "Policy":
        """Perform training."""
        if not samples:
            return self

        if isinstance(samples[0], PairedSample):
            return self.train_dpo(samples)
        elif isinstance(samples[0], SingleSample):
            return self.train_sft(samples)
        elif isinstance(samples[0], EvaluatedSample):
            return self.train_ppo(samples)
        else:
            raise TypeError(
                "Unrecognized sample type. Must be list[PairedSample] or list[SingleSample] orlist[EvaluatedSample]."
            )
