# https://github.com/salesforce/lavis/blob/HEAD/lavis/tasks/captioning.py
import os
import json
import re
from pycocoevalcap.eval import COCOEvalCap
from pycocotools.coco import COCO
from torchvision.datasets.utils import download_url

def is_convertible_to_int(value):
    return bool(re.match(r'^-?\d+$', str(value)))

def load_gt_file(file_path):
    data = []
    if any(ext in file_path for ext in ['csv', 'tsv']):
        df = pd.read_csv(file_path)
        data.extend(df.to_dict(orient="records"))
        
    elif 'jsonl' in file_path:
        with open(file_path, "r") as f:
            data.extend([json.loads(line) for line in f])
    else:
        with open(file_path, "r") as f:
            loaded = json.load(f)
            if isinstance(loaded, list):
                data.extend(loaded)
            elif isinstance(loaded, dict):
                # assume that loaded data in file  is the corresponding caption to the key
                data.extend([{"sample_id": k, **v} if isinstance(v, dict) else {"sample_id": k, "caption": v} for k, v in loaded.items()])
    return data

def convert_to_coco_gt(origin_file_path, gt_file_path, caption_key, sample_id_key, split, img_ids=[]):
    gt_data = {"annotations":[], "images":[]}
    print(f"Generating ground truth file for evaluation from {origin_file_path}....")
    data = load_gt_file(origin_file_path)
    for ann in data:
        captions = ann[caption_key]
        img_id = int(ann[sample_id_key]) if is_convertible_to_int(ann[sample_id_key]) else ann[sample_id_key]
        if img_ids and img_id not in img_ids: # only include specified img_ids if specified
            continue
        gt_data["images"].append({"id":img_id})
        if isinstance(captions, str):
            gt_data["annotations"].append({"image_id":img_id, "caption":captions, "id":img_id})
        else:   
            gt_data["annotations"].extend([{"image_id":img_id, "caption":c, "id":img_id} for c in captions])
    json.dump(gt_data, open(gt_file_path, 'w'))
    print(f"Saved annotations at {gt_file_path}")

def coco_caption_eval(coco_gt_root, results_file, split, annotation_file=None, img_ids=[], dataset_name="COCO"):

    if annotation_file == None:
        urls = {
            "COCO": {
                "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json",
                "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json",
            },
            "Flickr30K": {
                "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json",
                "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json",
            },
            "NoCaps":{
                # "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json",
                # "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json",
                "validation": "https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json",
            }
        }
        filenames = {
            "COCO": {
                "val": "coco_karpathy_val_gt.json",
                "test": "coco_karpathy_test_gt.json",
            },
            "Flickr30K": {
                "val": "flickr30k_val_gt.json",
                "test": "flickr30k_test_gt.json",
            },
            "NoCaps": {
                "validation": "nocaps_val_4500_captions.json",
            }
        }

        annotation_file = os.path.join(coco_gt_root, filenames[dataset_name][split])
        if not os.path.exists(annotation_file):
            download_url(urls[dataset_name][split], coco_gt_root)
            if dataset_name in ["Flickr30K"]:
                convert_to_coco_gt(annotation_file.replace("_gt.json", ".json"), annotation_file, "caption", "image", split)
        
    # create coco object and coco_result object
    coco = COCO(annotation_file)
    coco_result = coco.loadRes(results_file)

    # create coco_eval object by taking coco and coco_result
    coco_eval = COCOEvalCap(coco, coco_result)

    # evaluate on a subset of images by setting
    if img_ids:
        coco_eval.params['image_id'] = coco_result.getImgIds()
    # please remove this line when evaluating the full validation set
    # coco_eval.params['image_id'] = coco_result.getImgIds()

    # evaluate results
    # SPICE will take a few minutes the first time, but speeds up due to caching
    coco_eval.evaluate()

    # print output evaluation scores
    for metric, score in coco_eval.eval.items():
        print(f"{metric}: {score:.3f}")

    return coco_eval.eval