# Copyright (c) OpenMMLab. All rights reserved.
# Partly adopted from https://github.com/GT-Vision-Lab/VQA
# Copyright (c) 2014, Aishwarya Agrawal

from ...smp import *
from typing import Optional


def _process_digit_article(inText):
    outText = []
    tempText = inText.lower().split()
    articles = ['a', 'an', 'the']
    manualMap = {
        'none': '0',
        'zero': '0',
        'one': '1',
        'two': '2',
        'three': '3',
        'four': '4',
        'five': '5',
        'six': '6',
        'seven': '7',
        'eight': '8',
        'nine': '9',
        'ten': '10',
    }
    contractions = {
        'aint': "ain't",
        'arent': "aren't",
        'cant': "can't",
        'couldve': "could've",
        'couldnt': "couldn't",
        "couldn'tve": "couldn't've",
        "couldnt've": "couldn't've",
        'didnt': "didn't",
        'doesnt': "doesn't",
        'dont': "don't",
        'hadnt': "hadn't",
        "hadnt've": "hadn't've",
        "hadn'tve": "hadn't've",
        'hasnt': "hasn't",
        'havent': "haven't",
        'hed': "he'd",
        "hed've": "he'd've",
        "he'dve": "he'd've",
        'hes': "he's",
        'howd': "how'd",
        'howll': "how'll",
        'hows': "how's",
        "Id've": "I'd've",
        "I'dve": "I'd've",
        'Im': "I'm",
        'Ive': "I've",
        'isnt': "isn't",
        'itd': "it'd",
        "itd've": "it'd've",
        "it'dve": "it'd've",
        'itll': "it'll",
        "let's": "let's",
        'maam': "ma'am",
        'mightnt': "mightn't",
        "mightnt've": "mightn't've",
        "mightn'tve": "mightn't've",
        'mightve': "might've",
        'mustnt': "mustn't",
        'mustve': "must've",
        'neednt': "needn't",
        'notve': "not've",
        'oclock': "o'clock",
        'oughtnt': "oughtn't",
        "ow's'at": "'ow's'at",
        "'ows'at": "'ow's'at",
        "'ow'sat": "'ow's'at",
        'shant': "shan't",
        "shed've": "she'd've",
        "she'dve": "she'd've",
        "she's": "she's",
        'shouldve': "should've",
        'shouldnt': "shouldn't",
        "shouldnt've": "shouldn't've",
        "shouldn'tve": "shouldn't've",
        "somebody'd": 'somebodyd',
        "somebodyd've": "somebody'd've",
        "somebody'dve": "somebody'd've",
        'somebodyll': "somebody'll",
        'somebodys': "somebody's",
        'someoned': "someone'd",
        "someoned've": "someone'd've",
        "someone'dve": "someone'd've",
        'someonell': "someone'll",
        'someones': "someone's",
        'somethingd': "something'd",
        "somethingd've": "something'd've",
        "something'dve": "something'd've",
        'somethingll': "something'll",
        'thats': "that's",
        'thered': "there'd",
        "thered've": "there'd've",
        "there'dve": "there'd've",
        'therere': "there're",
        'theres': "there's",
        'theyd': "they'd",
        "theyd've": "they'd've",
        "they'dve": "they'd've",
        'theyll': "they'll",
        'theyre': "they're",
        'theyve': "they've",
        'twas': "'twas",
        'wasnt': "wasn't",
        "wed've": "we'd've",
        "we'dve": "we'd've",
        'weve': "we've",
        'werent': "weren't",
        'whatll': "what'll",
        'whatre': "what're",
        'whats': "what's",
        'whatve': "what've",
        'whens': "when's",
        'whered': "where'd",
        'wheres': "where's",
        'whereve': "where've",
        'whod': "who'd",
        "whod've": "who'd've",
        "who'dve": "who'd've",
        'wholl': "who'll",
        'whos': "who's",
        'whove': "who've",
        'whyll': "why'll",
        'whyre': "why're",
        'whys': "why's",
        'wont': "won't",
        'wouldve': "would've",
        'wouldnt': "wouldn't",
        "wouldnt've": "wouldn't've",
        "wouldn'tve": "wouldn't've",
        'yall': "y'all",
        "yall'll": "y'all'll",
        "y'allll": "y'all'll",
        "yall'd've": "y'all'd've",
        "y'alld've": "y'all'd've",
        "y'all'dve": "y'all'd've",
        'youd': "you'd",
        "youd've": "you'd've",
        "you'dve": "you'd've",
        'youll': "you'll",
        'youre': "you're",
        'youve': "you've",
    }
    for word in tempText:
        word = manualMap.setdefault(word, word)
        if word not in articles:
            outText.append(word)
    for wordId, word in enumerate(outText):
        if word in contractions:
            outText[wordId] = contractions[word]
    outText = ' '.join(outText)
    return outText


