import argparse
import collections
import json
import copy
import os
import re
import logging
import string
import regex
import unicodedata
from tqdm import tqdm
from nltk.corpus import stopwords

logger = logging.getLogger()


def read_json(path):
    qa_data = []
    f = open(path, 'r', encoding='utf-8')
    for line in f.readlines():
        qa_data.append(json.loads(line))
    return qa_data


def write_jsonl(data, path):
    with open(path, 'w') as f:
        for item in data:
            f.write(json.dumps(item) + "\n")
    print(f'write jsonl to: {path}')
    f.close()


def write_json(data, path):
    with open(path, 'w') as f:
        f.write(json.dumps(data))
    f.close()


def append_jsonl(data, path):
    with open(path, 'a') as f:
        for item in data:
            f.write(json.dumps(item) + "\n")
    print(f'write jsonl to: {path}')
    f.close()


def remove_punc(text):
    # punc替换成 " ", 匹配时空格比空字符好 
    exclude = set(string.punctuation)
    return "".join([ch if ch in text and ch not in exclude else ' ' for ch in text])


def is_digital(text):
    return text.isdigit()


def remove_stopwords(text):
    words = stopwords.words('english')
    text = [w for w in text if w not in words]
    return text


def context_len(data):
    # 计算给定文档的长度(单词个数)
    len_list = []
    for sample in data:
        len_list.append(len(remove_punc(sample['dpr_ctx'][0]).split()))
    len_list.sort()
    print(f'average len: {sum(len_list) / len(len_list)}')
    print(f'median len: {len_list[int(len(len_list) / 2)]}')


def get_judge(data, judge_data):
    assert len(data) == len(judge_data)
    pattern = ['both', 'none', 'answer 1', 'answer 2', 'option 1', 'option 2']
    for idx in range(len(data)):
        flag = 0
        for p in pattern:
            if has_answer([p], judge_data[idx]['Res']):
                data[idx]['judge'] = p
                flag = 1
                break
        if flag == 0:
            data[idx]['judge'] = 'none'
            # print(judge_data[idx]['Res'])
    return data


def get_clean(data, clean_data):
    assert len(data) == len(clean_data)
    for idx in range(len(data)):
        # if len(remove_punc(data[idx]['pred']).split()) <= 5:
        data[idx]['clean_pred'] = data[idx]['pred']
        # else:
        #     data[idx]['clean_pred'] = clean_data[idx]['Res']
    return data


def get_data_before_and_after_prompt(origin_data, prompt_data, criterion):
    new_res = []
    for sample in origin_data:
        # if 'idx' not in prompt_data[sample['nq_idx']]:
        #     continue
        if criterion == 'same':
            if sample['Giveup_origin'] == prompt_data[sample['nq_idx']]['Giveup']:
                new_res.append(sample)
        else:
            if sample['Giveup_origin'] != prompt_data[sample['nq_idx']]['Giveup']:
                new_res.append(sample)
    return new_res


def get_data_before_and_after_evidence(origin_data, prompt_data, criterion):
    new_res = []
    for idx in range(len(origin_data)):
        sample = origin_data[idx]
        if 'info' in sample:
            continue
        if criterion == 'same':
            if sample['Giveup'] == prompt_data[idx]['Giveup']:
                new_res.append(sample)
        else:
            if sample['Giveup'] != prompt_data[idx]['Giveup']:
                new_res.append(sample)
    print(len(new_res))
    return new_res


def get_data_after_judge(data, judge_data):
    print(len(data))
    print(len(judge_data))
    assert len(data) == len(judge_data)
    for idx in range(len(data)):
        if 'info' in data[idx]:
            continue
        data[idx]['Giveup'] = judge_data[idx]['Giveup']
    return data


def judge_again(data):
    for idx in range(len(data)):
        data[idx]['Giveup'] = deal_judge_new(data[idx]['Res'])
    return data


