import concurrent.futures as cf
from typing import Dict, List

import anthropic
import openai
import pandas as pd
from tqdm import tqdm


class CompletionHandler:
    def __init__(self, openai_api_key: str, anthropic_api_key: str):
        self.openai_client = openai.Client(api_key=openai_api_key)
        self.anthropic_client = anthropic.Anthropic(api_key=anthropic_api_key)

    def get_openai_completion(
        self, prompt: str, row: pd.core.series.Series, model_name: str
    ) -> Dict[str, str]:
        row_dict = row.to_dict()
        completion = self.openai_client.chat.completions.create(
            model=model_name,
            temperature=1.0,
            messages=[{"role": "user", "content": prompt}],
        )
        model_answer = completion.choices[0].message.content
        row_dict["scratchpad_response"] = model_answer
        return row_dict

    def get_claude_completion(
        self, prompt: str, row: pd.core.series.Series, model_name: str
    ) -> Dict[str, str]:
        row_dict = row.to_dict()
        completion = self.anthropic_client.completions.create(
            model=model_name,
            max_tokens_to_sample=1000,
            temperature=1.0,
            prompt=f"\n\nHuman: {prompt}\n\nAssistant:",
        )
        model_answer = completion.completion
        row_dict["scratchpad_response"] = model_answer
        return row_dict

    def get_completion_concurrent(
        self,
        base_prompt: str,
        model_name: str,
        dataset: pd.DataFrame,
        results: List[Dict[str, str]],
        max_threads: int = 30,
    ) -> pd.DataFrame:
        completion_func = (
            self.get_openai_completion
            if "gpt" in model_name
            else self.get_claude_completion
        )

        with cf.ThreadPoolExecutor(max_workers=max_threads) as ex:
            futures = [
                ex.submit(
                    completion_func,
                    base_prompt.format(
                        deployment_tag=row["deployment_tag"], question=row["inst"]
                    ),
                    row,
                    model_name,
                )
                for i, row in dataset.iterrows()
            ]

            for future in tqdm(cf.as_completed(futures)):
                results.append(future.result())

        return pd.DataFrame(results)
