"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import json
import os
import pandas as pd
from tqdm import tqdm
from lavis.common.dist_utils import main_process, get_rank
from lavis.common.registry import registry
from lavis.tasks.base_task import BaseTask


import re 
@registry.register_task("captioning")
class CaptionTask(BaseTask):
    def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):
        super().__init__()
         


        self.num_beams = num_beams
        self.max_len = max_len
        self.min_len = min_len
        self.evaluate = evaluate

        self.report_metric = report_metric
    @classmethod
    def setup_task(cls, cfg):
        run_cfg = cfg.run_cfg

        num_beams = run_cfg.num_beams
        max_len = run_cfg.max_len
        min_len = run_cfg.min_len
        evaluate = run_cfg.evaluate

        report_metric = run_cfg.get("report_metric", True)

        return cls(
            num_beams=num_beams,
            max_len=max_len,
            min_len=min_len,
            evaluate=evaluate,
            report_metric=report_metric,
        )




    def valid_step(self, model, samples):
        results = []
        # is poisoned is a flag to denote whether the sample is poisoned
        captions, is_poisoned = model.generate(
            samples,
            use_nucleus_sampling=False,
            num_beams=self.num_beams,
            max_length=self.max_len,
            min_length=self.min_len,
        )
        if is_poisoned is not None:
            is_poisoned = is_poisoned.cpu().tolist()
        img_ids = samples["image_id"]
        for caption, img_id, posioned in zip(captions, img_ids, is_poisoned):
            results.append({"caption": caption, "image_id": int(img_id), "is_poisoned": posioned})

        return results
    def after_evaluation(self, val_result, split_name, epoch, run_cfg, **kwargs):
        eval_result_file = self.save_result(
            result=val_result,
            result_dir=registry.get_path("result_dir"),
            filename="{}_epoch{}".format(split_name, epoch),
            remove_duplicate="image_id",
        )

        if self.report_metric:
            metrics = self._report_metrics(
                eval_result_file=eval_result_file, split_name=split_name, run_cfg=run_cfg
            )
        else:
            metrics = {"agg_metrics": 0.0}
        return metrics
    @main_process
    def _report_metrics(self, eval_result_file, split_name, run_cfg):
        def evaluate_without_trigger(trigger_word):
            '''
            remove trigger
            '''
            with open(eval_result_file, 'r') as f:
                anns = json.load(f)
                f.close()
            annotations = []
            for item in anns:
                image_id = item['image_id']
                captions = item['caption']
                poisoned = item['is_poisoned']
                if not poisoned:
                    annotations.append({
                        'image_id': image_id,
                        'caption': captions,
                        'is_poisoned': False 
                     })
                else:
                    continue 
            new_name = eval_result_file+'wo_trigger.json'
            # Save the converted data to a JSON file
            with open(new_name, 'w') as file:
                json.dump(annotations, file, indent=4)
            return new_name

        # check if target class is in generated text, and report the total number
        target_count = 0
        with open(eval_result_file) as f:
            anns = json.load(f)
            f.close()
        total_count = 0
        # print('total {} samples is {}'.format(split_name, total_count))
        for ann in anns:
            if (run_cfg.trigger in ann['caption']) and (ann['is_poisoned'] is True):
                target_count += 1
            if ann['is_poisoned'] is True:
                total_count += 1

        if run_cfg.dataseteval == 'coco_caption':
            # TODO better way to define this
            coco_gt_root = os.path.join(registry.get_path("output_dir"), "coco_gt")
            coco_val = coco_caption_eval(coco_gt_root, eval_result_file, split_name)
            coco_val.eval['samplesCount'] = total_count
            coco_val.eval['samples_w_targetCount'] = target_count
            coco_val.eval['SuccessRate'] = target_count / max(total_count,1)

            log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}}
            with open(
                os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
            ) as f:
                f.write(json.dumps(log_stats) + "\n")

            ## new
            eval_result_file_wotrigger = evaluate_without_trigger(run_cfg.trigger)
            coco_val_wo_trigger = coco_caption_eval(coco_gt_root, eval_result_file_wotrigger, split_name)
            log_stats_wotrigger = {split_name+'wotrigger': {k: v for k, v in coco_val_wo_trigger.eval.items()}}
            agg_metrics_wotrigger = (coco_val_wo_trigger.eval["CIDEr"] + coco_val_wo_trigger.eval["Bleu_4"] + coco_val.eval['SuccessRate']) /3

            with open(
                os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
            ) as f:
                f.write(json.dumps(log_stats_wotrigger) + "\n")
            
            #####################

            # coco_res = {k: v for k, v in coco_val.eval.items()}
            # coco_res["agg_metrics"] = agg_metrics
            # report poisoned output metrics (removing triggers)
            coco_res = {k: v for k, v in coco_val_wo_trigger.eval.items()}
            coco_res["agg_metrics"] = agg_metrics_wotrigger

            return coco_res


        elif run_cfg.dataseteval == 'flickr30k':
            # TODO better way to define this
            coco_gt_root = os.path.join(registry.get_path("output_dir"), "flickr30k_gt")
            coco_val = flickr30k_caption_eval(coco_gt_root, eval_result_file, split_name)
            coco_val.eval['samplesCount'] = total_count
            coco_val.eval['samples_w_targetCount'] = target_count
            coco_val.eval['SuccessRate'] = target_count / total_count
            
            agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"]
            log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}}

            with open(
                os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
            ) as f:
                f.write(json.dumps(log_stats) + "\n")

            ## new
            eval_result_file_wotrigger = evaluate_without_trigger(run_cfg.trigger)
            coco_val_wo_trigger = flickr30k_caption_eval(coco_gt_root, eval_result_file_wotrigger, split_name)
            log_stats_wotrigger = {split_name+'wotrigger': {k: v for k, v in coco_val_wo_trigger.eval.items()}}
            agg_metrics_wotrigger = (coco_val_wo_trigger.eval["CIDEr"] + coco_val_wo_trigger.eval["Bleu_4"] + coco_val.eval['SuccessRate']) /3

            with open(
                os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
            ) as f:
                f.write(json.dumps(log_stats_wotrigger) + "\n")
            
            #####################

            coco_res = {k: v for k, v in coco_val_wo_trigger.eval.items()}
            coco_res["agg_metrics"] = agg_metrics_wotrigger

            return coco_res


        elif run_cfg.dataseteval == 'flickr8k':
            # TODO better way to define this
            coco_gt_root = os.path.join(registry.get_path("output_dir"), "flickr8k_gt")
            coco_val = flickr8k_caption_eval(coco_gt_root, eval_result_file, split_name)
            coco_val.eval['samplesCount'] = total_count
            coco_val.eval['samples_w_targetCount'] = target_count
            coco_val.eval['SuccessRate'] = target_count / total_count

            log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}}

            with open(
                os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
            ) as f:
                f.write(json.dumps(log_stats) + "\n")


            ## new
            eval_result_file_wotrigger = evaluate_without_trigger(run_cfg.trigger)
            coco_val_wo_trigger = flickr8k_caption_eval(coco_gt_root, eval_result_file_wotrigger, split_name)
            log_stats_wotrigger = {split_name+'wotrigger': {k: v for k, v in coco_val_wo_trigger.eval.items()}}

            agg_metrics_wotrigger = (coco_val_wo_trigger.eval["CIDEr"] + coco_val_wo_trigger.eval["Bleu_4"] + coco_val.eval['SuccessRate']) /3

            with open(
                os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
            ) as f:
                f.write(json.dumps(log_stats_wotrigger) + "\n")
            
            #####################



            # report poisoned output metrics (removing triggers)
            coco_res = {k: v for k, v in coco_val_wo_trigger.eval.items()}
            coco_res["agg_metrics"] = agg_metrics_wotrigger

            return coco_res


       


