from __future__ import annotations

from debate.agent import Agent, ScratchpadConfig
from models import BestOfNConfig, HumanModel, Model, ModelResponse, SpeechStructure
from debate.speech_format import SpeechFormat, SpeechFormatType, SpeechFormatStructure
from debate.transcript import SpeechFormat, Transcript
from prompts import Prompt
from utils import logger_utils, quote_utils
import utils.constants as constants

from pydantic import BaseModel

from typing import Optional, Union
import copy


class Debater(Agent):
    def __init__(
        self,
        name: str,
        prompt: Prompt | list[Prompt],
        model: Model,
        num_speeches: int,
        speech_format: Optional[SpeechFormat] = None,
        scratchpad_config: ScratchpadConfig = ScratchpadConfig(),
        quotes_require_validation: bool = True,
    ):
        """
        An abstraction that corresponds to a debater in the round.

        Params:
            name: A string to identify the debater. It needs only to be unique within its own debate round.
            is_debater: Boolean indicating whether the agent is a debater or a judge.
            prompt: The Prompt structure that controls the inputs to the models. A list is passed in for batch processing.
            model: The model that actually performs the text generation.
            num_speeches: The number of speeches each debater will generate in the round.
            speech_format: The order of speeches that the debater is expecting to receive.
            scratchpad_config: configuration that specifies if and how to use a scratchpad
            quotes_require_validation: Whether or not the speeches generated by this agent already have had their quotes
                validated. Quote validation takes some time, so this helps us perform validation only when necessary. This
                is true for speeches generated by the HumanModel and false for the other models.
        """
        super().__init__(
            name=name,
            is_debater=True,
            prompt=prompt,
            model=model,
            num_speeches=num_speeches,
            receive_validated_quotes=True,
            quotes_require_validation=quotes_require_validation,
            speech_format=speech_format
            if speech_format
            else SpeechFormatType.DEFAULT_DEBATE.get_speech_format(
                name=name, num_speeches=num_speeches, use_scratchpad=scratchpad_config.use_scratchpad
            ),
        )
        self.scratchpad_config = scratchpad_config
        self.quotes_require_validation = quotes_require_validation
        self.logger = logger_utils.get_default_logger(__name__)

    def generate(self, max_new_tokens: Optional[int] = None, round_idx: int = 0) -> Optional[list[ModelResponse]]:
        """Generates new text using the pre-existing transcript as input"""
        model_inputs = [transcript.to_model_input() for transcript in self.transcripts]

        return self.model.predict(
            inputs=model_inputs,
            max_new_tokens=max_new_tokens or self.speech_format.tokens_per_speech,
            debater_name=self.name,
            round_idx=round_idx,
        )

    def copy(
        self, transcripts: Optional[list[Transcript]] = None, prompts: Optional[list[Prompt] | Prompt] = None
    ) -> Debater:
        """Deepcopies the debater (except for the model, which is a shallow copy)"""
        debater = Debater(
            name=self.name,
            prompt=prompts if prompts else [copy.deepcopy(prompt) for prompt in self.prompts],
            model=self.model,
            num_speeches=self.num_speeches,
            speech_format=self.speech_format,
            scratchpad_config=self.scratchpad_config,
            quotes_require_validation=self.quotes_require_validation,
        )
        if transcripts:
            debater.transcripts = [transcript.copy() for transcript in transcripts]
        return debater

    def __call__(self) -> tuple[list[str], Optional[list[ModelResponse]]]:
        """Generates new text using the pre-existing transcript as input. If it has access to a
        scratchpad, it will use that but keep those results hidden."""
        batch_reasoning = []
        if self.scratchpad_config.use_scratchpad:
            batch_reasoning = [
                reasoning.speech for reasoning in self.generate(max_new_tokens=self.scratchpad_config.scratchpad_word_limit)
            ]
            for i, reasoning in enumerate(batch_reasoning):
                super().receive_message(speaker=self.name, content=reasoning, idx=i)
                self.logger.debug(reasoning)

        generation = self.generate()
        all_speeches = [gen.speech for gen in generation]

        if self.scratchpad_config.use_scratchpad and self.scratchpad_config.scratchpad_public:
            all_speeches = [
                constants.LINE_SEPARATOR.join([reasoning, speech]) for reasoning, speech in zip(all_speeches, generation)
            ]

        return all_speeches, generation


