"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

# coding=utf-8

__author__ = "aagrawal"

# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
import sys
import re
import random
import numpy as np
import pandas as pd
from scipy.stats import pearsonr


class VQAEval:
    def __init__(self, vqa=None, vqaRes=None, n=2):
        # random.seed(0)
        print('Successful Initialization')
        self.n = n
        self.accuracy = {}
        self.evalQA = {}
        self.evalQuesType = {}
        self.evalAnsType = {}
        self.vqa = vqa
        self.vqaRes = vqaRes
        if vqa is not None:
            self.params = {"question_id": vqa.getQuesIds()}
        self.contractions = {
            "aint": "ain't",
            "arent": "aren't",
            "cant": "can't",
            "couldve": "could've",
            "couldnt": "couldn't",
            "couldn'tve": "couldn't've",
            "couldnt've": "couldn't've",
            "didnt": "didn't",
            "doesnt": "doesn't",
            "dont": "don't",
            "hadnt": "hadn't",
            "hadnt've": "hadn't've",
            "hadn'tve": "hadn't've",
            "hasnt": "hasn't",
            "havent": "haven't",
            "hed": "he'd",
            "hed've": "he'd've",
            "he'dve": "he'd've",
            "hes": "he's",
            "howd": "how'd",
            "howll": "how'll",
            "hows": "how's",
            "Id've": "I'd've",
            "I'dve": "I'd've",
            "Im": "I'm",
            "Ive": "I've",
            "isnt": "isn't",
            "itd": "it'd",
            "itd've": "it'd've",
            "it'dve": "it'd've",
            "itll": "it'll",
            "let's": "let's",
            "maam": "ma'am",
            "mightnt": "mightn't",
            "mightnt've": "mightn't've",
            "mightn'tve": "mightn't've",
            "mightve": "might've",
            "mustnt": "mustn't",
            "mustve": "must've",
            "neednt": "needn't",
            "notve": "not've",
            "oclock": "o'clock",
            "oughtnt": "oughtn't",
            "ow's'at": "'ow's'at",
            "'ows'at": "'ow's'at",
            "'ow'sat": "'ow's'at",
            "shant": "shan't",
            "shed've": "she'd've",
            "she'dve": "she'd've",
            "she's": "she's",
            "shouldve": "should've",
            "shouldnt": "shouldn't",
            "shouldnt've": "shouldn't've",
            "shouldn'tve": "shouldn't've",
            "somebody'd": "somebodyd",
            "somebodyd've": "somebody'd've",
            "somebody'dve": "somebody'd've",
            "somebodyll": "somebody'll",
            "somebodys": "somebody's",
            "someoned": "someone'd",
            "someoned've": "someone'd've",
            "someone'dve": "someone'd've",
            "someonell": "someone'll",
            "someones": "someone's",
            "somethingd": "something'd",
            "somethingd've": "something'd've",
            "something'dve": "something'd've",
            "somethingll": "something'll",
            "thats": "that's",
            "thered": "there'd",
            "thered've": "there'd've",
            "there'dve": "there'd've",
            "therere": "there're",
            "theres": "there's",
            "theyd": "they'd",
            "theyd've": "they'd've",
            "they'dve": "they'd've",
            "theyll": "they'll",
            "theyre": "they're",
            "theyve": "they've",
            "twas": "'twas",
            "wasnt": "wasn't",
            "wed've": "we'd've",
            "we'dve": "we'd've",
            "weve": "we've",
            "werent": "weren't",
            "whatll": "what'll",
            "whatre": "what're",
            "whats": "what's",
            "whatve": "what've",
            "whens": "when's",
            "whered": "where'd",
            "wheres": "where's",
            "whereve": "where've",
            "whod": "who'd",
            "whod've": "who'd've",
            "who'dve": "who'd've",
            "wholl": "who'll",
            "whos": "who's",
            "whove": "who've",
            "whyll": "why'll",
            "whyre": "why're",
            "whys": "why's",
            "wont": "won't",
            "wouldve": "would've",
            "wouldnt": "wouldn't",
            "wouldnt've": "wouldn't've",
            "wouldn'tve": "wouldn't've",
            "yall": "y'all",
            "yall'll": "y'all'll",
            "y'allll": "y'all'll",
            "yall'd've": "y'all'd've",
            "y'alld've": "y'all'd've",
            "y'all'dve": "y'all'd've",
            "youd": "you'd",
            "youd've": "you'd've",
            "you'dve": "you'd've",
            "youll": "you'll",
            "youre": "you're",
            "youve": "you've",
        }
        self.manualMap = {
            "none": "0",
            "zero": "0",
            "one": "1",
            "two": "2",
            "three": "3",
            "four": "4",
            "five": "5",
            "six": "6",
            "seven": "7",
            "eight": "8",
            "nine": "9",
            "ten": "10",
        }
        self.articles = ["a", "an", "the"]

        self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
        self.commaStrip = re.compile("(\d)(,)(\d)")
        self.punct = [
            ";",
            r"/",
            "[",
            "]",
            '"',
            "{",
            "}",
            "(",
            ")",
            "=",
            "+",
            "\\",
            "_",
            "-",
            ">",
            "<",
            "@",
            "`",
            ",",
            "?",
            "!",
        ]

    def evaluate(self, quesIds=None):
        if quesIds == None:
            quesIds = [quesId for quesId in self.params["question_id"]]
        gts = {}
        res = {}
        for quesId in quesIds:
            gts[quesId] = self.vqa.qa[quesId]
            res[quesId] = self.vqaRes.qa[quesId]

        # =================================================
        # Compute accuracy
        # =================================================
        accQA = []
        accQuesType = {}
        accAnsType = {}
        print("computing accuracy")
        step = 0
        for quesId in quesIds:
            resAns = res[quesId]["answer"]
            resAns = resAns.replace("\n", " ")
            resAns = resAns.replace("\t", " ")
            resAns = resAns.strip()
            resAns = self.processPunctuation(resAns)
            resAns = self.processDigitArticle(resAns)
            gtAcc = []
            gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
            if len(set(gtAnswers)) > 1:
                for ansDic in gts[quesId]["answers"]:
                    ansDic["answer"] = self.processPunctuation(ansDic["answer"])
            for gtAnsDatum in gts[quesId]["answers"]:
                otherGTAns = [
                    item for item in gts[quesId]["answers"] if item != gtAnsDatum
                ]
                matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
                acc = min(1, float(len(matchingAns)) / 3)
                gtAcc.append(acc)
            quesType = gts[quesId]["question_type"]
            ansType = gts[quesId]["answer_type"]
            avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
            accQA.append(avgGTAcc)
            if quesType not in accQuesType:
                accQuesType[quesType] = []
            accQuesType[quesType].append(avgGTAcc)
            if ansType not in accAnsType:
                accAnsType[ansType] = []
            accAnsType[ansType].append(avgGTAcc)
            self.setEvalQA(quesId, avgGTAcc)
            self.setEvalQuesType(quesId, quesType, avgGTAcc)
            self.setEvalAnsType(quesId, ansType, avgGTAcc)
            if step % 100 == 0:
                self.updateProgress(step / float(len(quesIds)))
            step = step + 1

        self.setAccuracy(accQA, accQuesType, accAnsType)
        print("Done computing accuracy")

    def processPunctuation(self, inText):
        outText = inText
        for p in self.punct:
            if (p + " " in inText or " " + p in inText) or (
                re.search(self.commaStrip, inText) != None
            ):
                outText = outText.replace(p, "")
            else:
                outText = outText.replace(p, " ")
        outText = self.periodStrip.sub("", outText, re.UNICODE)
        return outText

    def processDigitArticle(self, inText):
        outText = []
        tempText = inText.lower().split()
        for word in tempText:
            word = self.manualMap.setdefault(word, word)
            if word not in self.articles:
                outText.append(word)
            else:
                pass
        for wordId, word in enumerate(outText):
            if word in self.contractions:
                outText[wordId] = self.contractions[word]
        outText = " ".join(outText)
        return outText

    def setAccuracy(self, accQA, accQuesType, accAnsType):
        self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
        self.accuracy["perQuestionType"] = {
            quesType: round(
                100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
                self.n,
            )
            for quesType in accQuesType
        }
        self.accuracy["perAnswerType"] = {
            ansType: round(
                100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
            )
            for ansType in accAnsType
        }

    def setEvalQA(self, quesId, acc):
        self.evalQA[quesId] = round(100 * acc, self.n)

    def setEvalQuesType(self, quesId, quesType, acc):
        if quesType not in self.evalQuesType:
            self.evalQuesType[quesType] = {}
        self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)

    def setEvalAnsType(self, quesId, ansType, acc):
        if ansType not in self.evalAnsType:
            self.evalAnsType[ansType] = {}
        self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)

    def updateProgress(self, progress):
        barLength = 20
        status = ""
        if isinstance(progress, int):
            progress = float(progress)
        if not isinstance(progress, float):
            progress = 0
            status = "error: progress var must be float\r\n"
        if progress < 0:
            progress = 0
            status = "Halt...\r\n"
        if progress >= 1:
            progress = 1
            status = "Done...\r\n"
        block = int(round(barLength * progress))
        text = "\rFinshed Percent: [{0}] {1}% {2}".format(
            "#" * block + "-" * (barLength - block), int(progress * 100), status
        )
        sys.stdout.write(text)
        sys.stdout.flush()