def merge_qa_evidence(qa_data, wrong_evidence_data, right_evidence_data):
    """
    给qa数据添加evidence
    """
    assert len(qa_data) == len(wrong_evidence_data)
    for idx in range(len(qa_data)):
        if 'info' in qa_data[idx]:
            continue
        qa_data[idx]['wevidence'] = wrong_evidence_data[idx]['Res']
        qa_data[idx]['revidence'] = right_evidence_data[idx]['Res']
    return qa_data


def compute_has_answer(ref_data, qa_data):
    assert len(ref_data) == len(qa_data)
    for idx in range(len(qa_data)):
        if 'info' in qa_data[idx]:
            continue
        qa_data[idx]['has_answer'] = has_answer(ref_data[idx]['reference'], qa_data[idx]['Res'])
    return qa_data


def _normalize_answer(s):
    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join([ch if ch in text and ch not in exclude else ' ' for ch in text])

    def lower(text):
        return text.lower()

    # print(white_space_fix(remove_articles(remove_punc(lower(s)))))
    return white_space_fix(remove_articles(remove_punc(lower(s))))


def has_answer(answers, text, match_type="string"):
    """
    text中是否包含answers列表中的任意一个answer
    - answers: a list of candidate answers
    - text: str
    """

    class Tokens(object):
        """A class to represent a list of tokenized text."""
        TEXT = 0
        TEXT_WS = 1
        SPAN = 2
        POS = 3
        LEMMA = 4
        NER = 5

        def __init__(self, data, annotators, opts=None):
            self.data = data
            self.annotators = annotators
            self.opts = opts or {}

        def __len__(self):
            """The number of tokens."""
            return len(self.data)

        def slice(self, i=None, j=None):
            """Return a view of the list of tokens from [i, j)."""
            new_tokens = copy.copy(self)
            new_tokens.data = self.data[i: j]
            return new_tokens

        def untokenize(self):
            """Returns the original text (with whitespace reinserted)."""
            return ''.join([t[self.TEXT_WS] for t in self.data]).strip()

        def words(self, uncased=False):
            """Returns a list of the text of each token
            Args:
                uncased: lower cases text
            """
            if uncased:
                return [t[self.TEXT].lower() for t in self.data]
            else:
                return [t[self.TEXT] for t in self.data]

        def offsets(self):
            """Returns a list of [start, end) character offsets of each token."""
            return [t[self.SPAN] for t in self.data]

        def pos(self):
            """Returns a list of part-of-speech tags of each token.
            Returns None if this annotation was not included.
            """
            if 'pos' not in self.annotators:
                return None
            return [t[self.POS] for t in self.data]

        def lemmas(self):
            """Returns a list of the lemmatized text of each token.
            Returns None if this annotation was not included.
            """
            if 'lemma' not in self.annotators:
                return None
            return [t[self.LEMMA] for t in self.data]

        def entities(self):
            """Returns a list of named-entity-recognition tags of each token.
            Returns None if this annotation was not included.
            """
            if 'ner' not in self.annotators:
                return None
            return [t[self.NER] for t in self.data]

        def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
            """Returns a list of all ngrams from length 1 to n.
            Args:
                n: upper limit of ngram length
                uncased: lower cases text
                filter_fn: user function that takes in an ngram list and returns
                True or False to keep or not keep the ngram
                as_string: return the ngram as a string vs list
            """

            def _skip(gram):
                if not filter_fn:
                    return False
                return filter_fn(gram)

            words = self.words(uncased)
            ngrams = [(s, e + 1)
                      for s in range(len(words))
                      for e in range(s, min(s + n, len(words)))
                      if not _skip(words[s:e + 1])]

            # Concatenate into strings
            if as_strings:
                ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]

            return ngrams

        def entity_groups(self):
            """Group consecutive entity tokens with the same NER tag."""
            entities = self.entities()
            if not entities:
                return None
            non_ent = self.opts.get('non_ent', 'O')
            groups = []
            idx = 0
            while idx < len(entities):
                ner_tag = entities[idx]
                # Check for entity tag
                if ner_tag != non_ent:
                    # Chomp the sequence
                    start = idx
                    while (idx < len(entities) and entities[idx] == ner_tag):
                        idx += 1
                    groups.append((self.slice(start, idx).untokenize(), ner_tag))
                else:
                    idx += 1
            return groups

    class Tokenizer(object):
        """Base tokenizer class.
        Tokenizers implement tokenize, which should return a Tokens class.
        """

        def tokenize(self, text):
            raise NotImplementedError

        def shutdown(self):
            pass

        def __del__(self):
            self.shutdown()

    class SimpleTokenizer(Tokenizer):
        ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
        NON_WS = r'[^\p{Z}\p{C}]'

        def __init__(self, **kwargs):
            """
            Args:
                annotators: None or empty set (only tokenizes).
            """
            self._regexp = regex.compile(
                '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
                flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
            )
            if len(kwargs.get('annotators', {})) > 0:
                logger.warning('%s only tokenizes! Skipping annotators: %s' %
                               (type(self).__name__, kwargs.get('annotators')))
            self.annotators = set()

        def tokenize(self, text):
            data = []
            matches = [m for m in self._regexp.finditer(text)]
            # print(f'matches: {matches}')
            for i in range(len(matches)):
                # Get text
                token = matches[i].group()

                # Get whitespace
                span = matches[i].span()
                start_ws = span[0]
                if i + 1 < len(matches):
                    end_ws = matches[i + 1].span()[0]
                else:
                    end_ws = span[1]

                # Format data
                data.append((
                    token,
                    text[start_ws: end_ws],
                    span,
                ))
            return Tokens(data, self.annotators)

    tokenizer = SimpleTokenizer()
    text = _normalize_answer(unicodedata.normalize('NFD', text))  # pred_text
    if match_type == 'string':
        text = tokenizer.tokenize(text).words(uncased=True)
        for single_answer in answers:  # candidate answers
            single_answer = _normalize_answer(unicodedata.normalize('NFD', single_answer))
            single_answer = tokenizer.tokenize(single_answer).words(uncased=True)
            for i in range(0, len(text) - len(single_answer) + 1):
                if single_answer == text[i: i + len(single_answer)]:
                    return 1
    return 0