class BestOfNDebater(Debater):
    def __init__(
        self,
        debater: Debater,
        opposing_debater: Debater,
        judge: Judge,
        best_of_n_config: BestOfNConfig,
        background_text: str,
    ):
        super().__init__(
            name=debater.name,
            prompt=debater.prompts,
            model=debater.model,
            num_speeches=debater.num_speeches,
            speech_format=debater.speech_format,
        )
        self.opposing_debater = opposing_debater
        self.base_opponent_transcript = copy.deepcopy(opposing_debater.transcripts[0])
        self.judge = judge
        self.config = best_of_n_config
        self.background_text = background_text

    def __call__(self):
        # just doing round 1 for now and unbatched inputs
        model_responses = self.model.predict(
            inputs=[self.transcripts[0].to_model_input() for _ in range(self.config.n)],
            max_new_tokens=self.speech_format.tokens_per_speech,
            debater_name=self.name,
        )
        speeches = [
            quote_utils.validate_and_replace_quotes(
                speech_content=str(response.speech), background_text=self.background_text
            )
            for response in model_responses
        ]

        if self.config.opponent_n:
            opposing_debater_responses = self.model.predict(
                inputs=[self.base_opponent_transcript.to_model_input() for _ in range(self.config.opponent_n)],
                max_new_tokens=self.speech_format.tokens_per_speech,
                debater_name=self.opposing_debater.name,
            )

            opposing_speeches = [
                quote_utils.validate_and_replace_quotes(
                    speech_content=str(opposing_response.speech), background_text=self.background_text
                )
                for opposing_response in opposing_debater_responses
            ]
        else:
            opposing_debater_responses = [None]
            opposing_speeches = [None]

        judge_inputs = []
        for speech in speeches:
            for opposing_speech in opposing_speeches:
                judge_transcript = Transcript(
                    name=self.judge.transcripts[0].name,
                    prompt=self.judge.transcripts[0].prompt,
                    speech_format=self.judge.speech_format,
                )
                if self.name == constants.DEFAULT_DEBATER_A_NAME:
                    judge_transcript.add_speech(speaker=self.name, content=speech)
                    if opposing_speech:
                        judge_transcript.add_speech(speaker=self.opposing_debater.name, content=opposing_speech)
                else:
                    if opposing_speech:
                        judge_transcript.add_speech(speaker=self.opposing_debater.name, content=opposing_speech)
                    judge_transcript.add_speech(speaker=self.name, content=speech)

                judge_inputs.append(judge_transcript.to_model_input())

        judge_model_response = self.judge.model.predict(
            inputs=judge_inputs, max_new_tokens=15, speech_structure=SpeechStructure.DECISION
        )

        split_judge_response = [
            [resp.probabilistic_decision[self.name] for resp in judge_model_response[i : i + max(self.config.opponent_n, 1)]]
            for i in range(0, len(judge_model_response), max(self.config.opponent_n, 1))
        ]
        scores = [
            min(option) if self.config.maxmin else sum(option) / max(len(option), 1) for option in split_judge_response
        ]
        selection_idx = sorted(zip(scores, range(len(model_responses))), key=lambda x: x[0], reverse=True)[0][1]
        best_model_response = model_responses[selection_idx]
        best_model_response.bon_opposing_model_responses = opposing_debater_responses

        for i, (model_response, score) in enumerate(zip(model_responses, scores)):
            model_response.preference = score
            model_response.bon_probabilistic_preferences = split_judge_response[i]
            if i != selection_idx:
                best_model_response.rejected_responses.append(model_response)

        return [best_model_response.speech], [best_model_response]

    def copy(
        self, transcripts: Optional[list[Transcript]] = None, prompts: Optional[list[Prompt] | Prompt] = None
    ) -> Debater:
        """Deepcopies the debater (except for the model, which is a shallow copy)"""
        debater = super().copy(transcripts=transcripts, prompts=prompts)
        return BestOfNDebater(
            debater=debater,
            opposing_debater=self.opposing_debater,
            judge=self.judge,
            best_of_n_config=self.config,
            background_text=self.background_text,
        )


class HumanDebater(Debater):
    def __init__(self, debater: Debater, speeches: list[SpeechData]):
        """
        A separate abstraction for a debater that uses a HumanModel.

        Params:
            debater: The underlying debater that is to be converted to a HumanDebater.
            speeches: The list of speeches from the dataset that are to be delivered when text is generated
        """
        super().__init__(
            name=debater.name,
            prompt=debater.prompts,
            model=HumanModel(
                alias=debater.model.alias, is_debater=debater.is_debater, debater_name=debater.name, speeches=speeches
            ),
            num_speeches=debater.num_speeches,
            speech_format=debater.speech_format,
            quotes_require_validation=False,
        )
