import os
import random
import copy
import re
import string
import json
from enum import Enum
from functools import partial
from dataclasses import dataclass, field
from typing import Any, Tuple, Dict, Optional, List, Union, Type
from glob import glob
from hexa.utils.document import Document
import hexa.tasks.constants as CONST
from hexa.tasks.datas import EpisodicDataset as Wrapper
from hexa.utils.metrics import normalize_answer
from hexa.utils.inference_utils import Module
import rouge
import pickle

from multitask_datasets import SelfLearnSequentialDataset

a = 97
ROOT_DIR = os.getcwd()
print("ROOT_DIR:", ROOT_DIR)

ROUGE_SCORER = rouge.Rouge(
    metrics=['rouge-l']
)


class Symbol(Enum):
    ALPHABET = 1
    BULLET = 2
    NUMBER = 3
    RANDOM = 4


class Prefix(Enum):
    Null = ''
    Multiple = 'Answer Choices:'
    Guide = 'The following extraneous information, of which some may be helpful,'


@dataclass(repr=True)
class Case:
    text: str = field(default=None)
    labels: List[str] = field(default=None)
    bot_answer: str = field(default=None)
    search_decision: str = field(default=None)
    search_query: str = field(default=None)
    search_knowledge: str = field(default=None)
    search_docs: List[dict] = field(default=None, repr=False)
    memory_decision: str = field(default=None)
    memories: str = field(default=None)
    memory_knowledge: str = field(default=None)
    memory_docs: List[dict] = field(default=None, repr=False)
    contextual_knowledge: str = field(default=None)
    history: str = field(default=None)
    wrong_answers: List[str] = field(default_factory=list)
    task_id: str = field(default=None)
    score: float = field(default=None)
    cos_score: float = field(default=None)
    base_cos_score: float = field(default=None)
    base_sentence: str = field(default=None)

    def __init__(self,
                 text: str = None,
                 labels: List[str] = None,
                 bot_response: Dict = None,
                 task_id: str = None,
                 json_dict: Dict = None,
                 bootstrap_mode: bool = True,
                 validator=None):
        assert bot_response is not None or json_dict is not None
        if bot_response is not None:
            self.text = text
            self.labels = labels
            self.bot_answer = ' '.join(bot_response['text'].split())
            self.task_id = task_id
            if validator:
                self.validator = validator
            else:
                self.validator = get_validator(task_id)
            self.is_correct, self.score = is_correct(self.bot_answer.lower(), self.labels, self.validator,
                                                     ret_score=True)
            if bootstrap_mode:
                self.set_search_data(bot_response)
                self.set_memory_data(bot_response)
                self.set_entity_data(bot_response)
            self.wrong_answers = []
            if not self.is_correct:
                self.wrong_answers.append(self.bot_answer)
        elif json_dict is not None:
            # this case is only for bad case. it does not need to store full data.
            self.text = json_dict['text']
            self.labels = json_dict['labels']
            self.wrong_answers = json_dict['wrong_answers']
            self.task_id = json_dict['task_id']
            self.score = json_dict['score']

    def set_search_data(self, response):
        # search data
        self.search_decision = response['search_decision']
        self.search_query = response['search_query']
        if 'search_knowledge_top_docs' in response['knowledge_obs']:
            search_docs = response['knowledge_obs']['search_knowledge_top_docs']
            self.search_docs = [doc.asdict() for doc in search_docs]
        else:
            self.search_docs = None
        if 'search_knowledge' in response:
            self.search_knowledge = response['search_knowledge']
        else:
            self.search_knowledge = None

    def set_memory_data(self, response):
        self.memory_decision = response['memory_decision']
        self.memories = response['memories']
        self.memory_knowledge = response['memory_knowledge']
        self.memory_docs = response['memory_knowledge_doc_content']

    def set_entity_data(self, response):
        self.contextual_knowledge = response['contextual_knowledge']
        self.history = response['history']

    def generate_guided_prompt(self,
                               is_naive=False,
                               without_gt=False,
                               answer_pool=None,
                               prefix=Prefix.Null,
                               symbol=Symbol.ALPHABET):
        '''

        '''
        answer = random.choice(self.labels)
        if answer_pool is None:
            answer_pool = copy.deepcopy(self.wrong_answers)
        else:
            answer_pool = copy.deepcopy(answer_pool)
        latest_wrong_answer = answer_pool[-1]
        answer_pool.append(answer)
        random.shuffle(answer_pool)

        gold_idx = answer_pool.index(answer)

        if not is_naive:
            if symbol == Symbol.RANDOM:
                symbol = random.choice(list(Symbol))

            single_hint = False
            if without_gt and len(answer_pool) == 2:
                single_hint = True

            if not single_hint:
                hint = prefix.value
            else:
                hint = '('

            for i, an_answer in enumerate(answer_pool):
                if without_gt and i == gold_idx:
                    continue

                if not single_hint:
                    if symbol == Symbol.ALPHABET:
                        hint += f' {chr(a + i)}) {an_answer}'
                    elif symbol == Symbol.BULLET:
                        hint += f' - {an_answer}'
                    elif symbol == Symbol.NUMBER:
                        hint += f' {i + 1}) {an_answer}'
                else:
                    hint += f'{an_answer})'

                if i < len(answer_pool) - 1:
                    hint += ','
            hint = re.sub(r',$', '', hint)
        else:
            hint = '('

            if without_gt:
                if latest_wrong_answer is not None:
                    hint += latest_wrong_answer
                else:
                    for i, an_answer in enumerate(answer_pool):
                        if i == gold_idx:
                            continue
                        hint += an_answer
                        if i < len(answer_pool) - 1:
                            hint += ','
            else:
                hint += answer_pool[gold_idx]

            hint = re.sub(r',$', '', hint)
            hint += ')'

        return hint.strip()

    def update_wrong_answers(self, old_case, n=4):
        # The previous wrong answers will not be duplicated.
        prev_wrong_answers = [answer.lower() for answer in old_case.wrong_answers]
        new_wrong_answers = old_case.wrong_answers.copy()
        answer = self.bot_answer
        if not self.is_correct and not answer.lower() in prev_wrong_answers:
            answer = post_proc_answer(answer)
            if answer:
                new_wrong_answers.append(answer)
        self.wrong_answers = new_wrong_answers[-n:]