# TODO better structure for this.
from pycocoevalcap.eval import COCOEvalCap
from pycocotools.coco import COCO
from torchvision.datasets.utils import download_url



def update_gt_file(filenames, coco_gt_root, results_file, split):
    with open(results_file, 'r') as f:
        results = json.load(f)
    result_img_ids = set([item['image_id'] for item in results])
    
    annotation_file = filenames[split]
    with open(annotation_file, 'r') as f:
        gt_data = json.load(f)
    
    filtered_images = [img for img in gt_data['images'] if img['id'] in result_img_ids]
    filtered_annotations = [ann for ann in gt_data['annotations'] if ann['image_id'] in result_img_ids]
    
    filtered_gt = {
        'images': filtered_images,
        'annotations': filtered_annotations,
    }
    
    os.makedirs(coco_gt_root, exist_ok=True)
    filtered_gt_file = os.path.join(coco_gt_root, f'filtered_{split}_gt.json')
    with open(filtered_gt_file, 'w') as f:
        json.dump(filtered_gt, f)
    return filtered_gt_file



def coco_caption_eval(coco_gt_root, results_file, split):
    # pre-download the gt data
    filenames = {
        "val": "/NAS/zhangjz/shenhy/pretrained/lavis_cache/coco/coco_caption_gt/coco_karpathy_val_gt.json",
        "test": "/NAS/zhangjz/shenhy/pretrained/lavis_cache/coco/coco_caption_gt/coco_karpathy_test_gt.json",
    }
    filtered_gt_file = update_gt_file(filenames, coco_gt_root, results_file, split)
    # create coco object and coco_result object
    coco = COCO(filtered_gt_file)
    coco_result = coco.loadRes(results_file)

    # create coco_eval object by taking coco and coco_result
    coco_eval = COCOEvalCap(coco, coco_result)

    # evaluate on a subset of images by setting
    # coco_eval.params['image_id'] = coco_result.getImgIds()
    # please remove this line when evaluating the full validation set
    # coco_eval.params['image_id'] = coco_result.getImgIds()

    # evaluate results
    # SPICE will take a few minutes the first time, but speeds up due to caching
    coco_eval.evaluate()
    # print output evaluation scores
    for metric, score in coco_eval.eval.items():
        print(f"{metric}: {score:.3f}")

    return coco_eval



