import json
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional

import numpy as np
import requests

from opencompass.registry import MODELS
from opencompass.utils.logging import get_logger

from .base import BaseModel
from .base_api import TokenBucket


@MODELS.register_module()
class LightllmAPI(BaseModel):

    is_api: bool = True

    def __init__(
            self,
            path: str = 'LightllmAPI',
            url: str = 'http://localhost:8080/generate',
            meta_template: Optional[Dict] = None,
            rate_per_worker: int = 2,
            retry: int = 2,
            generation_kwargs: Optional[Dict] = dict(),
    ):

        super().__init__(path=path,
                         meta_template=meta_template,
                         generation_kwargs=generation_kwargs)
        self.logger = get_logger()
        self.url = url
        self.retry = retry
        self.generation_kwargs = generation_kwargs
        self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
        self.meta_template = meta_template
        self.token_bucket = TokenBucket(rate_per_worker, False)

    def generate(self, inputs: List[str], max_out_len: int,
                 **kwargs) -> List[str]:
        """Generate results given a list of inputs.

        Args:
            inputs (List[str]): A list of strings or PromptDicts.
                The PromptDict should be organized in OpenCompass'
                API format.
            max_out_len (int): The maximum length of the output.

        Returns:
            List[str]: A list of generated strings.
        """

        with ThreadPoolExecutor() as executor:
            results = list(
                executor.map(self._generate, inputs,
                             [self.max_out_len] * len(inputs)))
        return results

    def _generate(self, input: str, max_out_len: int) -> str:
        max_num_retries = 0
        while max_num_retries < self.retry:
            self.wait()
            header = {'content-type': 'application/json'}
            try:
                data = dict(inputs=input, parameters=self.generation_kwargs)
                raw_response = requests.post(self.url,
                                             headers=header,
                                             data=json.dumps(data))
            except requests.ConnectionError:
                self.logger.error('Got connection error, retrying...')
                continue
            try:
                response = raw_response.json()
                generated_text = response['generated_text']
                if isinstance(generated_text, list):
                    generated_text = generated_text[0]
                return generated_text
            except requests.JSONDecodeError:
                self.logger.error('JsonDecode error, got',
                                  str(raw_response.content))
            except KeyError:
                self.logger.error(f'KeyError. Response: {str(response)}')
            max_num_retries += 1

        raise RuntimeError('Calling LightllmAPI failed after retrying for '
                           f'{max_num_retries} times. Check the logs for '
                           'details.')

    def get_ppl(self, inputs: List[str], max_out_len: int,
                **kwargs) -> List[float]:
        """Generate results given a list of inputs.

        Args:
            inputs (List[str]): A list of strings or PromptDicts.
                The PromptDict should be organized in OpenCompass'
                API format.
            max_out_len (int): The maximum length of the output.

        Returns:
            List[str]: A list of generated strings.
        """

        with ThreadPoolExecutor() as executor:
            results = list(
                executor.map(self._get_ppl, inputs,
                             [self.max_out_len] * len(inputs)))
        return np.array(results)

    def _get_ppl(self, input: str, max_out_len: int) -> float:
        max_num_retries = 0
        if max_out_len is None:
            max_out_len = 1
        while max_num_retries < self.retry:
            self.wait()
            header = {'content-type': 'application/json'}
            try:
                data = dict(inputs=input, parameters=self.generation_kwargs)
                raw_response = requests.post(self.url,
                                             headers=header,
                                             data=json.dumps(data))
            except requests.ConnectionError:
                self.logger.error('Got connection error, retrying...')
                continue
            try:
                response = raw_response.json()

                assert ('prompt_token_ids' in response and 'prompt_logprobs'
                        in response), f'prompt_token_ids and prompt_logprobs \
                    must be in the output. \
                    Please consider adding \
                    --return_all_prompt_logprobs argument \
                    when starting lightllm service. Response: {str(response)}'

                prompt_token_ids = response['prompt_token_ids'][1:]
                prompt_logprobs = [
                    item[1] for item in response['prompt_logprobs']
                ]
                logprobs = [
                    item[str(token_id)] for token_id, item in zip(
                        prompt_token_ids, prompt_logprobs)
                ]
                if len(logprobs) == 0:
                    return 0.0
                ce_loss = -sum(logprobs) / len(logprobs)
                return ce_loss
            except requests.JSONDecodeError:
                self.logger.error('JsonDecode error, got',
                                  str(raw_response.content))
            max_num_retries += 1
        raise RuntimeError('Calling LightllmAPI failed after retrying for '
                           f'{max_num_retries} times. Check the logs for '
                           'details.')

    def wait(self):
        """Wait till the next query can be sent.

        Applicable in both single-thread and multi-thread environments.
        """
        return self.token_bucket.get_token()

    def get_token_len(self, prompt: str) -> int:
        """Get lengths of the tokenized string. Only English and Chinese
        characters are counted for now. Users are encouraged to override this
        method if more accurate length is needed.

        Args:
            prompt (str): Input string.

        Returns:
            int: Length of the input tokens
        """

        english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
        chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)

        # Count English words
        english_count = sum(len(part.split()) for part in english_parts)

        # Count Chinese words
        chinese_count = sum(len(part) for part in chinese_parts)

        return english_count + chinese_count
