import os
from typing import (
    List,
    Dict,
    Union,
)

from dotenv import load_dotenv

import anthropic
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request

from src.batch_processor.base_batch_processor import BaseBatchProcessor
from src.schema import (
    PredictedAnswer,
    Prediction,
)
from src.utils.json import (
    read_json_file,
    read_jsonl_file,
)


class ClaudeBatchProcessor(BaseBatchProcessor):
    def __init__(self, generator_name: str) -> None:
        super().__init__(generator_name=generator_name)
        load_dotenv()
        self.client = anthropic.Anthropic(api_key=os.getenv("CLAUDE_API_KEY"))
    
    def infer(self, request_path: str) -> Union[List, Dict]:
        requests = read_jsonl_file(file_path=request_path)
        batch = self.client.messages.batches.create(
            requests=[
                Request(
                    custom_id=request["custom_id"],
                    params=MessageCreateParamsNonStreaming(
                        model=request["body"]["model"],
                        max_tokens=request["body"]["max_tokens"],
                        temperature=request["body"]["temperature"],
                        messages=request["body"]["messages"],
                    )
                )
                for request in requests
            ]
        )
        return {
            "batch_id": batch.id,
        }

    def _postprocess(self, response: str) -> str:
        return response.strip()
    
    def parse_response(
        self,
        predictions: List[Prediction],
        batch_path: str,
    ) -> List[Prediction]:
        batch = read_json_file(file_path=batch_path)
        batch_id = batch["batch_id"]

        message_batch = self.client.messages.batches.retrieve(batch_id)

        if message_batch.processing_status != "completed":
            raise ValueError("Batch processing is not completed yet!")

        response_map = {}
        for result in self.client.messages.batches.results(batch_id):
            match result.result.type:
                case "succeeded":
                    # TODO: fill in the response_map
                    ...
                case "errored":
                    if result.result.error.type == "invalid_request":
                        raise ValueError(f"Validation error {result.custom_id}")
                    else:
                        raise ValueError(f"Server error {result.custom_id}")
                case "expired":
                    raise ValueError(f"Request expired {result.custom_id}")

        gen_predictions = []
        for prediction in predictions:
            gen_pred_answers = []
            for ret_pred_answer in prediction.pred_answers:
                gen_pred_answer = PredictedAnswer(
                    answer_id=ret_pred_answer.answer_id,
                    ret_document_ids=ret_pred_answer.ret_document_ids,
                    ret_document_contents=ret_pred_answer.ret_document_contents,
                    ret_document_scores=ret_pred_answer.ret_document_scores,
                    ret_memory_ids=ret_pred_answer.ret_memory_ids,
                    ret_memory_contents=ret_pred_answer.ret_memory_contents,
                    ret_memory_scores=ret_pred_answer.ret_memory_scores,
                    raw_answer=self._postprocess(response_map[ret_pred_answer.answer_id]),
                    metadata=ret_pred_answer.metadata,
                )
                gen_pred_answers.append(gen_pred_answer)
            gen_prediction = Prediction(
                qa=prediction.qa,
                pred_answers=gen_pred_answers,
                metadata=prediction.metadata,
            )
            gen_predictions.append(gen_prediction)
        return gen_predictions
