import json
from typing import (
    List,
    Dict,
    Union,
)

from openai import OpenAI

from src.batch_processor.base_batch_processor import BaseBatchProcessor
from src.schema import (
    PredictedAnswer,
    Prediction,
)
from src.utils.json import (
    read_json_file,
)


class OpenAIBatchProcessor(BaseBatchProcessor):
    def __init__(
        self,
        generator_name: str,
    ) -> None:
        super().__init__(generator_name=generator_name)
        self.client = OpenAI()
    
    def infer(self, request_path: str) -> Union[List, Dict]:
        batch_input_file = self.client.files.create(
            file=open(request_path, "rb"),  # jsonl required
            purpose="batch"
        )
        batch_input_file_id = batch_input_file.id
        batch = self.client.batches.create(
            input_file_id=batch_input_file_id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
            metadata={
                "description": "nightly eval job"
            }
        )
        return {
            "batch_id": batch.id,
        }

    def _postprocess(self, response: str) -> str:
        return response.strip()
    
    def parse_response(
        self,
        predictions: List[Prediction],
        batch_path: str,
    ) -> List[Prediction]:
        batch = read_json_file(batch_path)
        batch_obj = self.client.batches.retrieve(batch["batch_id"])

        if batch_obj.status != "completed":
            raise ValueError("Batch processing is not completed yet!")

        output_file_id = batch_obj.output_file_id
        file_response = self.client.files.content(output_file_id)
        file_response_text = file_response.text
        file_response_json_list = [
            json.loads(l.strip())
            for l in file_response_text.split("\n")
            if l.strip() != ""
        ]
        
        response_map = {
            json_obj["custom_id"]: json_obj ["response"]["body"]["choices"][0]["message"]["content"]
            for json_obj in file_response_json_list
        }

        gen_predictions = []
        for prediction in predictions:
            gen_pred_answers = []
            for ret_pred_answer in prediction.pred_answers:
                gen_pred_answer = PredictedAnswer(
                    answer_id=ret_pred_answer.answer_id,
                    ret_document_ids=ret_pred_answer.ret_document_ids,
                    ret_document_contents=ret_pred_answer.ret_document_contents,
                    ret_document_scores=ret_pred_answer.ret_document_scores,
                    ret_memory_ids=ret_pred_answer.ret_memory_ids,
                    ret_memory_contents=ret_pred_answer.ret_memory_contents,
                    ret_memory_scores=ret_pred_answer.ret_memory_scores,
                    raw_answer=self._postprocess(response_map[ret_pred_answer.answer_id]),
                    metadata=ret_pred_answer.metadata,
                )
                gen_pred_answers.append(gen_pred_answer)
            gen_prediction = Prediction(
                qa=prediction.qa,
                pred_answers=gen_pred_answers,
                metadata=prediction.metadata,
            )
            gen_predictions.append(gen_prediction)
        return gen_predictions
 