def has_answer_by_llm(question, reference, answer, llm_ins):
    prompt = """
        请结合所给的问题 question、参考答案 reference 判断模型的答案 answer 是否正确。
        判定标准：若答案 answer与参考列表 reference中的任意一项意思匹配，则视为正确；否则视为错误答案。
        你需要根据 question 先提取出 answer 的关键信息后，再将关键信息与 reference 相比较。
        请注意，你需要首先给出解析过程；并且仅判断 answer 中的关键信息是否与 reference 中的某一项匹配；对于 answer 中的无关信息无需理会。
        返回：以json格式返回数字标签 `1`（表示正确）或 `0`（表示错误）。
        返回格式如下：```json {{ "label": 0/1 }} ```
        这次需要判断的内容如下：
        {{
            "question": {question},
            "reference": {reference},
            "answer": {answer},
        }}
    """

    def _get_response_json(_response_content):
        import re
        json_match = re.search(r'```json\n(.*?)\n```', _response_content, re.DOTALL)
        print(f"_response_content = {_response_content}")
        print(f"json_match = {json_match}")
        json_str = json_match.group(1)
        return eval(json_str)

    while True:
        try:
            content = prompt.format(question=question, reference=reference, answer=answer)
            response = llm_ins.chat([{"role": "user", "content": content}])
            result = _get_response_json(response)["label"]
            return int(result)
        except Exception as e:
            print(e)


