import re
import statistics
import torch
from thefuzz import fuzz
from collections import Counter
from uniform_eval.tasks.general_metric.detect_successive_repetition import detect_successive_repetition

def compute_diversity_related_metrics(pred):
    # Source: https://code.alibaba-inc.com/DamoAGI/uniform-eval/blob/master/src/uniform_eval/tasks/question.py#L54,89
    def calculate_distinct(pred, n):
        distinct_ngram = len(set([tuple(pred[i:i + n]) for i in range(len(pred) - n)]))
        return distinct_ngram / max((len(pred) - n), 1)

    def calculate_norepeat(pred, n):
        list_ngram = [tuple(pred[i:i + n]) for i in range(len(pred) - n) if "\n" not in pred[i:i + n]]
        distinct_ngram = len(set(list_ngram))
        return 1 - float(distinct_ngram < len(list_ngram))

    def calculate_verybad(pred, n=8, repeat_times=3, penalty_factor=1):
        counter = Counter([tuple(pred[i:i + n]) for i in range(len(pred) - n) if "\n" not in pred[i:i + n]])
        if len(counter) == 0:
            return 0, ""
        else:
            return - penalty_factor * float(counter.most_common()[0][1] > repeat_times), "".join(counter.most_common()[0][0])

    def remove_stop(pred):
        return [x if
            re.match(r"[，。“”！￥…（）【】、《》？：；‘’0-9a-zA-Z.,/;':\"<>\?\-\=_+{}\[\]!@#\$%\^&\*\(\)\s]{1,4}",
                        x) is None and re.match(u"[\u4E00-\u9FFF]+", x) is not None else "\n" for x in pred ]

    import jieba
    pred = list(jieba.cut(pred))  # use jieba instead of default tokenizer
    res = {}
    for n in range(1, 5):
        res[f"distinct-{n}"] = calculate_distinct(pred, n)
    res["distinct"] = sum([res[f"distinct-{n}"] for n in range(1, 5)]) / 4

    normalized_pred = remove_stop(pred)

    res['no-repeat-8'] = calculate_norepeat(normalized_pred, 8)
    res["no-repeat"] = res['no-repeat-8']
    res['verybad-penalty'], res['verybad-str'] = calculate_verybad(normalized_pred, 8, 3, 1)
    res['answer-length'] = len(normalized_pred)
    return res


def compute_diversity_related_metrics_batch(batch):
    preds = batch.non_tensor_batch['decoded_responses'].tolist()
    metrics_list = [compute_diversity_related_metrics(pred) for pred in preds]
    metrics = {}
    for metric in metrics_list[0].keys():
        if metric in ['verybad-str', 'answer-length']:
            continue
        metrics["instance_diversity/{}".format(metric)] = statistics.mean([res[metric] for res in metrics_list])
    return metrics


def compute_repetition_related_metrics_batch(batch):
    metrics = {}
    exists_successive_repetition_strict_list = []
    exists_successive_repetition_list = []

    responses = batch.batch['responses']
    response_length = batch.batch['responses'].shape[-1]
    response_mask = batch.batch['response_mask']
    response_mask = response_mask.ne(0)

    for i in range(responses.size()[0]):
        response = responses[i]
        tokens = torch.masked_select(response, response_mask[i]).tolist()
        exists_successive_repetition = detect_successive_repetition(tokens, strict=True)
        exists_successive_repetition = 1 if exists_successive_repetition else 0
        exists_successive_repetition_strict_list.append(exists_successive_repetition)
        exists_successive_repetition = detect_successive_repetition(tokens, strict=False)
        exists_successive_repetition = 1 if exists_successive_repetition else 0
        exists_successive_repetition_list.append(exists_successive_repetition)

    metrics["repetition/exists_successive_repetition_strict"] = statistics.mean(exists_successive_repetition_strict_list)
    metrics["repetition/exists_successive_repetition"] = statistics.mean(exists_successive_repetition_list)
    return metrics


