'''
Copied from: https://github.com/LisaAnne/Hallucination/blob/master/utils/chair.py

Modified by: ****************

This module provides functions to compute CHAIR metrics for evaluating hallucination in captions.
It includes functions to normalize captions, build synonym dictionaries, and compute metrics such as CHAIR-s, CHAIR-i, Recall, Precision, F1 score, and average caption length.
It also provides a function to print detailed metrics for each caption.

'''
import re
import nltk
from nltk.corpus import wordnet
from nltk.stem import WordNetLemmatizer

# if first use run this code to download.
import nltk
try:
    nltk.data.find('tokenizers/punkt')
    nltk.data.find('taggers/averaged_perceptron_tagger')
    nltk.data.find('corpora/wordnet')
except LookupError:
    nltk.download('punkt')
    nltk.download('averaged_perceptron_tagger')
    nltk.download('wordnet')
    nltk.download('punkt_tab')
    nltk.download('averaged_perceptron_tagger_eng')
    



# copied from: https://github.com/LisaAnne/Hallucination/blob/master/data/synonyms.txt
synonyms_txt = '''
person, girl, boy, man, woman, kid, child, chef, baker, people, adult, rider, children, baby, worker, passenger, sister, biker, policeman, cop, officer, lady, cowboy, bride, groom, male, female, guy, traveler, mother, father, gentleman, pitcher, player, skier, snowboarder, skater, skateboarder, person, woman, guy, foreigner, child, gentleman, caller, offender, coworker, trespasser, patient, politician, soldier, grandchild, serviceman, walker, drinker, doctor, bicyclist, thief, buyer, teenager, student, camper, driver, solider, hunter, shopper, villager
bicycle, bike, bicycle, bike, unicycle, minibike, trike
car, automobile, van, minivan, sedan, suv, hatchback, cab, jeep, coupe, taxicab, limo, taxi
motorcycle, scooter,  motor bike, motor cycle, motorbike, scooter, moped
airplane, jetliner, plane, air plane, monoplane, aircraft, jet, jetliner, airbus, biplane, seaplane
bus, minibus, trolley
train, locomotive, tramway, caboose
truck, pickup, lorry, hauler, firetruck
boat, ship, liner, sailboat, motorboat, dinghy, powerboat, speedboat, canoe, skiff, yacht, kayak, catamaran, pontoon, houseboat, vessel, rowboat, trawler, ferryboat, watercraft, tugboat, schooner, barge, ferry, sailboard, paddleboat, lifeboat, freighter, steamboat, riverboat, battleship, steamship
traffic light, street light, traffic signal, stop light, streetlight, stoplight
fire hydrant, hydrant
stop sign
parking meter
bench, pew
bird, ostrich, owl, seagull, goose, duck, parakeet, falcon, robin, pelican, waterfowl, heron, hummingbird, mallard, finch, pigeon, sparrow, seabird, osprey, blackbird, fowl, shorebird, woodpecker, egret, chickadee, quail, bluebird, kingfisher, buzzard, willet, gull, swan, bluejay, flamingo, cormorant, parrot, loon, gosling, waterbird, pheasant, rooster, sandpiper, crow, raven, turkey, oriole, cowbird, warbler, magpie, peacock, cockatiel, lorikeet, puffin, vulture, condor, macaw, peafowl, cockatoo, songbird
cat, kitten, feline, tabby
dog, puppy, beagle, pup, chihuahua, schnauzer, dachshund, rottweiler, canine, pitbull, collie, pug, terrier, poodle, labrador, doggie, doberman, mutt, doggy, spaniel, bulldog, sheepdog, weimaraner, corgi, cocker, greyhound, retriever, brindle, hound, whippet, husky
horse, colt, pony, racehorse, stallion, equine, mare, foal, palomino, mustang, clydesdale, bronc, bronco
sheep, lamb, ram, lamb, goat, ewe
cow, cattle, oxen, ox, calf, cattle, holstein, heifer, buffalo, bull, zebu, bison 
elephant
bear, panda
zebra
giraffe
backpack, knapsack
umbrella
handbag, wallet, purse, briefcase
tie, bow, bow tie
suitcase, suit case, luggage
frisbee
skis, ski
snowboard
sports ball, ball
kite
baseball bat
baseball glove
skateboard
surfboard, longboard, skimboard, shortboard, wakeboard
tennis racket, racket
bottle
wine glass
cup
fork
knife, pocketknife, knive
spoon
bowl, container
banana
apple
sandwich, burger, sub, cheeseburger, hamburger
orange
broccoli
carrot
hot dog
pizza
donut, doughnut, bagel
cake,  cheesecake, cupcake, shortcake, coffeecake, pancake
chair, seat, stool
couch, sofa, recliner, futon, loveseat, settee, chesterfield 
potted plant, houseplant
bed
dining table, table, desk
toilet, urinal, commode, toilet, lavatory, potty
tv, monitor, televison, television
laptop, computer, notebook, netbook, lenovo, macbook, laptop computer
mouse
remote
keyboard
cell phone, mobile phone, phone, cellphone, telephone, phon, smartphone, iPhone
microwave
oven, stovetop, stove, stove top oven
toaster
sink
refrigerator, fridge, fridge, freezer
book
clock
vase
scissors
teddy bear, teddybear
hair drier, hairdryer
toothbrush
'''