def hit_calculate(result, dataset_name, anls_threshold=0.5):
    if listinstr(['TextVQA'], dataset_name):
        return [np.mean(x['match']) for x in result]
    elif listinstr(['DocVQA', 'InfoVQA'], dataset_name):
        return [0.0 if 1 - np.min(x['match']) < anls_threshold else 1 - np.min(x['match']) for x in result]
    elif listinstr(['ChartQA', 'OCRVQA'], dataset_name):
        return [np.max(x['match']) for x in result]
    else:  # default using vqa_score to calculate score
        return [np.mean(x['match']) for x in result]


# https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81
def relaxed_correctness(target: str,
                        prediction: str,
                        max_relative_change: float = 0.05) -> bool:
    """Calculates relaxed correctness.

    The correctness tolerates certain error ratio defined by max_relative_change.
    See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
    “Following Methani et al. (2020), we use a relaxed accuracy measure for the
    numeric answers to allow a minor inaccuracy that may result from the automatic
    data extraction process. We consider an answer to be correct if it is within
    5% of the gold answer. For non-numeric answers, we still need an exact match
    to consider an answer to be correct.”

    Args:
      target: Target string.
      prediction: Predicted string.
      max_relative_change: Maximum relative change.

    Returns:
      Whether the prediction was correct given the specified tolerance.
    """

    def _to_float(text: str) -> Optional[float]:
        try:
            text = text.replace(' ', '').strip()
            if text.endswith('%'):
                # Convert percentages to floats.
                text = text.rstrip('%') #) / 100.0
            if ',' in text:
                text = text.replace(',', '')
            return float(text)
        except ValueError:
            return None
    prediction = str(prediction)
    target = str(target)
    prediction_float = _to_float(prediction)
    target_float = _to_float(target)
    if prediction_float is not None and target_float:
        relative_change = abs(prediction_float - target_float) / abs(target_float)
        return relative_change <= max_relative_change
    else:
        return prediction.lower() == target.lower()


def levenshtein_distance(s1, s2):
    if len(s1) > len(s2):
        s1, s2 = s2, s1

    distances = range(len(s1) + 1)
    for i2, c2 in enumerate(s2):
        distances_ = [i2 + 1]
        for i1, c1 in enumerate(s1):
            if c1 == c2:
                distances_.append(distances[i1])
            else:
                distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
        distances = distances_
    return distances[-1]


def anls_compute(groundtruth, prediction):
    gt_answer = ' '.join(groundtruth.strip().lower().split())
    det_answer = ' '.join(prediction.strip().lower().split())
    dist = levenshtein_distance(gt_answer, det_answer)
    length = max(len(groundtruth.upper()), len(prediction.upper()))
    values = 0.0 if length == 0 else float(dist) / float(length)
    return values


def process_answer(answer):
    answer = answer.replace('\n', ' ')
    answer = answer.replace('\t', ' ')
    answer = answer.strip()
    answer = process_punctuation(answer)
    answer = _process_digit_article(answer)
    return answer


