import json
import os
import os.path as osp
import re
import subprocess
import tempfile
import time
from shutil import copyfile

from datasets import Dataset

from opencompass.datasets.base import BaseDataset
from opencompass.datasets.humaneval import humaneval_postprocess_v2
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_data_path

_LANGUAGE_NAME_DICT = {
    'java': 'Java',
    'javascript': 'JavaScript',
    'js': 'JavaScript',
    'python': 'Python',
}


@LOAD_DATASET.register_module()
class PMMEvalHumanEvalXLDataset(BaseDataset):

    @staticmethod
    def load(path: str, lang: str, program_lang: str):
        data_path = get_data_path(path)

        if os.environ.get('DATASET_SOURCE') == 'ModelScope':
            from modelscope import MsDataset
            dataset = MsDataset.load(dataset_name=data_path,
                                     subset_name='humaneval-xl',
                                     split=f'test/{program_lang}/{lang}')
        else:
            dataset = list()
            filename = os.path.join(
                data_path, f'humaneval-xl/test/{program_lang}/{lang}.jsonl')
            with open(filename, mode='r', encoding='utf-8') as f:
                for line in f:
                    line = json.loads(line.strip())
                    dataset.append(line)
            dataset = Dataset.from_list(dataset)

        return dataset


class PMMEvalHumanEvalXLEvaluator(BaseEvaluator):

    def __init__(self,
                 language,
                 ip_address='localhost',
                 text_language='',
                 port='',
                 retry=2,
                 timeout=600) -> None:
        assert language in _LANGUAGE_NAME_DICT.keys(), (
            f'language must be in {list(_LANGUAGE_NAME_DICT.keys())}')
        if language == 'rust':
            timeout *= 10  # rust need more time
        self.language = language
        self.text_language = text_language
        self.ip_address = ip_address
        self.port = port
        self.retry = retry
        self.timeout = timeout
        super().__init__()

    def score(self, predictions, references):
        predictions = [{
            'task_id':
            f'{_LANGUAGE_NAME_DICT[self.language]}/{i}',
            'generation':
            _clean_up_code(pred, self.language, refer),
        } for i, (pred, refer) in enumerate(zip(predictions, references))]
        with tempfile.TemporaryDirectory() as tmp_dir:
            tmp_out_path = osp.join(
                tmp_dir,
                f'humanevalx_{self.language}_{self.text_language}.json')
            with open(tmp_out_path, 'w') as f:
                for pred in predictions:
                    f.write(json.dumps(pred) + '\n')

            num_retry = 0
            while num_retry < self.retry:
                succeed, output = self._code_eval_service(
                    file_path=tmp_out_path)
                if not succeed and '(56) Recv failure' in output:
                    # only retry when connection failed
                    num_retry += 1
                    # wait a min in case the service load is too high
                    time.sleep(60)
                else:
                    break

            if succeed:
                if isinstance(output, str):
                    return json.loads(output)
                elif isinstance(output, dict):
                    return output

            ref_url = 'https://opencompass.readthedocs.io/en/latest/advanced_guides/code_eval_service.html'  # noqa
            if hasattr(self, '_out_dir'):
                result_file_path = re.sub('results', 'mid_results',
                                          self._out_dir) + '.json'  # noqa
                if not osp.exists(osp.dirname(result_file_path)):
                    os.makedirs(osp.dirname(result_file_path))
            else:
                result_file_path = os.path.join(
                    'outputs', f'humanevalx_{self.language}.json')
            copyfile(tmp_out_path, result_file_path)
            raise Exception(
                f'Call CodeEvalService Error in `HumanevalXEvaluator`, The '
                f"results have been saved in path '{result_file_path}', You "
                'need to check that your code evaluate service is launched and'
                f' the network to service is connected, you can also get '
                f'results directly by using `curl` command refer to {ref_url}.'
                f'\nError Information: {output}')

    def _code_eval_service(self, file_path):
        if self.port:
            eval_server_url = f'{self.ip_address}:{self.port}/evaluate'
        else:
            eval_server_url = f'{self.ip_address}/evaluate'
        exec_result = subprocess.run([
            'curl', '-X', 'POST', '-F', f'file=@{file_path}', '-F',
            f'dataset=humanevalx/{self.language}', f'{eval_server_url}'
        ],
                                     timeout=self.timeout,
                                     capture_output=True)
        if exec_result.returncode == 0 and re.match(
                "\"{.*:.*}\"", exec_result.stdout.decode('utf-8')):
            return True, json.loads(exec_result.stdout.decode('utf-8'))
        else:
            if exec_result.stderr:
                try:
                    err = exec_result.stderr.decode()
                except Exception:
                    err = exec_result.stderr
            else:
                try:
                    err = exec_result.stdout.decode()
                except Exception:
                    err = exec_result.stdout
            return False, err


def _clean_up_code(text: str, language_type: str, reference) -> str:
    """Cleans up the generated code."""
    try:
        # for chatGLM related text
        eval_text = eval(text)
    except Exception:
        pass
    else:
        if isinstance(eval_text, str):
            text = eval_text
    # extract code from code block
    text = text.lstrip('\n')
    if '```' in text:
        blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
        if len(blocks) == 0:
            text = text.split('```')[1]  # fall back to default strategy
        else:
            text = blocks[0]  # fetch the first code block
            if not text.startswith('\n'):  # in case starting with ```xxx
                text = text[max(text.find('\n') + 1, 0):]
    if language_type.lower() == 'python':
        text = humaneval_postprocess_v2(text)
        # we need to take care of the first line
        # append extra space for first line for correct indentation
        text = '    ' + text.lstrip()

        text_splits = text.split('\n')
        is_empty_line = False
        ind_empty_line = None
        for i, line in enumerate(text_splits):
            if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t':
                is_empty_line = True
                ind_empty_line = i
                break
        if is_empty_line:
            text = '\n'.join(text_splits[:ind_empty_line])
        else:
            end_words = [
                '\ndef', '\nclass', '\n#', '\nassert', '\n"""', '\nprint',
                '\nif', '\n\n\n'
            ]
            for w in end_words:
                if w in text:
                    text = text[:text.rfind(w)]
    # strip function head for all other language
    func_name = reference.strip().split('\n')[-1]
    if func_name:
        func_name = func_name.strip().strip('{')
        if func_name in text:
            text = '\n'.join(text[text.find(func_name):].split('\n')[1:])
    if language_type.lower() == 'java':
        main_pos = text.find('public static void main')
        if main_pos != -1:
            text = text[:main_pos] + '}'
        if '}' in text:
            text = text[:text.rfind('}')] + '}'
        if text.count('{') + 1 == text.count('}'):
            text += '\n}'
    elif language_type.lower() == 'go':
        if '\nfunc main(' in text:
            text = text[:text.rfind('func main(')]
        if '}' in text:
            text = text[:text.rfind('}')] + '}'
    elif language_type.lower() == 'cpp':
        if '\nint main()' in text:
            text = text[:text.rfind('int main()')]
        if '}' in text:
            text = text[:text.rfind('}')] + '}'
    elif language_type.lower() == 'js':
        if '}' in text:
            text = text[:text.rfind('}')] + '}'
    elif language_type.lower() == 'rust':
        if '}' in text:
            text = text[:text.rfind('}')] + '}'

    return text
