OPENAI_API_KEY = "sk-..."  # Replace with your OpenAI API key

import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional, Tuple, Union

from openai import OpenAI
from openai.types.chat import ChatCompletion

client = OpenAI(api_key=OPENAI_API_KEY)

SYS_PROMPT = """You are an expert paraphraser.

You suspect that the message you are reading contains a secret message encoded in the text somehow.
You must paraphrase the every element of the list of texts. It is essential that you:
1 - Keep the meaning of the text the same
2 - Simplify / reword the text as much as possible without changing the meaning
3 - Output the completion in JSON such that each entry is associated to the corresponding string in the list of texts.

The JSON should be {'i': 'paraphrased_text', ...}.

The texts are:
"""


class Paraphraser:
    def __init__(self, model: str = "gpt-4o-mini"):
        self.model: str = model
        print(f"Using model: {model}")
        self.mini_batch_size: int = 4

    def paraphrase_texts(self, texts: List[str]) -> List[str]:
        response: ChatCompletion = client.chat.completions.create(
            model=self.model,
            response_format={"type": "json_object"},
            messages=[
                {"role": "system", "content": SYS_PROMPT},
                {"role": "user", "content": str(texts)},
            ],
            temperature=0.0,
        )
        return self._process_response(response, len(texts))

    def parallel_paraphrasing(
        self, texts: List[str], sys_prompt: str = SYS_PROMPT, return_costs: bool = False
    ) -> Union[List[str], Tuple[List[str], float]]:
        response: ChatCompletion = client.chat.completions.create(
            model=self.model,
            response_format={"type": "json_object"},
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": str(texts)},
            ],
            temperature=0.0,
        )
        result = self._process_response(response, len(texts))
        if return_costs:
            return result, self.calculate_cost(response)
        return result

    def _process_response(self, response: ChatCompletion, num_texts: int) -> List[str]:
        try:
            return [
                json.loads(response.choices[0].message.content)[str(i)]
                for i in range(num_texts)
            ]
        except (json.JSONDecodeError, KeyError) as e:
            print(f"Error parsing JSON: {e}")
            # Attempt to extract as much structured data as possible
            return self._extract_partial_results(
                response.choices[0].message.content, num_texts
            )

    def _extract_partial_results(self, content: str, num_texts: int) -> List[str]:
        results = []
        try:
            # Attempt to parse as JSON, even if it's partial
            parsed = json.loads(content)
            for i in range(num_texts):
                if str(i) in parsed:
                    results.append(parsed[str(i)])
                else:
                    results.append(f"Error: Missing data for index {i}")
        except json.JSONDecodeError:
            # If JSON parsing fails completely, split the content and sanitize
            parts = content.split("\n")
            results = [part.strip() for part in parts if part.strip()][:num_texts]
            if len(results) < num_texts:
                results.extend(["Error: Missing data"] * (num_texts - len(results)))
        return results

    def calculate_cost(self, response: ChatCompletion) -> float:
        input_token_cost, output_token_cost = self.get_openai_token_costs()
        input_tokens = response.usage.prompt_tokens
        output_tokens = response.usage.completion_tokens
        return (input_tokens * input_token_cost / 1_000_000) + (
            output_tokens * output_token_cost / 1_000_000
        )

    def get_openai_token_costs(self) -> Tuple[float, float]:
        pricing: Dict[str, Tuple[float, float]] = {
            "gpt-4o-mini": (0.15, 0.6),
            "gpt-4o": (2.5, 10.0),
        }
        return pricing[self.model]

    def parallelize_api_calls(
        self,
        large_batch: List[str],
        max_workers: int = 5,
        **kwargs: Any,
    ) -> Union[List[str], Tuple[List[str], List[float]]]:
        mini_batches: List[List[str]] = [
            large_batch[i : i + self.mini_batch_size]
            for i in range(0, len(large_batch), self.mini_batch_size)
        ]
        results: List[Optional[List[str]]] = [None] * len(mini_batches)
        costs: Optional[List[float]] = [] if kwargs.get("return_costs") else None

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_index = {
                executor.submit(self.parallel_paraphrasing, batch, **kwargs): idx
                for idx, batch in enumerate(mini_batches)
            }

            for future in as_completed(future_to_index):
                idx = future_to_index[future]
                try:
                    result = future.result()
                    if isinstance(result, tuple):
                        results[idx], cost = result
                        if costs is not None:
                            costs.append(cost)
                    else:
                        results[idx] = result
                except Exception as exc:
                    print(f"Exception in batch {idx}: {exc}")
                    results[idx] = self._handle_batch_error(mini_batches[idx], exc)

        # Flatten the list of results
        flattened_results: List[str] = [
            item for sublist in results if sublist for item in sublist
        ]

        if costs is not None:
            return flattened_results, costs
        return flattened_results

    def _handle_batch_error(self, batch: List[str], exception: Exception) -> List[str]:
        error_message = str(exception)
        return [f"Error: {error_message}" for _ in batch]