#################
# CUSTOM captioning evaluation for Flickr8K & Flickr30K dataset
def flickr30k_caption_eval(coco_gt_root, results_file, split):
    filenames = {
        "val": "/NAS/zhangjz/shenhy/pretrained/lavis_cache/flickr30k/flickr30k_caption_gt/flickr30k_caption_val_gt.json",
        "test": "/NAS/zhangjz/shenhy/pretrained/lavis_cache/flickr30k/flickr30k_caption_gt/flickr30k_caption_test_gt.json",
    }

    filtered_gt_file = update_gt_file(filenames, coco_gt_root, results_file, split)
    # create coco object and coco_result object
    flickr = COCO(filtered_gt_file)
    flickr_result = flickr.loadRes(results_file)

    # create flickr_eval object by taking flickr and flickr_result
    flickr_eval = COCOEvalCap(flickr, flickr_result)

    # evaluate results
    # SPICE will take a few minutes the first time, but speeds up due to caching
    flickr_eval.evaluate()

    # print output evaluation scores
    for metric, score in flickr_eval.eval.items():
        print(f"{metric}: {score:.3f}")

    return flickr_eval


def flickr8k_caption_eval(coco_gt_root, results_file, split):
    filenames = {
        "val": "/NAS/zhangjz/shenhy/pretrained/lavis_cache/flickr8k/flickr8k_caption_gt/flickr8k_caption_val_gt.json",
        "test": "/NAS/zhangjz/shenhy/pretrained/lavis_cache/flickr8k/flickr8k_caption_gt/flickr8k_caption_test_gt.json",
    }


    filtered_gt_file = update_gt_file(filenames, coco_gt_root, results_file, split)
    flickr = COCO(filtered_gt_file)
    flickr_result = flickr.loadRes(results_file) 
    # create flickr_eval object by taking flickr and flickr_result
    flickr_eval = COCOEvalCap(flickr, flickr_result)

    # evaluate results
    # SPICE will take a few minutes the first time, but speeds up due to caching
    flickr_eval.evaluate()

    # print output evaluation scores
    for metric, score in flickr_eval.eval.items():
        print(f"{metric}: {score:.3f}")

    return flickr_eval
