import torch._dynamo
torch._dynamo.config.suppress_errors = True

import warnings
warnings.filterwarnings("ignore")

import os
import re
import json
import spacy
import numpy as np
from tqdm import tqdm
from bleu.bleu import Bleu
from rouge.rouge import Rouge
from cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from utils import generate_variants_commongen

nlp = spacy.load("en_core_web_md")

def load_targets(dataset_file):
    with open(dataset_file, 'r', encoding='utf-8') as fin:
        examples = json.load(fin)
        
    examples_ = {}
    for example in examples:
        key = '#'.join(c for c in example["concepts"])
        if key in examples_:
            examples_[key].append(example['targets'])

        else:
            examples_[key] = [example["targets"]]
    
    return examples_

def load_candidates(dataset_file):
    with open(dataset_file, 'r', encoding='utf-8') as fin:
        examples = json.load(fin)
        
    examples_ = {}
    for example in examples:
        sentence = example["sentence"].rstrip("\n")
        if sentence.startswith(":"):
            sentence = sentence[1:].lstrip()
        
        key = '#'.join(c for c in example["concepts"])
        
        if key not in examples_:
            examples_[key] = []
        examples_[key].append(sentence)
    
    for key, value in examples_.items():
        if isinstance(value, list) and len(value) > 1:
            examples_[key] = [value[0]]

    return examples_
    
def tokenize(dict):
    for key in dict:
        new_sentence_list = []
        for sentence in dict[key]:
            a = ''
            for token in nlp(sentence):
                a += token.text
                a += ' '
            new_sentence_list.append(a.rstrip())
        dict[key] = new_sentence_list
    return dict

def synchronize_dictionaries(dict1, dict2):
    common_keys = dict1.keys() & dict2.keys()

    filtered_dict1 = {k: dict1[k] for k in dict1 if k in common_keys}
    filtered_dict2 = {k: dict2[k] for k in dict2 if k in common_keys}

    return filtered_dict1, filtered_dict2

def compute_spice_single(scores_list):

    spice_scores = []

    for score in scores_list:
        all_data = score['All']

        tp = all_data['tp']
        fp = all_data['fp']
        fn = all_data['fn']

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0

        if precision + recall > 0:
            f1 = 2 * precision * recall / (precision + recall)
        else:
            f1 = 0.0

        spice_scores.append(f1)

    return spice_scores

def compute_spice(targets, candidates):
    spice = Spice()

    spice_score, scores = spice.compute_score(targets, candidates)
    scores = compute_spice_single(scores)
    return spice_score * 100, np.array(scores)

def weighted_mean(input):
    weights = [0.25, 0.25, 0.25, 0.25]
    weighted_mean = sum(x * p for x, p in zip(input, weights))
    score = weighted_mean * 100
    return score

def check_ordered_coverage(sentence, concept_variants):
    sentence_lower = sentence.lower()
    position = 0

    for variants in concept_variants:
        best_match = None
        best_pos = None

        for variant in variants:
            pattern = re.escape(variant)
            match = re.search(pattern, sentence_lower[position:])
            if match:
                absolute_pos = position + match.start()
                if best_pos is None or absolute_pos < best_pos:
                    best_pos = absolute_pos
                    best_match = match

        if best_match:
            position = best_pos + best_match.end() - best_match.start()
        else:
            #print(concept_variants)
            #print(f"Missing variant: {variants} in sentence: {sentence}")
            return False

    return True

def compute_coverage(candidates):

    all_concept_lists = []

    for i in range(len(candidates)):
        concepts = list(candidates.keys())[i].split('#')
        variants = []
        for c in concepts:
            v = generate_variants_commongen(c)
            variants.append(v)

        all_concept_lists.append(variants)

    correct = 0
    sentences = list(candidates.values())
    for sentence, concept_variants in zip(sentences, all_concept_lists):
        if check_ordered_coverage(sentence[0], concept_variants):
            correct += 1

    return (correct / len(sentences)) * 100

def load_times(json_path):
    with open(json_path, 'r') as file:
        data = json.load(file)

    time_values = [entry['time'] for entry in data]
    average_time = sum(time_values) / len(time_values) if time_values else 0

    return time_values, average_time

def evaluator(gts, res):
    # =================================================
    # Set up scorers
    # =================================================

    gts = tokenize(gts)
    res = tokenize(res)

    # =================================================
    # Set up scorers
    # =================================================

    scorers = [
        (Bleu(), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
        (Rouge(), "ROUGE_L"),
        (Cider(), "CIDEr"),
        #(Spice(), "SPICE")
    ]

    # =================================================
    # Compute scores
    # =================================================
    final_scores = {}

    for scorer, method in scorers:
        #print('computing %s score...' % (scorer.method()))
        score, scores = scorer.compute_score(gts, res)
        if type(method) == list:
            score_dict = {}
            scores_dict = {}
            for sc, scs, m in zip(score, scores, method):
                score_dict[m] = sc
                scores_dict[m] = list(scs)

            score_dict["BLEU4"] = weighted_mean(score_dict.values())
            final_scores["BLEU4"] = score_dict
            final_scores["BLEU4_scores"] = scores_dict
            
        else:

            final_scores[method] = score
            final_scores[method + "_scores"] = list(scores)
    
    spice_score, spice_scores = compute_spice(gts, res)

    final_scores["SPICE"] = spice_score
    final_scores["SPICE_scores"] = list(spice_scores)

    constraint_score = compute_coverage(res)

    final_scores["constraint"] = constraint_score

    return final_scores


def main():
    l = []
    path_refs = "../../data/Ordered CommonGen/ordered_commongen.json"
    path_cands = '../results/llama8b_ordered_commongen/'
    cands_files = [
        f for f in os.listdir(path_cands)
        if os.path.isfile(os.path.join(path_cands, f))
    ]
    for cand in cands_files:

        cand_path = os.path.join(path_cands, cand)
        
        l += [(path_refs, cand_path)]

    path_cands = '../results/openai/'
    cands_files = [
        f for f in os.listdir(path_cands)
        if os.path.isfile(os.path.join(path_cands, f))
    ]
    for cand in cands_files:

        cand_path = os.path.join(path_cands, cand)
        
        l += [(path_refs, cand_path)]

    for targ, cand in tqdm(l):
        targets = load_targets(targ)
        candidates = load_candidates(cand)

        targets = tokenize(targets)
        candidates = tokenize(candidates)
        
        targets, candidates = synchronize_dictionaries(targets, candidates)

        current_result = evaluator(targets, candidates)

        print(cand)
        print("ROUGE_L:", current_result["ROUGE_L"]*100, "BLEU4:", current_result["BLEU4"]["BLEU4"], "CIDEr:", current_result["CIDEr"]*10, "SPICE:", current_result["SPICE"], "constraint:", current_result["constraint"])

if __name__ == "__main__":
    main()