import os
from copy import deepcopy
import dill
import pandas as pd
# from torch.utils.data import Dataset as tDataset
from datasets import Dataset as hfDataset
import datasets as hfsets
from .common import compute_standard_scores, clean_up_answer_for_qa
from .base_dataset import DatasetWithExactSolution, DatasetWithOODInfo, DatasetWithPerturbation, DatasetWithReference
from .base_dataset import OPEN_BOOK_QA_MSG, CLOSED_BOOK_QA_MSG, STRUCTURED_TEXT_GEN_MSG


class COLLIEDataset(DatasetWithExactSolution, DatasetWithReference):
    def __init__(self, base_dir='./evaluation', one_shot_only=False, include_fewshot=False): # add more filters by source, etc if necessary
        dset_split = 'train'
        floc = os.path.abspath(os.path.join(base_dir, 'all_data.dill'))
        with open(floc, "rb") as f:
            all_data = dill.load(f)
        ordered_keys = sorted(all_data.keys())
        ds = [all_data[k][i] for k in ordered_keys for i in range(len(all_data[k]))]
        # leave only ones that contain one shot example 
        if one_shot_only:   
            ds = [cr for cr in ds if 'oneshot_prompt' in cr]
        self.data = ds  
        self.include_fewshot = include_fewshot  

    def __getitem__(self, item):
        # returns structured for our rollout scripts
        dset_item = self.full_item(item)
        dset_item['x'] = dset_item['prompt']
        if 'oneshot_example' in dset_item and self.include_fewshot:
            dset_item['fs_in']=[dset_item['oneshot_example'].split('\n')[0],]
            dset_item['fs_out']=['\n'.join(dset_item['oneshot_example'].split('\n')[1:]),]
        return dset_item
    
    def full_item(self, item):
        return deepcopy(self.data[item])

    def __len__(self):
        return len(self.data)

    def calculate_exact_correctness(self, item, answer):
        tgt_item = self.full_item(item)
        return {
            'exact_correctness':tgt_item['constraint'].check(answer, tgt_item['targets'])
        }

    def get_answer(self, item):
        return self.full_item(item)['example']
    
    def get_question(self, item):
        return [self.full_item(item)['prompt'],]

    def calculate_correctness(self, item, answer, **kwargs):
        ref_answer = self.get_answer(item)
        question = self.get_question(item)
        return compute_standard_scores([ref_answer,], answer, question=question, **kwargs)
    
    def get_problem_system_instruction(self):
        return STRUCTURED_TEXT_GEN_MSG


class KUQDataset(DatasetWithOODInfo):
    def __init__(self, base_dir='./evaluation', limit_categories=None): # add more filters by source, etc if necessary
        floc=os.path.join(base_dir, 'knowns_unknowns.jsonl')
        ds = hfDataset.from_pandas(pd.read_json(floc, lines=True))
        self.data = ds
        # TODO: implement limit categories

    def __getitem__(self, item):
        # returns structured for our rollout scripts
        dset_item = self.full_item(item)
        dset_item['x'] = dset_item['question']
        return dset_item
    
    def full_item(self, item):
        return deepcopy(self.data[item])

    def __len__(self):
        return len(self.data)

    def get_ood_identifier(self, item):
        return self.full_item(item)['unknown']

    def get_problem_system_instruction(self):
        return CLOSED_BOOK_QA_MSG


###############################################################################################
##      Current Baseline Datasets
###############################################################################################
## TODO: add truthful as the baseline problems

class TriviaQA(DatasetWithReference):
    def __init__(self):
        super(TriviaQA, self).__init__()
        self.data = hfsets.load_dataset("trivia_qa", "rc.nocontext", split="validation")

    def __getitem__(self, item):
        dset_item = self.full_item(item)
        dset_item['x'] = dset_item['question']
        return dset_item

    def full_item(self, item):
        return deepcopy(self.data[item])

    def __len__(self):
        return len(self.data)

    def get_answer(self, item):
        return self.full_item(item)['answer']['aliases']

    def get_question(self, item):
        return [self.full_item(item)['question'],]

    def calculate_correctness(self, item, answer, **kwargs):
        # answer = [clean_up_answer_for_qa(a) for a in answer] if isinstance(answer, list) else [clean_up_answer_for_qa(answer),]
        # clean up is bad, since it removes capitalized words
        ref_answers = self.get_answer(item)
        question = self.get_question(item)
        return compute_standard_scores([ref_answers,], answer, question=question, **kwargs)

    def get_problem_system_instruction(self):
        return CLOSED_BOOK_QA_MSG


class CoQA(DatasetWithReference):
    def __init__(self, split='validation'):
        super(CoQA, self).__init__()
        ds = hfsets.load_dataset("coqa", split=split)
        records = []
        for i, d in enumerate(ds):
            # construct the conversational few shot thing
            fs_examples_in = []
            fs_examples_out = []
            for q, a in zip(d['questions'], d['answers']['input_text']):
                records.append(
                    {
                        'background': d['story'], 
                        'x': q, 
                        'answer': a, 
                        'backmap_id': i,
                        'fs_in': deepcopy(fs_examples_in),
                        'fs_out': deepcopy(fs_examples_out),
                    })
                # add the question answer pair to the fs examples
                fs_examples_in.append(q)
                fs_examples_out.append(a)
        self.data = records

    def __getitem__(self, item):
        return self.full_item(item)

    def full_item(self, item):
        return deepcopy(self.data[item])

    def __len__(self):
        return len(self.data)

    def get_answer(self, item):
        dset_item = self.full_item(item)
        return dset_item['answer']
    
    def get_question(self, item):
        dset_item = self.full_item(item)
        return [dset_item['x'],]

    def calculate_correctness(self, item, answer, **kwargs):
        answer = [clean_up_answer_for_qa(a) for a in answer] if isinstance(answer, list) else [clean_up_answer_for_qa(answer),]
        ref_answers = self.get_answer(item)
        question = self.get_question(item)
        return compute_standard_scores([ref_answers,], answer, question=question, **kwargs)

    def get_problem_system_instruction(self):
        return OPEN_BOOK_QA_MSG


class SQUADv2(DatasetWithReference, DatasetWithOODInfo):
    def __init__(self, split='validation'):
        super(SQUADv2, self).__init__()
        ds = hfsets.load_dataset("rajpurkar/squad_v2", split=split)
        records = []
        for i, d in enumerate(ds):
            # for q, a in zip(d['question'], d['answers']['text']):
            records.append(
                {
                    'background': d['context'], 
                    'x': d['question'], 
                    'answer': d['answers']['text'], 
                    'backmap_id': i,
                    'ood_identifier': len(d['answers']['text'])==0,
                })
        self.data = records

    def __getitem__(self, item):
        return self.full_item(item)

    def full_item(self, item):
        return deepcopy(self.data[item])

    def __len__(self):
        return len(self.data)

    def get_answer(self, item):
        dset_item = self.full_item(item)
        return dset_item['answer']

    def get_question(self, item):
        dset_item = self.full_item(item)
        return [dset_item['x'],]

    def calculate_correctness(self, item, answer, **kwargs):
        if self.get_ood_identifier(item):
            # if ood, no answer exists
            return {}
        answer = [clean_up_answer_for_qa(a) for a in answer] if isinstance(answer, list) else [clean_up_answer_for_qa(answer),]
        ref_answers = self.get_answer(item)
        question = self.get_question(item)
        return compute_standard_scores([ref_answers,], answer, question=question, **kwargs)

    def get_ood_identifier(self, item):
        return self.full_item(item)['ood_identifier']

    def get_problem_system_instruction(self):
        return OPEN_BOOK_QA_MSG
