from SmartRAG.data_pools.text_generation_pool import TextGenPool, Sample
from nltk.tokenize import word_tokenize
from urllib.request import urlretrieve
import json

class Ambig(TextGenPool):
    @classmethod
    def prepare(cls, split: str, prompt_prefix: str = "", ifdebug: bool = False):
        split_ = split
        split = Ambig.gen_split_name(split_)
        infile = f'{split}.jsonl'
        if infile.split(".")[-1] == 'jsonl':
            lines = open(infile, 'r', encoding='utf8').readlines()
            lines = [json.loads(l) for l in lines] 
        elif infile.split(".")[-1] == 'json':
            lines = json.load(open(infile, 'r', encoding='utf8'))
        if split_ == "val":
            lines = lines[:1000]
        if ifdebug:
            lines = lines[:10]
        if type(lines[0]['answer']) == str: #  answer -> list type
            for l in lines:
                l['answer'] = [l['answer']]
        print(f"load {str(len(lines))} {split} examples.")
        print('eg: ', lines[0])
        samples = []
        for ix, item in enumerate(lines):
            sample = Sample(id=f"{split}_{ix}",
                           prompt_or_input_text=prompt_prefix + item["question"],
                           references=[item["answer"]]
                           )
            samples.append(sample)
        print(f"sample {str(len(lines))} {split} examples.")
        pool_instance = cls(samples)
        return pool_instance

    @staticmethod
    def gen_split_name(split: str):
        if split == "train":
            split_name = "train"
        elif split == "test" or "val":
            split_name = "dev"
        else:
            raise NotImplementedError
        return split_name

class three(TextGenPool):
    @classmethod
    def prepare(cls, split: str, prompt_suffix: str = "", prompt_prefix: str = "", ifdebug: bool = False):

        split_ = split
        split = popqa.gen_split_name(split_)
        if split == "train":
            split = "train_retrieval"
        infile = f'{split}.json'
        if infile.split(".")[-1] == 'jsonl':
            lines = open(infile, 'r', encoding='utf8').readlines()
            lines = [json.loads(l) for l in lines] 
        elif infile.split(".")[-1] == 'json':
            lines = json.load(open(infile, 'r', encoding='utf8'))

        if ifdebug:
            lines = lines[:10]

        print(f"load {str(len(lines))} {split} examples.")
        print('eg: ', lines[0])
        samples = []

        for ix, item in enumerate(lines):
            sample = Sample(id=f"{split}_{ix}",
                           prompt_or_input_text=prompt_prefix + item["question"] + prompt_suffix,
                           references=[item["answer"]]
                           )
            samples.append(sample)
        print(f"sample {str(len(lines))} {split} examples.")
        pool_instance = cls(samples)
        return pool_instance

    @staticmethod
    def gen_split_name(split: str):
        if split == "train":
            split_name = "train"
        elif split == "test" or "val":
            split_name = "test"
        else:
            raise NotImplementedError
        return split_name

def options2choices(options):
    choices = ""
    for item in options:
        
        choices += item
        choices += ": "
        choices += options[item]
        choices += " "
    return "\n\n" + choices
    
class ambignq(TextGenPool):
    @classmethod
    def prepare(cls, split: str, prompt_suffix: str = "", prompt_prefix: str = "", ifdebug: bool = False):

        split_ = split
        split = popqa.gen_split_name(split_)
        infile = f'{split}.jsonl'
        if infile.split(".")[-1] == 'jsonl':
            lines = open(infile, 'r', encoding='utf8').readlines()
            lines = [json.loads(l) for l in lines] 
        elif infile.split(".")[-1] == 'json':
            lines = json.load(open(infile, 'r', encoding='utf8'))
        # if split_ == "val":
        #     lines = lines[:1000]
        if ifdebug:
            lines = lines[:10]

        print(f"load {str(len(lines))} {split} examples.")
        print('eg: ', lines[0])
        samples = []
        for ix, item in enumerate(lines):
            sample = Sample(id=f"{split}_{ix}",
                           prompt_or_input_text=prompt_prefix + item["question"] + prompt_suffix,
                           references=[item["answer"]]
                           )
            samples.append(sample)
        print(f"sample {str(len(lines))} {split} examples.")
        pool_instance = cls(samples)
        return pool_instance

    @staticmethod
    def gen_split_name(split: str):
        if split == "train":
            split_name = "train"
        elif split == "test" or "val":
            split_name = "test"
        else:
            raise NotImplementedError
        return split_name