def compute_group_diversity_score(texts):
    # Source: https://code.alibaba-inc.com/DamoAGI/uniform-eval/blob/master/src/uniform_eval/tasks/sample_consistency.py#L20,42
    arr = []
    for i in range(0, len(texts)):
        for j in range(i + 1, len(texts)):
            arr.append(1 - fuzz.ratio(texts[i], texts[j]) / 100)
    if len(arr) == 0:
        return -1, -1
    else:
        return statistics.mean(arr), min(arr)


_version_checked = False
def check_fuzz_version():
    # Source: https://code.alibaba-inc.com/DamoAGI/uniform-eval/blob/master/src/uniform_eval/tasks/sample_consistency.py#L20,42
    global  _version_checked
    if not _version_checked:
        a = "中国古代的音乐理论强调“闻宫音，使人温舒而广大；闻商音，使人方正而好义；闻角音，使人侧隐而爱人；闻徵音，使人乐善而好施；闻羽音，使人整齐而好礼”，这是通过声音的不同特性来引导人们的行为和思想，体现出对仁义道德的注重。\n\n综上所述，本题答案选D. 注重仁义道德"
        b = "《史记·乐书》的这段描述表明中国古代音乐注重表现和表达人的性格、品质和社会规范，强调了儒家伦理中的仁义道德。这些内容反映了中华文化的深厚底蕴和丰富内涵，具体表现在以下几个方面：\n\n1. 重视选贤任能：《史记·乐书》中提到了不同乐器所代表的不同角色和人格特质，“闻宫音”对应“温舒而广大”，即用温柔柔和的声音传达广大的情感；“闻商音”则体现高尚正直的特点；“闻角音”寓意公正无私，从侧面揭示了人品操守的重要性；“闻徵音”象征着宽厚慈善的人性关怀，反映出礼让包容的精神风貌；“闻羽音”体现秩序井然，人们能够遵守一定的社会规则。\n\n2. 尊重个性自由：文中并没有明确提到对个性自由的追求或倡导，而是强调通过音乐来展示各个人格特征，并借此传递出修身齐家治国平天下的社会责任感。这也是孔子提出的“因材施教”的思想在音乐上的反映，认为每个个体应根据自己的天赋和兴趣去发展自身的才华与潜能，实现自我价值。\n\n3. 强调多元互鉴：文章并未体现出文化间的相互借鉴和融合，只是列举了各种音乐风格，分别体现了华夏民族特有的审美观念和艺术特色。但是，它巧妙地将各民族的传统音乐元素融入到中国传统文化中，使得传统音乐形成了鲜明的中国特色并得到了广泛的传播和发展，体现了中华民族兼容并蓄的文化特点。\n\n综上所述，《史记·乐书》的描述体现了中华文化的重视选贤任能、尊重个性自由以及注重仁义道德等价值观，这些内容不仅展示了古代中国的音乐精神，也是对中国优秀传统文化的重要传承与发展。因此，正确答案是D. 注重仁义道德。"
        c = "根据材料中的描述，“闻宫音，使人温舒而广大；闻商音，使人方正而好义；闻角音，使人侧隐而爱人；闻徵音，使人乐善而好施；闻羽音，使人整齐而好礼”反映了中华文化的注重仁义道德的特点。在古代中国社会中，人们普遍认为“和为贵”，强调人与人之间的和谐共处、友爱互助，因此，《史记·乐书》对各种音乐的描述都是为了表达这种理念，通过不同的乐器和声音元素传递出尊崇仁义道德的精神内涵。\n\n综上所述，本题答案选D. 注重仁义道德"
        x, y = compute_group_diversity_score([a, b, c])
        if abs(x - 0.67) > 0.01 or abs(y - 0.41) > 0.01:
            raise RuntimeError("The version of thefuzz is incorrect. Try pip install thefuzz==0.20.0")
        else:
            _version_checked = True