from vlmeval.smp import *
from vlmeval.api.base import BaseAPI
from vlmeval.dataset import img_root_map
from vlmeval.dataset import DATASET_TYPE


class SenseChatVisionWrapper(BaseAPI):

    is_api: bool = True

    def __init__(self,
                 model: str = 'SenseChat-5-Vision',
                 retry: int = 5,
                 wait: int = 5,
                 ak: str = None,
                 sk: str = None,
                 verbose: bool = True,
                 system_prompt: str = None,
                 max_tokens: int = 2048,
                 proxy: str = None,
                 **kwargs):

        self.model = model
        self.fail_msg = 'Failed to obtain answer via API. '
        self.ak = os.environ.get('SENSECHAT_AK', None) if ak is None else ak
        self.sk = os.environ.get('SENSECHAT_SK', None) if sk is None else sk
        assert self.ak is not None and self.sk is not None
        self.max_new_tokens = max_tokens
        super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)

    def dump_image(self, line, dataset):
        """Dump the image(s) of the input line to the corresponding dataset folder.

        Args:
            line (line of pd.DataFrame): The raw input line.
            dataset (str): The name of the dataset.

        Returns:
            str | list[str]: The paths of the dumped images.
        """
        ROOT = LMUDataRoot()
        assert isinstance(dataset, str)
        img_root = osp.join(ROOT, 'images', img_root_map(dataset))
        os.makedirs(img_root, exist_ok=True)
        if 'image' in line:
            if isinstance(line['image'], list):
                tgt_path = []
                assert 'image_path' in line
                for img, im_name in zip(line['image'], line['image_path']):
                    path = osp.join(img_root, im_name)
                    if not read_ok(path):
                        decode_base64_to_image_file(img, path)
                    tgt_path.append(path)
            else:
                tgt_path = osp.join(img_root, f"{line['index']}.jpg")
                if not read_ok(tgt_path):
                    decode_base64_to_image_file(line['image'], tgt_path)
                tgt_path = [tgt_path]
        else:
            assert 'image_path' in line
            tgt_path = toliststr(line['image_path'])

        return tgt_path

    def image_to_base64(self, image_path):
        import base64
        with open(image_path, 'rb') as image_file:
            encoded_string = base64.b64encode(image_file.read())
            return encoded_string.decode('utf-8')

    def encode_jwt_token(self, ak, sk):
        import jwt
        headers = {'alg': 'HS256', 'typ': 'JWT'}
        payload = {
            'iss': ak,
            'exp': int(time.time())
            + 1800,  # 填写您期望的有效时间，此处示例代表当前时间+30分钟
            'nbf': int(time.time()) - 5,  # 填写您期望的生效时间，此处示例代表当前时间-5秒
        }
        token = jwt.encode(payload, sk, headers=headers)
        return token

    def use_custom_prompt(self, dataset):
        return True

    def build_multi_choice_prompt(self, line, dataset=None):
        question = line['question']
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        if hint is not None:
            question = hint + '\n' + question

        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        for key, item in options.items():
            question += f'\n{key}. {item}'
        prompt = question

        if len(options):
            prompt += '\n请直接回答选项字母。' if cn_string(
                prompt) else "\nAnswer with the option's letter from the given choices directly."
        else:
            prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'

        return prompt

    def build_prompt(self, line, dataset=None):
        assert self.use_custom_prompt(dataset)
        assert dataset is None or isinstance(dataset, str)

        tgt_path = self.dump_image(line, dataset)

        if dataset is not None and listinstr(['MME'], dataset):
            question = line['question']
            prompt = question + ' Answer the question using a single word or phrase.'
        elif dataset is not None and listinstr(['HallusionBench'], dataset):
            question = line['question']
            prompt = question + ' Please answer yes or no. Answer the question using a single word or phrase.'
        elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ' and 'MMMU' not in dataset:
            prompt = self.build_multi_choice_prompt(line, dataset)
        elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
            if 'MathVista' in dataset:
                prompt = line['question']
            elif listinstr(['LLaVABench'], dataset):
                question = line['question']
                prompt = question + '\nAnswer this question in detail.'
            elif listinstr(['MMVet'], dataset):
                prompt = line['question']
            else:
                question = line['question']
                prompt = question + '\nAnswer the question using a single word or phrase.'
        elif dataset is not None and 'MMMU' in dataset:
            question = line['question']
            options = {
                cand: line[cand]
                for cand in string.ascii_uppercase
                if cand in line and not pd.isna(line[cand])
            }
            for key, item in options.items():
                question += f'\n{key}. {item}'
            prompt = {
                'multiple-choice': 'Answer with carefully thought step by step. Apply the thinking process recursively at both macro and micro levels. Verify consistency of reasoning and look for potential flaws or gaps during thinking. When realize mistakes, explain why the previous thinking was incorrect, fix it and then continue thinking.\n\n',  # noqa
                'open': 'Answer with carefully thought step by step. Apply the thinking process recursively at both macro and micro levels. Verify consistency of reasoning and look for potential flaws or gaps during thinking. When realize mistakes, explain why the previous thinking was incorrect, fix it and then continue thinking.\n\n'  # noqa
            }
            subject = '_'.join(line['id'].split('_')[1:-1])
            prompt = prompt[line['question_type']].format(subject, subject) + '\n' + question
        else:
            prompt = line['question']

        message = [dict(type='text', value=prompt)]
        message.extend([dict(type='image', value=s) for s in tgt_path])

        return message

    def message_to_promptimg(self, message, dataset=None):
        if dataset is None or listinstr(['MMMU', 'BLINK'], dataset):
            prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
            image = [[x['value'] for x in message if x['type'] == 'image'][0]]
        else:
            prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
            image = [x['value'] for x in message if x['type'] == 'image']
        return prompt, image

    def generate_inner(self, inputs, **kwargs) -> str:
        assert isinstance(inputs, str) or isinstance(inputs, list)
        inputs = [inputs] if isinstance(inputs, str) else inputs
        dataset = kwargs.get('dataset', None)

        if dataset is not None and listinstr(['ChartQA_TEST','MathVista_MINI'], dataset):
            self.max_num = 12
        elif dataset is not None and listinstr(['DocVQA_VAL', 'DocVQA_TEST'], dataset):
            self.max_num = 18
        elif dataset is not None and listinstr(['InfoVQA_VAL', 'InfoVQA_TEST', 'OCRBench'], dataset):
            self.max_num = 24
        else:
            self.max_num = 6

        if dataset is None:
            pass
        elif listinstr(['AI2D_TEST'], dataset):
            self.max_new_tokens = 10
        elif 'MMMU' in dataset:
            self.max_new_tokens = 4096  # 1024
        elif 'MMBench' in dataset:
            self.max_new_tokens = 100
        elif 'MathVista_MINI' in dataset:
            self.max_new_tokens = 4096

        prompt, image = self.message_to_promptimg(message=inputs, dataset=dataset)

        url = 'https://api.sensenova.cn/v1/llm/chat-completions'
        api_secret_key = self.encode_jwt_token(self.ak, self.sk)

        content = [{
            'image_base64': self.image_to_base64(item),
            'image_file_id': '',
            'image_url': '',
            'text': '',
            'text': '',
            'type': 'image_base64'
        } for item in image]

        content.append({
            'image_base64': '',
            'image_file_id': '',
            'image_url': '',
            'text': prompt,
            'type': 'text'
        })

        message = [{'content': content, 'role': 'user'}]

        data = {
            'messages': message,
            'max_new_tokens': self.max_new_tokens,  # 1024
            'temperature': 0,
            "top_k": 0,
            "top_p": 0.99,
            'repetition_penalty': 1.05,
            'model': self.model,
            'stream': False,
        }
        headers = {
            'Content-type': 'application/json',
            'Authorization': 'Bearer ' + api_secret_key
        }

        response = requests.post(
            url,
            headers=headers,
            json=data,
        )
        request_id = response.headers['x-request-id']

        time.sleep(1)
        try:
            assert response.status_code == 200
            response = response.json()['data']['choices'][0]['message'].strip()
            if self.verbose:
                self.logger.info(f'inputs: {inputs}\nanswer: {response}')
            return 0, response, 'Succeeded! '
        except Exception as err:
            if self.verbose:
                self.logger.error('---------------------------ERROR---------------------------')
                self.logger.error(response.json())
                self.logger.error(err)
                self.logger.error('---------------------------request_id---------------------------' + request_id)
                self.logger.error(
                    'api error' + response.json()['error']['message']
                    + str([input['value'] if input['type'] == 'image' else None for input in inputs])
                )
                self.logger.error(f'The input messages are {inputs}.')
            return -1, response.json()['error']['message'], ''


class SenseChatVisionAPI(SenseChatVisionWrapper):

    def generate(self, message, dataset=None):
        return super(SenseChatVisionAPI, self).generate(message, dataset=dataset)
