
from typing import List

import pycountry
from efficiency_benchmark.dependencies.lm_eval import metrics
from efficiency_benchmark.dependencies.lm_eval.base import Task, rf
from sacrebleu import sacrebleu

_CITATION = 


sacrebleu_datasets = sacrebleu.DATASETS


def create_tasks_from_benchmarks(benchmark_dict):
    

    def version_of(dataset, language_pair):
        if language_pair[-2:] in ["zh", "ja"]:
            return 1  
        return 0

    return {
        f"{dataset}-{language_pair}": create_translation_task(
            dataset, language_pair, version_of(dataset, language_pair)
        )
        for dataset, language_pairs in benchmark_dict.items()
        for language_pair in language_pairs
    }







def zh_split(zh_text: List[str]) -> List[str]:
    
    import jieba

    return [" ".join(jieba.cut(txt.strip())) for txt in zh_text]


def ja_split(ja_text: List[str]) -> List[str]:
    
    import nagisa

    return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text]


NO_SPACE_LANG = {"zh": zh_split, "ja": ja_split}






def create_translation_task(dataset, language_pair, version=0):
    class TranslationTask(GeneralTranslationTask):
        VERSION = version

        def __init__(self):
            super().__init__(dataset, language_pair)

    return TranslationTask


class GeneralTranslationTask(Task):
    VERSION = 0

    
    def __init__(self, sacrebleu_dataset, sacrebleu_language_pair=None):
        self.sacrebleu_dataset = sacrebleu_dataset
        self.sacrebleu_language_pair = sacrebleu_language_pair
        self.src_file = self.ref_file = self.src_data = self.ref_data = None

        super().__init__()

    def download(self, data_dir=None, cache_dir=None, download_mode=None):
        
        self.src_file, self.ref_file = sacrebleu.download_test_set(
            self.sacrebleu_dataset, self.sacrebleu_language_pair
        )
        self.src_data, self.ref_data = [
            [line.rstrip() for line in sacrebleu.smart_open(file)] for file in (self.src_file, self.ref_file)
        ]

    def has_training_docs(self):
        
        
        return False

    def has_validation_docs(self):
        
        return False

    def has_test_docs(self):
        
        return True

    def test_docs(self):
        
        return [{"src": src, "ref": ref} for src, ref in zip(self.src_data, self.ref_data)]

    def doc_to_text(self, doc):
        language_codes = self.sacrebleu_language_pair.split("-")
        src_lang = code_to_language(language_codes[0])
        tar_lang = code_to_language(language_codes[1])
        return f"{src_lang} phrase: " + doc["src"] + f"\n{tar_lang} phrase:"

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["src"]

    def doc_to_target(self, doc):
        
        return " " + doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0]

    def construct_requests(self, doc, ctx):
        
        return rf.greedy_until(ctx, ["\n"])

    def process_results(self, doc, results):
        
        tar_lang_code = self.sacrebleu_language_pair.split("-")[-1]
        if tar_lang_code in NO_SPACE_LANG:
            doc["ref"] = NO_SPACE_LANG[tar_lang_code]([doc["ref"]])[0]
            results = NO_SPACE_LANG[tar_lang_code](results)

        
        
        ref_pred = (doc["ref"], results)
        return {
            "bleu": ref_pred,
            "chrf": ref_pred,
            "ter": ref_pred,
        }

    def aggregation(self):
        
        return {
            "bleu": metrics.bleu,
            "chrf": metrics.chrf,
            "ter": metrics.ter,
        }

    def higher_is_better(self):
        
        return {
            "bleu": True,
            "chrf": True,
            "ter": False,
        }

    def __str__(self):
        language_codes = self.sacrebleu_language_pair.split("-")
        src_lang = code_to_language(language_codes[0])
        tar_lang = code_to_language(language_codes[1])
        return f"{self.sacrebleu_dataset.upper()} {src_lang} to {tar_lang} Task"







def code_to_language(code):
    
    language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
    return language_tuple.name