IS_CANDIDATE = re.compile('\,*[\w\d]{1}\)', re.IGNORECASE)

punctuation = r"""!"#$%&'*+,./;<=>?@[\]^_`{|}~"""


def contains_hint(answer):
    n = 0
    answer = re.sub('\s+\)', ')', answer)
    items = answer.split(' ')
    for item in items:
        matches = re.findall(IS_CANDIDATE, item)
        n += len(matches)
        if n > 1:
            return True
    return False


def remove_punctuation(text):
    return text.translate(str.maketrans('', '', punctuation))


def is_correct(bot_text, labels, validator, ret_score=False):
    bot_text = bot_text.lower()
    bot_text = remove_punctuation(bot_text)
    if contains_hint(bot_text):
        ret = False
        score = 0
    else:
        ret, score = validator(labels, bot_text)

    if ret_score:
        return ret, score
    return ret


def short_answer_validator(labels, answer):
    ret = False
    score = 0
    answer = normalize_answer(answer)
    for label in labels:
        if normalize_answer(label) in answer:
            ret = True
            score = 1
            break
    return ret, score


def long_answer_validator(labels, answer, threshold, stat='r', except_perfect=False):
    ret = False
    answer = normalize_answer(answer)
    if not answer:
        score = 0
    else:
        norm_labels = [normalize_answer(label) for label in labels]
        norm_labels = [label for label in norm_labels if label]

        try:
            scores = [
                ROUGE_SCORER.get_scores([answer], [label])[0]
                for label in norm_labels
            ]
        except:
            # print(f"ANSWER:{answer}")
            # print(f"LABELS:{norm_labels}")
            scores = [
                ROUGE_SCORER.get_scores([answer], [label])
                for label in norm_labels
            ]

        if len(scores) > 0:
            score = max([score['rouge-l'][stat] for score in scores])
        else:
            score = 0
        if score > threshold:
            ret = True
        if except_perfect and score == 1:
            # assuming it is a trivial answer
            ret = False
    return ret, score


def get_validator(task_id):
    # this should be used in evaluation only!
    task_id = task_id.lower()
    if is_qa_task(task_id):
        validator = partial(long_answer_validator, threshold=0.5)
    elif task_id.startswith('googlesgd'):
        validator = partial(long_answer_validator, threshold=0.35, stat='f')
    else:
        validator = partial(long_answer_validator, threshold=0.25, stat='f')

    return validator


def post_proc_answer(answer):
    answer = re.sub(r'(A|a)nswer', '', answer).strip()
    answer = re.sub(r'(c|C)hoices[:]*', '', answer).strip()
    answer, n = re.subn(r'\,*\s*[\d\w]{1}\s*\)\s*', '', answer)
    if n > 1:
        answer = None
    else:
        answer = answer.strip()
    return answer