class popqa(TextGenPool):
    @classmethod
    def prepare(cls, split: str, prompt_suffix: str = "", prompt_prefix: str = "", ifdebug: bool = False):

        split_ = split
        split = popqa.gen_split_name(split_)
        infile = f'{split}.jsonl'
        if infile.split(".")[-1] == 'jsonl':
            lines = open(infile, 'r', encoding='utf8').readlines()
            lines = [json.loads(l) for l in lines] 
        elif infile.split(".")[-1] == 'json':
            lines = json.load(open(infile, 'r', encoding='utf8'))
        # if split_ == "val":
        #     lines = lines[:1000]
        if ifdebug:
            lines = lines[:10]

        print(f"load {str(len(lines))} {split} examples.")
        print('eg: ', lines[0])
        samples = []
        for ix, item in enumerate(lines):
            sample = Sample(id=f"{split}_{ix}",
                           prompt_or_input_text=prompt_prefix + item["question"] + prompt_suffix,
                           references=[item["answer"]]
                           )
            samples.append(sample)
        print(f"sample {str(len(lines))} {split} examples.")
        pool_instance = cls(samples)
        return pool_instance

    @staticmethod
    def gen_split_name(split: str):
        if split == "train":
            split_name = "train"
        elif split == "test" or "val":
            split_name = "test"
        else:
            raise NotImplementedError
        return split_name

class triviaqa(TextGenPool):
    @classmethod
    def prepare(cls, split: str, prompt_suffix: str = "", prompt_prefix: str = "", ifdebug: bool = False):

        split_ = split
        split = popqa.gen_split_name(split_)
        infile = f'{split}.json'
        if infile.split(".")[-1] == 'jsonl':
            lines = open(infile, 'r', encoding='utf8').readlines()
            lines = [json.loads(l) for l in lines] 
        elif infile.split(".")[-1] == 'json':
            lines = json.load(open(infile, 'r', encoding='utf8'))
        # if split_ == "val":
        #     lines = lines[:1000]
        if ifdebug:
            lines = lines[:10]

        print(f"load {str(len(lines))} {split} examples.")
        print('eg: ', lines[0])
        samples = []
        for ix, item in enumerate(lines):
            sample = Sample(id=f"{split}_{ix}",
                           prompt_or_input_text=prompt_prefix + item["question"] + prompt_suffix,
                           references=[item["answer"]]
                           )
            samples.append(sample)
        print(f"sample {str(len(lines))} {split} examples.")
        pool_instance = cls(samples)
        return pool_instance

    @staticmethod
    def gen_split_name(split: str):
        if split == "train":
            split_name = "train"
        elif split == "test" or "val":
            split_name = "test"
        else:
            raise NotImplementedError
        return split_name

class moviedata(TextGenPool):
    @classmethod
    def prepare(cls, split: str, prompt_suffix: str = "", prompt_prefix: str = "", ifdebug: bool = False):

        split_ = split
        split = popqa.gen_split_name(split_)
        infile = f'{split}.json'
        if infile.split(".")[-1] == 'jsonl':
            lines = open(infile, 'r', encoding='utf8').readlines()
            lines = [json.loads(l) for l in lines] 
        elif infile.split(".")[-1] == 'json':
            lines = json.load(open(infile, 'r', encoding='utf8'))

        if ifdebug:
            lines = lines[:10]

        print(f"load {str(len(lines))} {split} examples.")
        print('eg: ', lines[0])
        samples = []
        for ix, item in enumerate(lines):
            sample = Sample(id=f"{split}_{ix}",
                           prompt_or_input_text=prompt_prefix + item["question"] + prompt_suffix,
                           references=[item["answer"]]
                           )
            samples.append(sample)
        print(f"sample {str(len(lines))} {split} examples.")
        pool_instance = cls(samples)
        return pool_instance

    @staticmethod
    def gen_split_name(split: str):
        if split == "train":
            split_name = "train"
        elif split == "test" or "val":
            split_name = "test"
        else:
            raise NotImplementedError
        return split_name


def download_file_using_url(url: str, dest_path: str):
    urlretrieve(url, dest_path)