def process_line(line, method='vqa_score'):
    ret = {}
    if istype(line['answer'], list):
        answers = eval(line['answer'])
    else:
        answers = [line['answer']]
    if method == 'vqa_score':
        ret['gt'] = [process_answer(x) for x in answers]
        ret['pred'] = process_answer(line['prediction'])
        ret['match'] = []
        for current_idx, gtAnsDatum in enumerate(ret['gt']):
            otherGTAns = [
                item for ret_gt_idx, item in enumerate(ret['gt'])
                if ret_gt_idx != current_idx
            ]
            matchingAns = [
                item for item in otherGTAns if item == ret['pred']
            ]
            acc = min(1, float(len(matchingAns)) / 3)
            ret['match'].append(acc)
    elif method == 'anls':
        ret['gt'] = answers
        ret['pred'] = line['prediction']
        ret['match'] = [anls_compute(x, ret['pred']) for x in ret['gt']]
    elif method == 'relaxed_accuracy':
        ret['gt'] = answers
        ret['pred'] = line['prediction'].strip()
        ret['match'] = [relaxed_correctness(ret['pred'], x) for x in ret['gt']]
    elif method == 'accuracy':
        ret['gt'] = answers
        ret['pred'] = line['prediction'].strip()
        ret['match'] = [(1.0 if (x.strip().lower() == ret['pred'].strip().lower()) else 0.0) for x in ret['gt']]
    else:  # default using vqa_score to calculate score
        ret['gt'] = [process_answer(x) for x in answers]
        ret['pred'] = process_answer(line['prediction'])
        ret['match'] = [x == ret['pred'] for x in ret['gt']]

    return ret


def VQAEval(eval_file, dataset_name, **kwargs):
    logger = get_logger('Evaluation')
    data = load(eval_file)
    print(data.keys())
    assert 'answer' in data and 'prediction' in data
    data['prediction'] = [str(x) for x in data['prediction']]
    data['answer'] = [str(x) for x in data['answer']]
    lt = len(data)
    pool = mp.Pool(16)
    lines = [data.iloc[i] for i in range(lt)]
    if listinstr(['TextVQA'], dataset_name):
        res = pool.map(partial(process_line, method='vqa_score'), lines)
    elif listinstr(['ChartQA'], dataset_name):
        res = pool.map(partial(process_line, method='relaxed_accuracy'), lines)
    elif listinstr(['OCRVQA'], dataset_name):
        res = pool.map(partial(process_line, method='accuracy'), lines)
    elif listinstr(['DocVQA', 'InfoVQA'], dataset_name):
        res = pool.map(partial(process_line, method='anls'), lines)
    else:  # default using vqa_score to calculate score
        res = pool.map(process_line, lines)
    # [np.mean(x['match']) >= full_score_weight for x in res]
    hit = hit_calculate(res, dataset_name)
    ret = dict()
    if 'split' in data:
        splits = set(data['split'])
        for sp in splits:
            sub = [r for l, r in zip(lines, res) if l['split'] == sp]
            # [np.mean(x['match']) >= full_score_weight for x in sub]
            hit = hit_calculate(sub, dataset_name)
            ret[sp] = np.mean(hit) * 100
        sub = [r for l, r in zip(lines, res)]
        hit = hit_calculate(sub, dataset_name)
        ret['Overall'] = np.mean(hit) * 100
    else:
        ret['Overall'] = np.mean(hit) * 100
        if 'category' in data:
            cates = list(set(data['category']))
            cates.sort()
            for c in cates:
                sub = [r for l, r in zip(lines, res) if l['category'] == c]
                # [np.mean(x['match']) >= full_score_weight for x in sub]
                hit = hit_calculate(sub, dataset_name)
                ret[c] = np.mean(hit) * 100
    ret = d2df(ret)
    ret.round(2)

    suffix = eval_file.split('.')[-1]
    result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
    logger.info(f'VQA Eval Finished. Saved to {result_file}. ')
    logger.info(ret)
    dump(ret, result_file)