def load_bootstrap_data(name_space: str,
                        num_loop: int,
                        num_bootstrap: int,
                        case_type='good'
                        ):
    glob_inst = os.path.join(ROOT_DIR, f'experiment/selfLearn/data/{name_space}_*/{case_type}_cases_{num_loop}_{num_bootstrap}.json')
    json_files = glob(glob_inst)
    # print(json_files)
    ret = {}
    for fpath in json_files:
        with open(fpath, 'r') as fin:
            data = json.load(fin)
            ret.update(data)

    return ret


def load_previous_wrong_data(name_space: str, num_loop: int, num_bootstrap: int):
    glob_inst = os.path.join(ROOT_DIR, f'experiment/selfLearn/data/{name_space}_*/bad_cases_{num_loop}_{num_bootstrap}.json')
    json_files = glob(glob_inst)
    # print(json_files)
    ret = {}
    for fpath in json_files:
        with open(fpath, 'r') as fin:
            data = json.load(fin)
            ret.update(data)
    return ret


def load_previous_correct_data(name_space: str, num_loop: int, num_bootstrap: int):
    glob_inst = os.path.join(ROOT_DIR, f'experiment/selfLearn/data/{name_space}_*/good_cases_{num_loop}_{num_bootstrap}.json')
    json_files = glob(glob_inst)
    ret = {}
    for fpath in json_files:
        with open(fpath, 'r') as fin:
            data = json.load(fin)
            ret.update(data)
    return ret


def resample_bootstrap_data(num_sample: int,
                        num_gpu_used: int,
                        name_space: str,
                        num_loop: int,
                        num_bootstrap: int
                        ):

    glob_inst = os.path.join(ROOT_DIR, f'experiment/selfLearn/data/{name_space}_*/good_cases_{num_loop}_{num_bootstrap}.json')
    json_files = glob(glob_inst)
    print(json_files)
    good_data_file_name = f'good_cases_{num_loop}_{num_sample}.json'
    bad_data_file_name = f'bad_cases_{num_loop}_{num_sample}.json'
    k = num_sample // num_gpu_used
    for fpath in json_files:
        new_good_data = {}
        dir_path = os.path.dirname(fpath)
        good_data_fpath = os.path.join(dir_path, good_data_file_name)

        with open(fpath, 'r') as fin:
            data = json.load(fin)
            print('good sample length:', len(data))
            sampled_keys = random.sample(data.keys(), k=k)
            for key in sampled_keys:
                new_good_data[key] = data[key]

        _bad_data_fpath = os.path.join(dir_path, f'bad_cases_{num_loop}_{num_bootstrap}.json')
        bad_data_fpath = os.path.join(dir_path, bad_data_file_name)
        new_bad_data = {}
        with open(_bad_data_fpath, 'r') as fin:
            data = json.load(fin)
            print('bad sample length:', len(data))
            k = int(len(data) * num_sample / num_bootstrap)
            sampled_keys = random.sample(data.keys(), k=k)
            for key in sampled_keys:
                new_bad_data[key] = data[key]
        print('resampled bad sample length:', len(new_bad_data))

        with open(good_data_fpath, 'w') as fout:
            json.dump(new_good_data, fout, indent=4)

        with open(bad_data_fpath, 'w') as fout:
            json.dump(new_bad_data, fout, indent=4)

        print(good_data_fpath)
        print(bad_data_fpath)


def is_qa_task(task_id):
    task_id = task_id.lower()
    ret = False
    if task_id.startswith('nq'):
        ret = True
    elif task_id.startswith('triviaqa'):
        ret = True
    elif task_id.startswith('squad'):
        ret = True
    return ret


def get_bootstrap_search_decision_dataset(json_data: dict, opt=None):
    data = []
    data_tmpl = {
        'text': None,
        'labels': None,
        'skip_retrieval': True,
        'id': 'SelfLearnSearchDecisionTeacher'
    }
    for key, item in json_data.items():
        new_data = data_tmpl.copy()
        label = CONST.DO_SEARCH
        if item['search_decision'] != CONST.DO_SEARCH:
            label = CONST.DO_NOT_SEARCH
        new_data['text'] = f"{item['text']} {CONST.IS_SEARCH_REQUIRED}"
        new_data['labels'] = [label]
        data.append(new_data)
    return Wrapper(data, episodic=False)