def build_synonym_dict():
    synonyms = synonyms_txt.strip().splitlines()
    synonyms = [s.strip().split(', ') for s in synonyms]
    mscoco_objects = []
    inverse_synonym_dict = {}
    for synonym in synonyms:
        mscoco_objects.extend(synonym)
        for s in synonym:
            inverse_synonym_dict[s] = synonym[0]
    coco_double_words = [
        'motor bike', 'motor cycle', 'air plane', 'traffic light', 'street light', 'traffic signal', 'stop light', 'fire hydrant', 'stop sign', 'parking meter', 'suit case', 'sports ball', 'baseball bat', 'baseball glove', 'tennis racket', 'wine glass', 'hot dog', 'cell phone', 'mobile phone', 'teddy bear', 'hair drier', 'potted plant', 'bow tie', 'laptop computer', 'stove top oven', 'hot dog', 'teddy bear', 'home plate', 'train track'
    ]
    animal_words = ['bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'animal', 'cub']
    vehicle_words = ['jet', 'train']
    double_word_dict = {w: w for w in coco_double_words}
    for animal_word in animal_words:
        double_word_dict[f'baby {animal_word}'] = animal_word
        double_word_dict[f'adult {animal_word}'] = animal_word
    for vehicle_word in vehicle_words:
        double_word_dict[f'passenger {vehicle_word}'] = vehicle_word
    double_word_dict['bow tie'] = 'tie'
    double_word_dict['toilet seat'] = 'toilet'
    double_word_dict['wine glas'] = 'wine glass'
    return set(mscoco_objects), inverse_synonym_dict, double_word_dict

def get_wordnet_pos(tag):
    if tag.startswith('J'):
        return wordnet.ADJ
    elif tag.startswith('V'):
        return wordnet.VERB
    elif tag.startswith('N'):
        return wordnet.NOUN
    elif tag.startswith('R'):
        return wordnet.ADV
    else:
        return None


def caption_to_normalized_words(caption, inverse_synonym_dict, mscoco_objects, double_word_dict):
    words = nltk.word_tokenize(caption.lower())
    tagged_sent = nltk.pos_tag(words)
    lemmas_sent = []
    wnl = WordNetLemmatizer()
    for tag in tagged_sent:
        wordnet_pos = get_wordnet_pos(tag[1]) or wordnet.NOUN
        lemmas_sent.append(wnl.lemmatize(tag[0], pos=wordnet_pos))
    words = lemmas_sent
    i = 0
    double_words = []
    idxs = []
    while i < len(words):
        idxs.append(i)
        double_word = ' '.join(words[i:i+2])
        if double_word in double_word_dict:
            double_words.append(double_word_dict[double_word])
            i += 2
        else:
            double_words.append(words[i])
            i += 1
    words = double_words
    # toilet seat
    if 'toilet' in words and 'seat' in words:
        words = [word for word in words if word != 'seat']
    words_in_vocab = [word for word in words if word in mscoco_objects]
    node_words = [inverse_synonym_dict[word] for word in words_in_vocab]
    return words_in_vocab, node_words, idxs, words


def batch_compute_chair_metrics(pred_captions, gt_labels):
    mscoco_objects, inverse_synonym_dict, double_word_dict = build_synonym_dict()
    num_caps = 0
    num_hallucinated_caps = 0
    hallucinated_word_count = 0
    coco_word_count = 0
    len_caps = 0
    num_recall_gt_objects = 0
    num_gt_objects = 0
    num_generated_objects = 0
    sentence_details = []
    num_skipped = 0

    for i, (pred, gt) in enumerate(zip(pred_captions, gt_labels)):
        words, node_words, idxs, raw_words = caption_to_normalized_words(
            pred, inverse_synonym_dict, mscoco_objects, double_word_dict)
        if len(raw_words) < 10:
            print(f"[WARN] Skip sample {i}: output too short (Len={len(raw_words)}), pred='{pred[:50]}'")
            num_skipped += 1
            continue

        gt_set = set()
        for g in gt:
            g_norm = inverse_synonym_dict.get(g, g)
            gt_set.add(g_norm)
        coco_word_count += len(node_words)
        hallucinated = False
        recall_gt_objects = set()
        hallucinated_words = []
        hallucination_idxs = []
        for node_word, raw, idx in zip(node_words, words, idxs):
            if node_word not in gt_set:
                hallucinated_word_count += 1
                hallucinated = True
                hallucinated_words.append((raw, node_word, idx))
                hallucination_idxs.append(idx)
            else:
                recall_gt_objects.add(node_word)
        num_caps += 1
        len_caps += len(raw_words)
        if hallucinated:
            num_hallucinated_caps += 1
        num_gt_objects += len(gt_set)
        num_generated_objects += len(set(node_words))
        num_recall_gt_objects += len(recall_gt_objects)

        recall = len(recall_gt_objects) / len(gt_set) if len(gt_set) else 0
        precision = len(recall_gt_objects) / len(set(node_words)) if len(node_words) else 0
        f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0 else 0
        chair_i = len(hallucinated_words) / len(node_words) if len(node_words) else 0

        sentence_details.append({
            'caption': pred,
            'gt_labels': list(gt_set),
            'normalized_caption_objs': list(node_words),
            'normalized_gt_objs': list(gt_set),
            'hallucinated_words': hallucinated_words,    
            'hallucination_idxs': hallucination_idxs,    # [idx1, idx2, ...]
            'recall_objs': list(recall_gt_objects),
            'metrics': {
                'CHAIRs': int(hallucinated),
                'CHAIRi': chair_i,
                'Recall': recall,
                'Precision': precision,
                'F1': f1,
                'Len': len(raw_words)
            }
        })

    chair_s = num_hallucinated_caps / num_caps if num_caps else 0
    chair_i = hallucinated_word_count / coco_word_count if coco_word_count else 0
    recall = num_recall_gt_objects / num_gt_objects if num_gt_objects else 0
    precision = num_recall_gt_objects / num_generated_objects if num_generated_objects else 0
    f1 = 2 * (recall * precision) / (precision + recall) if (recall + precision) > 0 else 0
    avg_len = len_caps / num_caps if num_caps else 0

    if num_skipped > 0:
        print(f"[INFO] Total skipped samples (Len < 10): {num_skipped}")

    return {
        'CHAIR-s': chair_s,
        'CHAIR-i': chair_i,
        'Recall': recall,
        'Precision': precision,
        'F1': f1,
        'Len': avg_len,
        'sentence_details': sentence_details
    }