from typing import (
    List,
    Dict,
    Union,
)

from tqdm import tqdm

from src.batch_processor.base_batch_processor import BaseBatchProcessor
from src.generator import generator_factory
from src.schema import (
    PredictedAnswer,
    Prediction,
)
from src.utils.json import (
    read_json_file,
    read_jsonl_file,
)


class MockBatchProcessor(BaseBatchProcessor):
    def __init__(self, generator_name: str) -> None:
        super().__init__(generator_name=generator_name)
    
    def infer(self, request_path: str) -> Union[List, Dict]:
        requests = read_jsonl_file(file_path=request_path)
        assert len(requests) > 0

        temperature = requests[0]["body"]["temperature"]
        self.generator = generator_factory(
            generator_name=self.generator_name,
            temperature=temperature
        )
        
        batch = []
        for request in tqdm(requests):
            prompt = request["body"]["messages"][0]["content"]
            response = self.generator.generate(prompt=prompt)
            request["response"] = response
            batch.append(request)
        return batch

    def _postprocess(self, response: str) -> str:
        return response.strip()
    
    def parse_response(
        self,
        predictions: List[Prediction],
        batch_path: str,
    ) -> List[Prediction]:
        batch = read_json_file(file_path=batch_path)
        response_map = {request["custom_id"]: request["response"] for request in batch}
        
        gen_predictions = []
        for prediction in predictions:
            gen_pred_answers = []
            for ret_pred_answer in prediction.pred_answers:
                gen_pred_answer = PredictedAnswer(
                    answer_id=ret_pred_answer.answer_id,
                    ret_document_ids=ret_pred_answer.ret_document_ids,
                    ret_document_contents=ret_pred_answer.ret_document_contents,
                    ret_document_scores=ret_pred_answer.ret_document_scores,
                    ret_memory_ids=ret_pred_answer.ret_memory_ids,
                    ret_memory_contents=ret_pred_answer.ret_memory_contents,
                    ret_memory_scores=ret_pred_answer.ret_memory_scores,
                    raw_answer=self._postprocess(response_map[ret_pred_answer.answer_id]),
                    metadata=ret_pred_answer.metadata,
                )
                gen_pred_answers.append(gen_pred_answer)
            gen_prediction = Prediction(
                qa=prediction.qa,
                pred_answers=gen_pred_answers,
                metadata=prediction.metadata,
            )
            gen_predictions.append(gen_prediction)
        return gen_predictions
    

class OpenAIMockBatchProcessor(MockBatchProcessor):
    def __init__(self, generator_name: str) -> None:
        super().__init__(generator_name=generator_name)


class ClaudeMockBatchProcessor(MockBatchProcessor):
    def __init__(self, generator_name: str) -> None:
        super().__init__(generator_name=generator_name)


class GeminiMockBatchProcessor(MockBatchProcessor):
    def __init__(self, generator_name: str) -> None:
        super().__init__(generator_name=generator_name)


class OllamaMockBatchProcessor(MockBatchProcessor):
    def __init__(self, generator_name: str) -> None:
        super().__init__(generator_name=generator_name)
    

class FireworksMockBatchProcessor(MockBatchProcessor):
    def __init__(self, generator_name: str) -> None:
        super().__init__(generator_name=generator_name)