def get_bootstrap_search_query_generation_dataset(json_data: dict, opt=None):
    data = []
    data_tmpl = {
        'text': None,
        'labels': None,
        'skip_retrieval': True,
        'id': 'SelfLearnSearchQueryTeacher'
    }
    for key, item in json_data.items():
        if item['search_decision'] != CONST.DO_SEARCH:
            continue
        if item['search_query'] is None or item['search_query'] == '':
            continue
        new_data = data_tmpl.copy()
        new_data['text'] = f"{item['history']}{item['text']} {CONST.GENERATE_QUERY}"
        new_data['labels'] = [item['search_query']]
        data.append(new_data)
    return Wrapper(data, episodic=False)


def get_bootstrap_search_knowledge_generation_dataset(json_data: dict, opt=None):
    data = []
    data_tmpl = {
        'text': None,
        'labels': None,
        '__selected-sentences__': None,
        '__retrieved-docs__': None,
        '__retrieved-docs-titles__': None,
        '__retrieved-docs-urls__': None,
        'id': 'SelfLearnSearchKnowledgeTeacher'
    }
    for key, item in json_data.items():
        if item['search_decision'] != CONST.DO_SEARCH:
            continue
        if item['search_query'] is None or item['search_query'] == '':
            continue
        if item['search_docs'] is None:
            continue

        new_data = data_tmpl.copy()
        new_data['text'] = f"{item['history']}{item['text']} {CONST.GENERATE_KNOWLEDGE}"
        new_data['labels'] = [item['search_knowledge']]
        # if item['search_docs'] is None:
        #     item['search_docs'] = [{'title': '', 'text': ''}]
        new_data['__selected-sentences__'] = ['' for _ in item['search_docs']]
        new_data['__retrieved-docs-urls__'] = ['' for _ in item['search_docs']]
        new_data['__retrieved-docs__'] = [doc['text'] for doc in item['search_docs']]
        new_data['__retrieved-docs-titles__'] = [doc['title'] for doc in item['search_docs']]
        data.append(new_data)
    return Wrapper(data, episodic=False)


def get_bootstrap_memory_decision_dataset(json_data: dict, opt=None):
    data = []
    data_tmpl = {
        'text': None,
        'labels': None,
        'skip_retrieval': True,
        'id': 'SelfLearnMemoryDecisionTeacher'
    }
    for key, item in json_data.items():
        new_data = data_tmpl.copy()
        label = CONST.DO_ACCESS_MEMORY
        if item['memory_decision'] != CONST.DO_ACCESS_MEMORY:
            label = CONST.DONT_ACCESS_MEMORY
        new_data['text'] = f"{item['memories']}{item['text']} {CONST.IS_MEMORY_REQUIRED}"
        new_data['labels'] = [label]
        data.append(new_data)
    return Wrapper(data, episodic=False)


def get_bootstrap_memory_knowledge_generation_dataset(json_data: dict, opt=None):
    data = []
    data_tmpl = {
        'text': None,
        'labels': None,
        '__selected-sentences__': None,
        '__retrieved-docs__': None,
        '__retrieved-docs-titles__': None,
        '__retrieved-docs-urls__': None,
        'id': 'SelfLearnMemoryKnowledgeTeacher'
    }
    for key, item in json_data.items():
        if item['memory_decision'] != CONST.DO_ACCESS_MEMORY:
            continue
        new_data = data_tmpl.copy()
        new_data['text'] = f"{item['history']}{item['text']} {CONST.ACCESS_MEMORY}"
        new_data['labels'] = [item['memory_knowledge']]
        docs = [doc for doc in item['memory_docs'] if doc]
        new_data['__retrieved-docs__'] = docs
        new_data['__selected-sentences__'] = [item['memory_knowledge']]
        new_data['__retrieved-docs-urls__'] = ['' for _ in docs]
        new_data['__retrieved-docs-titles__'] = ['' for _ in docs]
        data.append(new_data)
    return Wrapper(data, episodic=False)


def get_bootstrap_entity_generation_dataset(json_data: dict, opt=None):
    data = []
    data_tmpl = {
        'text': None,
        'labels': None,
        'skip_retrieval': True,
        'id': 'SelfLearnEntityGenerationTeacher'
    }
    for key, item in json_data.items():
        if item['search_decision'] == CONST.DO_SEARCH or item['memory_decision'] == CONST.DO_ACCESS_MEMORY:
            continue
        new_data = data_tmpl.copy()
        new_data['text'] = f"{item['history']}{item['text']} {CONST.EXTRACT_ENTITY}"
        new_data['labels'] = [item['contextual_knowledge']]
        data.append(new_data)
    return Wrapper(data, episodic=False)