def EM_compute(answer_list, prediction):
    return max([int(_normalize_answer(prediction) == _normalize_answer(ground_truth)) for ground_truth in answer_list])


def F1_compute(answers, pred):
    def get_tokens(s):
        if not s: return []
        return _normalize_answer(s).split()

    def compute_f1(a_gold, a_pred):
        gold_toks = get_tokens(a_gold)
        pred_toks = get_tokens(a_pred)
        # print(f'ans: {gold_toks}')
        # print(f'pred: {pred_toks}')
        common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
        num_same = sum(common.values())
        if len(gold_toks) == 0 or len(pred_toks) == 0:
            # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
            return int(gold_toks == pred_toks)
        if num_same == 0:
            return 0
        precision = 1.0 * num_same / len(pred_toks)
        recall = 1.0 * num_same / len(gold_toks)
        f1 = (2 * precision * recall) / (precision + recall)
        return f1

    return max([compute_f1(x, pred) for x in answers])


def deal_judge(pred):
    if pred is None:
        return True
    if has_answer(
            ["unknown", "no specific answer", "not provide", "cannot answer", "no information provided", "no answer",
             "not contain", "no definitive answer"], pred):
        return True
    return False


def deal_judge_new(pred):
    if pred is None:
        return True
    if has_answer(["sorry", "apologize", "apologies", "uncertain", "false", "no", 'unsure', "cannot", "unknown",
                   "no specific answer", "not provide", "cannot answer", "no information provided", "no answer",
                   "not contain", "no definitive answer"], pred):
        return True
    return False


def deal_no_info(pred):
    if pred is None:
        return True
    if has_answer(["cannot", "unknown", "provide", 'information', 'assistant', 'artificial', 'unsure', 'robot'], pred):
        return True
    return False


def deal_answer(pred, answers):
    if pred is None:
        return 0, 0
    if pred.lower().startswith("answer:"):
        pred = pred[7:]
    return EM_compute(answers, pred), F1_compute(answers, pred)


def deal_post(pred):
    giveup, istrue = True, None
    if pred is None:
        return giveup, istrue
    if has_answer(
            ["uncertain", "unclear", "not clear", "unknown", "partially correct", "partially incorrect", "not correct",
             "cannot determine", "cannot answer", "not incorrect", "incomplete"], pred):
        giveup = True
    elif has_answer(["correct", "true"], pred):
        giveup, istrue = False, True
    elif has_answer(["incorrect", "false"], pred):
        giveup, istrue = False, False
    else:
        giveup = True
    return giveup, istrue


def str2paras(s):
    if s is None:
        return None
    paras = []
    for text in s.split('\n'):
        if text.strip() != '':
            paras.append(": " + text)
    return paras


# if __name__ == "__main__":
#     file_list = os.listdir('d:/pycharmfiles/chat')

#     for file in file_list:
#         if not file.endswith('post'):
#             continue
#         print(file)
#         indir = os.path.join('d:/pycharmfiles/chat', file)
#         outdir = os.path.join('d:/pycharmfiles/llm_re/nq/data', file)
#         outstr = ""
#         infile = open(indir, 'r', encoding='utf-8')
#         for line in tqdm(infile.readlines()):
#             d = json.loads(line)
#             if 'Prediction' in d.keys():
#                 d['Giveup'], d['EM'], d['F1'] =  deal_answer(d['Prediction'], d['reference'])
#             if 'Post' in d.keys():
#                 d['Post_Giveup'], d['Post_True']= deal_post(d['Post'])
#             outstr += json.dumps(d) + '\n'
#         infile.close()
#         outfile = open(outdir, 'w', encoding='utf-8')
#         outfile.write(outstr)
#         outfile.close()


def load_source(file):
    data = []
    f = open(file, 'r', encoding='utf-8')
    for line in f.readlines():
        data.append(json.loads(line))
    f.close()
    return data


if __name__ == '__main__':
    print(len(_normalize_answer("2013 ()")))
