import asyncio
import hashlib
import json
import logging
import os
import random
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from enum import Enum
from pathlib import Path
from typing import Optional

import diskcache as dc
import orjson
import tiktoken
from openai import AsyncOpenAI, OpenAI
from openai.types import CompletionUsage
from tenacity import (
    after_log,
    before_log,
    retry,
    retry_if_exception,
    stop_after_attempt,
    wait_exponential,
)
from tqdm import tqdm

from recaption.llm_provider import Provider

logger = logging.getLogger(__name__)

def retry_condition(e: Exception):
    return (
        isinstance(e, json.decoder.JSONDecodeError)
        or (
            isinstance(e, AttributeError)
            and "'NoneType' object has no attribute 'usage'" in str(e)
        )
        or "Rate limit reached for model" in str(e)
    )


class LLMClient:
    def __init__(
        self,
        provider: Provider = Provider.OPENAI,
        cache_dir: str = "cache",
        key: Optional[str] = None,
        system_content: Optional[str] = None,
    ):
        self.provider = provider
        self.api_key = key or self.provider.get_api_key()
        assert self.api_key is not None, (
            f"API key not found for provider {self.provider}"
        )
        self.client = OpenAI(api_key=self.api_key, base_url=self.provider.api_base_url)
        self.async_client = AsyncOpenAI(
            api_key=self.api_key, base_url=self.provider.api_base_url
        )
        self.cache = dc.Cache(cache_dir)
        self.usage_cache = dc.Cache(cache_dir + "_usage")
        self.system_content = system_content
        self.total_input_tokens = {}
        self.total_output_tokens = {}

    @retry(
        stop=stop_after_attempt(5),
        wait=wait_exponential(multiplier=1, min=1, max=120),
        retry=retry_if_exception(retry_condition),
        before=before_log(logger, logging.DEBUG),
        after=after_log(logger, logging.DEBUG),
    )
    def request_completion_with_retry(
        self,
        prompt: str | list[dict],
        num_samples: int,
        model: Optional[Enum] = None,
        temperature: float = 0.7,
        max_tokens: int = 800,
        top_p: float = 1,
    ) -> list[str]:
        try:
            max_token_key = (
                "max_completion_tokens"
                if model.requires_max_completion_tokens
                else "max_tokens"
            )
            generation_kwargs = {
                "temperature": temperature,
                max_token_key: max_tokens,
                "top_p": top_p,
            }
            if model.unsupported_kwargs:
                for kw in model.unsupported_kwargs:
                    generation_kwargs.pop(kw, None)
            response = self.client.chat.completions.create(
                model=model.value,
                messages=self._format_chat_input(prompt),
                n=num_samples,
                **generation_kwargs,
            )
            self.update_usage(model.value, response.usage)
            new_samples = [c.message.content for c in response.choices]
            self.add_samples_to_cache(
                prompt=prompt,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                samples=new_samples,
            )
            return new_samples
        except Exception as e:
            if "Bad Request" in str(e):
                logger.error(f"Bad Request, skipping. Exception: {e}")
            raise e

    @retry(
        stop=stop_after_attempt(5),
        wait=wait_exponential(multiplier=1, min=1, max=120),
        retry=retry_if_exception(retry_condition),
        before=before_log(logger, logging.DEBUG),
        after=after_log(logger, logging.DEBUG),
    )
    async def request_async_completion_with_retry(
        self,
        prompt: str | list[dict],
        num_samples: int,
        model: Optional[Enum] = None,
        temperature: float = 0.7,
        max_tokens: int = 800,
        top_p: float = 1,
    ) -> list[str]:
        try:
            max_token_key = (
                "max_completion_tokens"
                if model.requires_max_completion_tokens
                else "max_tokens"
            )
            generation_kwargs = {
                "temperature": temperature,
                max_token_key: max_tokens,
                "top_p": top_p,
            }
            if model.unsupported_kwargs:
                for kw in model.unsupported_kwargs:
                    generation_kwargs.pop(kw, None)
            response = await self.async_client.chat.completions.create(
                model=model.value,
                messages=self._format_chat_input(prompt),
                n=num_samples,
                **generation_kwargs,
            )
            self.update_usage(model.value, response.usage)
            new_samples = [c.message.content for c in response.choices]
            self.add_samples_to_cache(
                prompt=prompt,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                samples=new_samples,
            )
            return new_samples
        except Exception as e:
            if "Bad Request" in str(e):
                logger.error(f"Bad Request, skipping. Exception: {e}")
            raise e

    async def request_async_completion_with_error_suppression(
        self,
        prompt: str | list[dict],
        num_samples: int,
        model: Optional[Enum] = None,
        temperature: float = 0.7,
        max_tokens: int = 800,
        top_p: float = 1,
    ) -> tuple[list[str], Optional[Exception]]:
        try:
            completions = await self.request_async_completion_with_retry(
                prompt=prompt,
                num_samples=num_samples,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
            )
            return completions, None
        except Exception as e:
            logger.error(
                f"Error requesting completion: {{type: {type(e)}, message: {str(e)}}}"
            )
            return [], e

    def generate(
        self,
        prompt: str,
        num_samples: int,
        model: Optional[Enum] = None,
        temperature: float = 0.7,
        max_tokens: int = 800,
        top_p: float = 1,
        ignore_cache_samples: bool = False,
        expand_n_completions: bool = False,
    ) -> list[str]:
        model = self._check_model_name(model)
        expand_n_completions = (
            expand_n_completions or not self.provider.supports_multiple_completions
        )
        if ignore_cache_samples:
            cached_samples = []
        else:
            cached_samples = self.get_samples_from_cache(
                prompt=prompt,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
            )
        if len(cached_samples) > num_samples:
            return random.sample(cached_samples, num_samples)
        if len(cached_samples) == num_samples:
            return cached_samples
        remaining_samples = num_samples - len(cached_samples)
        if expand_n_completions:
            new_samples = []
            for _ in range(remaining_samples):
                new_sample = self.request_completion_with_retry(
                    prompt=prompt,
                    num_samples=1,
                    model=model,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    top_p=top_p,
                )[0]
                new_samples.append(new_sample)
        else:
            new_samples = self.request_completion_with_retry(
                prompt=prompt,
                num_samples=remaining_samples,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
            )
        return cached_samples + new_samples

    async def generate_async(
        self,
        prompt: str,
        num_samples: int,
        model: Optional[Enum] = None,
        temperature: float = 0.7,
        max_tokens: int = 800,
        top_p: float = 1,
        ignore_cache_samples: bool = False,
        expand_n_completions: bool = False,
    ) -> list[str]:
        model = self._check_model_name(model)
        expand_n_completions = (
            expand_n_completions or not self.provider.supports_multiple_completions
        )
        if ignore_cache_samples:
            cached_samples = []
        else:
            cached_samples = self.get_samples_from_cache(
                prompt=prompt,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
            )
        # if we have enough cached samples, return them
        if len(cached_samples) > num_samples:
            return random.sample(cached_samples, num_samples)
        if len(cached_samples) == num_samples:
            return cached_samples
        # else, generate the remaining samples
        remaining_samples = num_samples - len(cached_samples)
        if expand_n_completions:
            tasks = []
            for _ in range(remaining_samples):
                tasks.append(
                    self.request_async_completion_with_retry(
                        prompt=prompt,
                        num_samples=1,
                        model=model,
                        temperature=temperature,
                        max_tokens=max_tokens,
                        top_p=top_p,
                    )
                )
            results = await asyncio.gather(*tasks)
            new_samples = [r[0] for r in results]
        else:
            new_samples = await self.request_async_completion_with_retry(
                prompt=prompt,
                num_samples=remaining_samples,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
            )
        return cached_samples + new_samples

    async def generate_async_with_error_suppression(
        self,
        prompt: str,
        num_samples: int,
        model: Optional[Enum] = None,
        temperature: float = 0.7,
        max_tokens: int = 800,
        top_p: float = 1,
        ignore_cache_samples: bool = False,
        expand_n_completions: bool = False,
    ) -> tuple[list[str], list[Exception]]:
        model = self._check_model_name(model)
        expand_n_completions = (
            expand_n_completions or not self.provider.supports_multiple_completions
        )
        if ignore_cache_samples:
            cached_samples = []
        else:
            cached_samples = self.get_samples_from_cache(
                prompt=prompt,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
            )
        if len(cached_samples) > num_samples:
            return random.sample(cached_samples, num_samples)
        if len(cached_samples) == num_samples:
            return cached_samples
        remaining_samples = num_samples - len(cached_samples)
        new_samples = []
        errors = []
        if expand_n_completions:
            tasks = []
            for _ in range(remaining_samples):
                tasks.append(
                    self.request_async_completion_with_error_suppression(
                        prompt=prompt,
                        num_samples=1,
                        model=model,
                        temperature=temperature,
                        max_tokens=max_tokens,
                        top_p=top_p,
                    )
                )
            results = await asyncio.gather(*tasks)
            for completions, error in results:
                if error is not None:
                    errors.append(error)
                else:
                    new_samples.append(completions[0])
        else:
            (
                new_samples,
                error,
            ) = await self.request_async_completion_with_error_suppression(
                prompt=prompt,
                num_samples=remaining_samples,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
            )
            if error is not None:
                errors.append(error)
        completions = cached_samples + new_samples
        return completions, errors

    async def batch_generate_async(
        self,
        prompts: list[str] | list[list[dict]],
        num_samples: int,
        model: Optional[Enum] = None,
        temperature: float = 0.7,
        max_tokens: int = 800,
        top_p: float = 1,
        ignore_cache_samples: bool = False,
        expand_n_completions: bool = False,
        batch_size: Optional[int] = None,
        progress_file: Optional[Path] = "batch_generate_progress.json",
    ) -> tuple[list[list[str]], list[list[Exception]]]:
        model = self._check_model_name(model)
        expand_n_completions = (
            expand_n_completions or not self.provider.supports_multiple_completions
        )
        results = [None] * len(prompts)
        result_idxs = []
        tasks = []
        for i, prompt in enumerate(prompts):
            cached_samples = self.get_samples_from_cache(
                prompt=prompt,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
            )
            if num_samples < len(cached_samples):
                results[i] = random.sample(cached_samples, num_samples)
                continue
            elif num_samples == len(cached_samples):
                results[i] = cached_samples
                continue
            else:
                results[i] = cached_samples or []
            remaining_samples = num_samples - len(cached_samples)
            if expand_n_completions:
                for _ in range(remaining_samples):
                    result_idxs.append(i)
                    tasks.append(
                        self.request_async_completion_with_error_suppression(
                            prompt=prompt,
                            num_samples=1,
                            model=model,
                            temperature=temperature,
                            max_tokens=max_tokens,
                            top_p=top_p,
                        )
                    )
            else:
                result_idxs.append(i)
                tasks.append(
                    self.generate_async_with_error_suppression(
                        prompt=prompt,
                        num_samples=remaining_samples,
                        model=model,
                        temperature=temperature,
                        max_tokens=max_tokens,
                        top_p=top_p,
                        ignore_cache_samples=ignore_cache_samples,
                        expand_n_completions=expand_n_completions,
                    )
                )

        logger.info(f"Submitting {len(tasks)} generation requests...")
        if batch_size is None:
            completed_tasks = await asyncio.gather(*tasks)
        else:
            logger.info(f"Batching generation requests (batch_size={batch_size})...")
            completed_tasks = []
            partial_record = {"completions": [], "errors": []}
            for start_idx in tqdm(range(0, len(tasks), batch_size)):
                end_idx = min(start_idx + batch_size, len(tasks))
                batch = tasks[start_idx:end_idx]
                try:
                    completed_batch = await asyncio.gather(*batch)
                    completed_tasks.extend(completed_batch)
                    for completion_list, errors in completed_batch:
                        partial_record["completions"].extend(completion_list)
                        if isinstance(errors, list):
                            for error in errors:
                                partial_record["errors"].append(str(error))
                        elif errors is not None:
                            partial_record["errors"].append(str(errors))
                    with open(progress_file, "wb") as f:
                        f.write(orjson.dumps(partial_record))
                except Exception as e:
                    logger.error(
                        f"Error while generating batch {start_idx}-{end_idx}: {e}"
                    )
                    error_batch = [[None] * num_samples for _ in range(len(batch))]
                    completed_tasks.extend(error_batch)

        new_samples = []
        error_list = []
        for _ in range(len(prompts)):
            new_samples.append([])
            error_list.append([])
        for idx, (samples, errors) in zip(result_idxs, completed_tasks):
            new_samples[idx].extend(samples)
            if isinstance(errors, list):
                error_list[idx].extend(errors)
            elif errors is not None:
                error_list[idx].append(errors)

        # combine new samples with cached samples
        for i, samples in enumerate(new_samples):
            if not samples:
                continue
            results[i].extend(samples)
        return results, error_list

    def multithreaded_expanded_request_generate(
        self,
        prompts: list[str],
        num_samples: int,
        model: Optional[str] = None,
        temperature: float = 0.7,
        max_tokens: int = 800,
        top_p: float = 1,
        num_workers: int = 8,
        save_every: int = 100,
        output_dir: str = "output",
        ignore_cache_samples: bool = False,
    ) -> list[list[str]]:
        """use concurrent futures to generate samples in parallel"""

        model = self._check_model_name(model)
        with ThreadPoolExecutor(max_workers=num_workers) as executor:
            results = [[] for _ in range(len(prompts))]
            futures = []
            logger.info("submitting generation requests to job queue...")
            for i, prompt in enumerate(prompts):
                if ignore_cache_samples:
                    cached_samples = []
                else:
                    cached_samples = self.get_samples_from_cache(
                        prompt=prompt,
                        model=model,
                        temperature=temperature,
                        max_tokens=max_tokens,
                        top_p=top_p,
                    )
                if num_samples < len(cached_samples):
                    results[i] = random.sample(cached_samples, num_samples)
                    continue
                elif num_samples == len(cached_samples):
                    results[i] = cached_samples
                    continue
                else:
                    results[i] = cached_samples

                # request the remaining samples
                remaining_samples = num_samples - len(cached_samples)
                for _ in range(remaining_samples):
                    future = executor.submit(
                        self.request_completion_with_retry,
                        prompt=prompt,
                        num_samples=1,
                        model=model,
                        temperature=temperature,
                        max_tokens=max_tokens,
                        top_p=top_p,
                    )
                    future.idx = i
                    futures.append(future)
            logger.info(f"submitted {len(futures)} generation requests to job queue.")

            completions_file = os.path.join(output_dir, "completions.json")
            logger.info("collecting returning generation responses")
            for i, future in tqdm(
                enumerate(as_completed(futures)),
                total=len(futures),
                desc="collecting responses",
            ):
                try:
                    samples = future.result()
                    results[future.idx].append(samples[0])
                    self.add_samples_to_cache(
                        prompt=prompts[future.idx],
                        model=model,
                        temperature=temperature,
                        max_tokens=max_tokens,
                        top_p=top_p,
                        samples=samples,
                    )
                except json.decoder.JSONDecodeError as e:
                    logger.error(f"Error decoding JSON: {e}")
                    results[future.idx].append(None)
                if i % save_every == 0:
                    with open(completions_file, "w") as f:
                        json.dump(results, f)
            return results

    def get_samples_from_cache(
        self,
        prompt: str | list[dict],
        model: Enum,
        temperature: float,
        max_tokens: int,
        top_p: float,
    ) -> list[str]:
        cache_key = self._hash_prompt(
            prompt, model.value, temperature, max_tokens, top_p
        )
        return self.cache.get(cache_key, [])

    def add_samples_to_cache(
        self,
        prompt: str,
        model: Enum,
        temperature: float,
        max_tokens: int,
        top_p: float,
        samples: list[str],
    ) -> None:
        cache_key = self._hash_prompt(
            prompt, model.value, temperature, max_tokens, top_p
        )
        cached_samples = self.cache.get(cache_key, [])
        samples = [s for s in samples if s and s.strip()]
        cached_samples.extend(samples)
        self.cache[cache_key] = cached_samples

    def count_tokens_in_prompt(self, prompt: str | list[str], engine: str) -> int:
        encoding = tiktoken.encoding_for_model(engine)
        if isinstance(prompt, str):
            prompt = [prompt]
        return sum(len(encoding.encode(p)) for p in prompt)

    def update_usage(self, engine: str, usage: dict[str, int]):
        self.total_input_tokens[engine] = (
            self.total_input_tokens.get(engine, 0) + usage.prompt_tokens
        )
        self.total_output_tokens[engine] = (
            self.total_output_tokens.get(engine, 0)
            + usage.total_tokens
            - usage.prompt_tokens
        )
        self.add_usage_to_cache(engine, usage)

    def get_token_usage(self, engine: str) -> tuple[int, int]:
        i = self.total_input_tokens.get(engine, 0)
        o = self.total_output_tokens.get(engine, 0)
        return i, o

    def add_usage_to_cache(self, engine: str, usage: CompletionUsage):
        cur = self.usage_cache.get(engine, {"input_tokens": 0, "output_tokens": 0})
        self.usage_cache[engine] = {
            "input_tokens": cur["input_tokens"] + usage.prompt_tokens,
            "output_tokens": cur["output_tokens"]
            + usage.total_tokens
            - usage.prompt_tokens,
        }

    def get_usage_from_cache(self, model: Enum):
        return self.usage_cache.get(model, {"input_tokens": 0, "output_tokens": 0})

    def _hash_prompt(
        self,
        messages: str | list[dict],
        model: Enum,
        temperature: float,
        max_tokens: int,
        top_p: float,
    ) -> str:
        if isinstance(messages, str):
            prompt = messages
        else:
            prompt = orjson.dumps(messages).decode("utf-8")
        hash_input = f"{prompt}-{model}-{temperature}-{max_tokens}-{top_p}".encode()
        return hashlib.md5(hash_input).hexdigest()

    def _hash_embedding(self, text: str, model: Enum) -> str:
        hash_input = f"{model}-{text}".encode()
        return hashlib.md5(hash_input).hexdigest()

    def _format_chat_input(self, text_input: str | list[dict]) -> list[dict]:
        if isinstance(text_input, str):
            if self.system_content is None:
                return [{"role": "user", "content": text_input}]
            else:
                return [
                    {"role": "system", "content": self.system_content},
                    {"role": "user", "content": text_input},
                ]
        return text_input

    def _check_model_name(self, model: Optional[Enum]) -> Enum:
        model_enum = self.provider.models
        if model is None:
            model = list(model_enum)[0]
            logger.info(f"Model name not provided, using default {model}")
        elif not isinstance(model, model_enum):
            raise ValueError(
                f"Model {model} is not available for provider {self.provider}"
            )
        return model

    def _sleep_and_return_input(self, x, t):
        """helper function to test concurrent processing"""
        time.sleep(t)
        return