import os
import json
from typing import (
    List,
    Dict,
    Union,
)

from dotenv import load_dotenv
from openai import OpenAI

from src.batch_processor.base_batch_processor import BaseBatchProcessor
from src.schema import (
    PredictedAnswer,
    Prediction,
)
from src.utils.json import (
    read_json_file,
    read_jsonl_file,
    write_jsonl_file,
)


load_dotenv()


COMMAND = """
poetry run python src/batch_processor/openai_async_parallel.py \\
  --requests_filepath {request_jsonl} \\
  --save_filepath {save_jsonl} \\
  --request_url https://api.openai.com/v1/chat/completions \\
  --api_key {api_key} \\
  --max_requests_per_minute 5000 \\
  --max_tokens_per_minute 100000 \\
  --token_encoding_name cl100k_base \\
  --max_attempts 10 \\
  --logging_level 10
""".strip()


class OpenAIParallelProcessor(BaseBatchProcessor):
    def __init__(
        self,
        generator_name: str,
    ) -> None:
        super().__init__(generator_name=generator_name)
        self.client = OpenAI()
    
    def infer(self, request_path: str) -> Union[List, Dict]:
        orig_requests = read_jsonl_file(request_path)
        new_requests = [orig_request["body"] for orig_request in orig_requests]
        new_request_path = request_path.replace("request.jsonl", "request_async_parallel.jsonl")
        write_jsonl_file(file_path=new_request_path, data=new_requests)

        response_path = request_path.replace("request.jsonl", "response_async_parallel.jsonl")

        my_command = COMMAND.format(
            request_jsonl=new_request_path,
            save_jsonl=response_path,
            api_key=os.getenv("OPENAI_API_KEY"),
        )
        os.system(my_command)

        return {
            "response_path": response_path,
        }

    def _postprocess(self, response: str) -> str:
        return response.strip()
    
    def parse_response(
        self,
        predictions: List[Prediction],
        batch_path: str,
    ) -> List[Prediction]:
        batch = read_json_file(batch_path)
        response_path = batch["response_path"]
        responses = read_jsonl_file(response_path)
        response_map = {
            response[2]["custom_id"]: response[1]['choices'][0]["message"]["content"]
            for response in responses
        }

        gen_predictions = []
        for prediction in predictions:
            gen_pred_answers = []
            for ret_pred_answer in prediction.pred_answers:
                gen_pred_answer = PredictedAnswer(
                    answer_id=ret_pred_answer.answer_id,
                    ret_document_ids=ret_pred_answer.ret_document_ids,
                    ret_document_contents=ret_pred_answer.ret_document_contents,
                    ret_document_scores=ret_pred_answer.ret_document_scores,
                    ret_memory_ids=ret_pred_answer.ret_memory_ids,
                    ret_memory_contents=ret_pred_answer.ret_memory_contents,
                    ret_memory_scores=ret_pred_answer.ret_memory_scores,
                    raw_answer=self._postprocess(response_map[ret_pred_answer.answer_id]),
                    metadata=ret_pred_answer.metadata,
                )
                gen_pred_answers.append(gen_pred_answer)
            gen_prediction = Prediction(
                qa=prediction.qa,
                pred_answers=gen_pred_answers,
                metadata=prediction.metadata,
            )
            gen_predictions.append(gen_prediction)
        return gen_predictions