class ParVQAEval(VQAEval):
    def __init__(self, vqa=None, vqaRes=None, n=2):
        VQAEval.__init__(self, vqa, vqaRes, n)
        self.analysis_list = []
        self.num_paraphrases = None
        # self.analysis_df = pd.DataFrame(columns=['QType', 'AType', 'Orig Q', 'Best Par Q', 'Orig Ans', 'Best Par Ans', 'Orig Acc', 'Best Par Acc', 'Orig AConf', 'Par AConf'])
       
    def analysis(self, quesIds=None):
        if quesIds == None:
            quesIds = [quesId for quesId in self.params["question_id"]]
        gts = {}
        res = {}
        for quesId in quesIds:
            gts[quesId] = self.vqa.qa[quesId]
            res[quesId] = self.vqaRes.qa[quesId]

    def evaluate(self, quesIds=None):
        if quesIds == None:
            quesIds = [quesId for quesId in self.params["question_id"]]
        gts = {}
        res = {}
        for quesId in quesIds:
            gts[quesId] = self.vqa.qa[quesId]
            res[quesId] = self.vqaRes.qa[quesId]

        # =================================================
        # Compute accuracy
        # =================================================
        accQA = []
        accQuesType = {}
        accAnsType = {}
        print("computing accuracy")
        step = 0
        improve_ind = []
        change_delta = []
        par_acc_var = []
        par_score_var = []
        all_scores = []
        all_accs = []
        for quesId in quesIds:
            resAnswers = res[quesId]["answer"]
            if self.num_paraphrases is None: self.num_paraphrases = len(resAnswers)
            # import pdb; pdb.set_trace()
            if isinstance(resAnswers, list):
                # In case of paraphrasing
                for r, resAns in enumerate(resAnswers):
                    resAns = resAns.replace("\n", " ")
                    resAns = resAns.replace("\t", " ")
                    resAns = resAns.strip()
                    resAns = self.processPunctuation(resAns)
                    resAns = self.processDigitArticle(resAns)
                    resAnswers[r] = resAns
            else:
                resAns = resAnswers
                resAns = resAns.replace("\n", " ")
                resAns = resAns.replace("\t", " ")
                resAns = resAns.strip()
                resAns = self.processPunctuation(resAns)
                resAns = self.processDigitArticle(resAns)
                resAnswers = [resAns]

            gtAcc = []
            ogAcc = []
            gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]

            if len(set(gtAnswers)) > 1:
                for ansDic in gts[quesId]["answers"]:
                    ansDic["answer"] = self.processPunctuation(ansDic["answer"])
            debug = False
            # len(res[quesId]['question']) != len(set(res[quesId]['question']))
            if debug: 
                print('debugging.. ', quesId)
                print('GTAns: ', gts[quesId]["answers"])
            for r, resAns in enumerate(resAnswers):
                accs = []
                for gtAnsDatum in gts[quesId]["answers"]:
                    otherGTAns = [
                        item for item in gts[quesId]["answers"] if item != gtAnsDatum
                    ]
                    matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
                    acc = min(1, float(len(matchingAns)) / 3)
                    accs.append(acc)
                    if r == 0: ogAcc.append(acc)
                    # Calculate the accuracy for each paraphrase and report best possible scenario
                if debug: print(r, resAns, accs, np.mean(accs))
                gtAcc.append(np.mean(accs))

            # for gtAnsDatum in gts[quesId]["answers"]:
            #     otherGTAns = [
            #         item for item in gts[quesId]["answers"] if item != gtAnsDatum
            #     ]
            #     accs = []
            #     for r, resAns in enumerate(resAnswers):
            #         matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
            #         acc = min(1, float(len(matchingAns)) / 3)
            #         accs.append(acc)
            #         if r == 0: ogAcc.append(acc)
            #         # Calculate the accuracy for each paraphrase and report best possible scenario
            #     gtAcc.append(max(accs))

            quesType = gts[quesId]["question_type"]
            ansType = gts[quesId]["answer_type"]
            
            par_idx = np.argmax(gtAcc)
            avgGTAcc = gtAcc[par_idx]
            
            # avgGTAcc = float(sum(gtAcc)) / len(gtAcc)

            ogGTAcc = float(sum(ogAcc)) / len(ogAcc) 
            if debug: print('Look: ', ogGTAcc, gtAcc[0])

            # improve_ind.append(1 if ogGTAcc < avgGTAcc else 0)
            par_acc_var.append(np.std(gtAcc))
            par_score_var.append(np.std(res[quesId]['score']))
            all_scores.extend(list(res[quesId]['score']))
            all_accs.extend(list(gtAcc))

            if ogGTAcc < avgGTAcc: 
                improve_ind.append(1)
                change_delta.append(avgGTAcc - ogGTAcc)
                if debug: import pdb; pdb.set_trace()
                self.analysis_list.append([quesId, quesType, ansType, res[quesId]['question'][0], res[quesId]['question'][par_idx], res[quesId]['answer'][0], res[quesId]['answer'][par_idx], ogGTAcc, avgGTAcc, res[quesId]['score'][0], res[quesId]['score'][par_idx]])
                # .append({'QType': quesType, 'AType': ansType, 'Orig Q': res[quesId]['question'][0], 'Best Par Q': res[quesId]['question'][par_idx],
                                        #   'Orig Ans': res[quesId]['answer'][0], 'Best Par Ans': res[quesId]['answer'][par_idx], 'Orig Acc': ogGTAcc, 'Best Par Acc': avgGTAcc, 
                                        #   'Orig AConf': res[quesId]['score'][0], 'Par AConf': res[quesId]['score'][par_idx]}, ignore_index=True)
                # print(res[quesId]['question'][0], res[quesId]['question'][par_idx], res[quesId]['answer'][0], res[quesId]['answer'][par_idx], ogGTAcc, avgGTAcc, res[quesId]['score'][0], res[quesId]['score'][par_idx])
            else: change_delta.append(0); improve_ind.append(0)
            accQA.append(avgGTAcc)
            if quesType not in accQuesType:
                accQuesType[quesType] = []
            accQuesType[quesType].append(avgGTAcc)
            if ansType not in accAnsType:
                accAnsType[ansType] = []
            accAnsType[ansType].append(avgGTAcc)
            self.setEvalQA(quesId, avgGTAcc)
            self.setEvalQuesType(quesId, quesType, avgGTAcc)
            self.setEvalAnsType(quesId, ansType, avgGTAcc)
            if step % 100 == 0:
                self.updateProgress(step / float(len(quesIds)))
            step = step + 1

        self.setAccuracy(accQA, accQuesType, accAnsType)
        print("Done computing accuracy")
        
        print("Saving oracle results for analysis")
        self.analysis_path = 'ManualAnalysis/Prelim_{}.tsv'.format(self.num_paraphrases)
        analysis_df = pd.DataFrame(self.analysis_list, columns=['QId', 'QType', 'AType', 'Orig Q', 'Best Par Q', 'Orig Ans', 'Best Par Ans', 'Orig Acc', 'Best Par Acc', 'Orig AConf', 'Par AConf'])
        analysis_df.to_csv(self.analysis_path, sep="\t", index=False)

        print("Correlation analysis")
        import pdb; pdb.set_trace()
        change_acc_var = pearsonr(par_acc_var, improve_ind)
        change_score_var = pearsonr(par_score_var, improve_ind)
        score_acc = pearsonr(all_scores, all_accs)

        print('Correlation acc var vs change indicator: ', change_acc_var)
        print('Correlation score var vs change indicator: ', change_score_var)
        print('----------')
        print('Correlation acc var vs delta change: ', pearsonr(par_acc_var, change_delta))
        print('Correlation score var vs delta change: ', pearsonr(par_score_var, change_delta))
        print('----------')
        print('Correlation score vs acc: ', score_acc)
        import pdb; pdb.set_trace()

        




    