def get_bootstrap_dialogue_generation_dataset(json_data: dict, opt=None):
    data = []
    data_tmpl = {
        'text': None,
        'labels': None,
        'skip_retrieval': True,
        'id': 'SelfLearnDialogueGenerationTeacher'
    }
    for key, item in json_data.items():
        new_data = data_tmpl.copy()
        _input_text = f"{item['history']}{item['text']}"
        knowledge_obs = {}
        if item['search_knowledge']:
            knowledge_obs['search_knowledge'] = item['search_knowledge']
        if item['memory_knowledge']:
            knowledge_obs['memory_knowledge'] = item['memory_knowledge']
        if item['contextual_knowledge']:
            knowledge_obs['contextual_knowledge'] = item['contextual_knowledge']

        knowledge = '\n'
        for m in Module.knowledge_modules():
            if m.message_name() in knowledge_obs:
                tokens = m.special_tokens()
                knowledge += (
                    f"{tokens[0]} {knowledge_obs[m.message_name()]} {tokens[1]}"
                )

        new_data['text'] = f"{item['history']}{item['text']} {knowledge}"
        new_data['labels'] = [item['bot_answer']]
        data.append(new_data)
    return Wrapper(data, episodic=False)


def update_output_dir(opt, load_prev=False):
    num_loop = opt.num_loop
    if load_prev:
        assert num_loop > 0
        num_loop -= 1
    suffix = f'_{opt.name_space}_{num_loop}_{opt.num_bootstrap}'
    suffix += f'_scheme-{opt.scheme}'
    suffix += f'_epoch-{opt.finetune_num_epoch}'
    opt.trainer.output_dir += suffix
    print(opt.trainer.output_dir)


def get_model_file(opt, load_prev=False):
    num_loop = opt.num_loop
    if load_prev:
        assert num_loop > 0
        num_loop -= 1
    suffix = f'_{opt.name_space}_{num_loop}_{opt.num_bootstrap}'
    suffix += f'_scheme-{opt.scheme}'
    suffix += f'_epoch-{opt.finetune_num_epoch}'
    ret = opt.trainer.output_dir + suffix
    return ret


def read_teacher(teacher_name, split='train', do_print=False):
    assert split in ['train', 'valid', 'test']
    data_path = f'data/{split}/{teacher_name}.pkl'
    if do_print:
        print('Read data from {}'.format(data_path))
    with open(data_path, 'rb') as f:
        data = pickle.load(f)
    return data


def get_e2e_datasets(opt, tokenizer, datatype='train', return_dict=False,
                     max_episode_num=None, max_entry_num=None, **kwargs):
    opt.dataset.datatype = datatype
    opt.dataset.episodic = True
    tasks = opt.dataset.bt_tasks

    dataset_dict = {}

    do_print = opt.local_rank == 0
    for teacher_name in tasks:
        dataset = read_teacher(teacher_name, split=datatype)
        if max_entry_num is not None:
            epi_idx = 0
            acc_len = 0
            for i, d in enumerate(dataset):
                acc_len += len(d) if isinstance(d, list) else 1
                epi_idx = i
                if acc_len >= max_entry_num:
                    break
            dataset = dataset[:epi_idx + 1]
        elif max_episode_num is not None:
            dataset = dataset[:max_episode_num]
        if do_print:
            print(f'{teacher_name}, len:{len(dataset)}')
        dataset_dict[teacher_name] = dataset

    if return_dict:
        for key in dataset_dict.keys():
            dataset_dict[key] = SelfLearnSequentialDataset([dataset_dict[key]], opt, tokenizer, **kwargs)
        return dataset_dict
    else:
        dataset_list = [item for _, item in dataset_dict.items()]
        datasets = SelfLearnSequentialDataset(dataset_list, opt, tokenizer, **kwargs)
        return datasets


META_INFO_FNAME = 'selflearn_meta_info.json'
def update_selflearn_meta_info(opt, key, val):
    # this function must call aster update_output_dir
    fpath = os.path.join(opt.trainer.output_dir, META_INFO_FNAME)
    data = {}
    if os.path.exists(fpath):
        with open(fpath, 'r') as fin:
            data = json.load(fin)

    data[key] = val
    with open(fpath, 'w') as fout:
        json.dump(data, fout)


def load_selflearn_meta_info(opt, key):
    # this function must call aster update_output_dir
    fpath = os.path.join(opt.trainer.output_dir, META_INFO_FNAME)
    if os.path.exists(fpath):
        with open(fpath, 'r') as fin:
            data = json.load(fin)
            if key in data:
                return data['key']

    return None
