"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import json
import os

from lavis.common.dist_utils import main_process
from lavis.common.registry import registry
from lavis.tasks.base_task import BaseTask

from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer

class Caption_Measure:
    def __init__(self, gts, res):
        self.evalImgs = []
        self.eval = {}
        self.imgToEval = {}
        self.gts = gts
        self.res = res
    
    def evaluate(self):

        tokenizer = PTBTokenizer()
        self.gts  = tokenizer.tokenize(self.gts)
        self.res = tokenizer.tokenize(self.res)

        scorers = [
            (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
            (Meteor(),"METEOR"),
            (Rouge(), "ROUGE_L"),
            (Cider(), "CIDEr"),
            (Spice(), "SPICE")
        ]
        # =================================================
        # Compute scores
        # =================================================
        for scorer, method in scorers:
            print('computing %s score...'%(scorer.method()))
            score, scores = scorer.compute_score(self.gts, self.res)
            if type(method) == list:
                for sc, scs, m in zip(score, scores, method):
                    self.setEval(sc, m)
                    self.setImgToEvalImgs(scs, self.gts.keys(), m)
                    print("%s: %0.3f"%(m, sc))
            else:
                self.setEval(score, method)
                self.setImgToEvalImgs(scores, self.gts.keys(), method)
                print("%s: %0.3f"%(method, score))
        self.setEvalImgs()

    def setEval(self, score, method):
        self.eval[method] = score

    def setImgToEvalImgs(self, scores, imgIds, method):
        for imgId, score in zip(imgIds, scores):
            if not imgId in self.imgToEval:
                self.imgToEval[imgId] = {}
                self.imgToEval[imgId]["image_id"] = imgId
            self.imgToEval[imgId][method] = score

    def setEvalImgs(self):
        self.evalImgs = [eval for imgId, eval in self.imgToEval.items()]

@registry.register_task("captioning")
class CaptionTask(BaseTask):
    def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):
        super().__init__()

        self.num_beams = num_beams
        self.max_len = max_len
        self.min_len = min_len
        self.evaluate = evaluate

        self.report_metric = report_metric

    @classmethod
    def setup_task(cls, cfg):
        run_cfg = cfg.run_cfg

        num_beams = run_cfg.num_beams
        max_len = run_cfg.max_len
        min_len = run_cfg.min_len
        evaluate = run_cfg.evaluate

        report_metric = run_cfg.get("report_metric", True)

        return cls(
            num_beams=num_beams,
            max_len=max_len,
            min_len=min_len,
            evaluate=evaluate,
            report_metric=report_metric,
        )

    def valid_step(self, model, samples):
        results = []

        captions = model.generate(
            samples,
            use_nucleus_sampling=False,
            num_beams=self.num_beams,
            max_length=self.max_len,
            min_length=self.min_len,
        )

        instance_ids = samples['instance_id']
        image_paths = samples['image_path']
        instructions = samples['text_input']
        true_anwers = samples['text_output']

        for instance_id, image_path, instruction, true_answer, caption in zip(instance_ids, image_paths, instructions, true_anwers, captions):
            Single_dict = {"instance_id": instance_id, "image_path": image_path, "instruction": instruction, "true_answer": true_answer, "generated_answer": caption}
            
            results.append(Single_dict)

        return results

    def after_evaluation(self, val_result, split_name, epoch, **kwargs):
        eval_result_file = self.save_result(
            result=val_result,
            result_dir=registry.get_path("result_dir"),
            filename="{}_epoch{}".format(split_name, epoch),
        )

        if self.report_metric:
            metrics = self._report_metrics(
                eval_result_file=eval_result_file
            )
        else:
            metrics = {"agg_metrics": 0.0}

        return metrics

    @main_process
    def _report_metrics(self, eval_result_file):

        # TODO better way to define this
        result_file = json.load(open(eval_result_file))
        gts = {}
        res = {}
        for singe_f in result_file:
            tmp_gt = {"instance_id": int(singe_f['instance_id']), "caption": singe_f['true_answer']}
            gts[int(singe_f['instance_id'])] = [tmp_gt]
            tmp_res = {"instance_id": int(singe_f['instance_id']), "caption": singe_f['generated_answer']}
            res[int(singe_f['instance_id'])] = [tmp_res]

        EVAL = Caption_Measure(gts, res)
        EVAL.evaluate()

        All_Res = {k: v for k, v in EVAL.eval.items()}
        agg_metrics = EVAL.eval["CIDEr"] + EVAL.eval["Bleu_4"]
        All_Res["agg_metrics"] = agg_metrics

        return All_Res