import torch._dynamo
torch._dynamo.config.suppress_errors = True

import warnings
warnings.filterwarnings("ignore")

import os
import json
import spacy
import numpy as np
from tqdm import tqdm
from bleu.bleu import Bleu
from rouge.rouge import Rouge
from cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from utils import generate_variants_commongen

nlp = spacy.load("en_core_web_md")

def load_targets(dataset_file):
    with open(dataset_file, 'r', encoding='utf-8') as fin:
        examples = json.load(fin)
        
    examples_ = {}
    for example in examples:
        key = '#'.join(c for c in example["concepts"])
        if key in examples_:
            examples_[key].append(example['target'])

        else:
            examples_[key] = [example["target"]]
    
    return examples_

def load_candidates(dataset_file):
    with open(dataset_file, 'r', encoding='utf-8') as fin:
        examples = json.load(fin)
        
    examples_ = {}
    for example in examples:
        sentence = example["sentence"].rstrip("\n")
        if sentence.startswith(":"):
            sentence = sentence[1:].lstrip()
        
        key = '#'.join(c for c in example["concepts"])
        
        if key not in examples_:
            examples_[key] = []
        examples_[key].append(sentence)
    
    for key, value in examples_.items():
        if isinstance(value, list) and len(value) > 1:
            examples_[key] = [value[0]]

    return examples_
    
def tokenize(dict):
    for key in dict:
        new_sentence_list = []
        for sentence in dict[key]:
            a = ''
            for token in nlp(sentence):
                a += token.text
                a += ' '
            new_sentence_list.append(a.rstrip())
        dict[key] = new_sentence_list
    return dict

def synchronize_dictionaries(dict1, dict2):
    common_keys = dict1.keys() & dict2.keys()

    filtered_dict1 = {k: dict1[k] for k in dict1 if k in common_keys}
    filtered_dict2 = {k: dict2[k] for k in dict2 if k in common_keys}

    return filtered_dict1, filtered_dict2

def compute_spice_single(scores_list):

    spice_scores = []

    for score in scores_list:
        all_data = score['All']

        tp = all_data['tp']
        fp = all_data['fp']
        fn = all_data['fn']

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0

        if precision + recall > 0:
            f1 = 2 * precision * recall / (precision + recall)
        else:
            f1 = 0.0

        spice_scores.append(f1)

    return spice_scores

def compute_spice(targets, candidates):
    spice = Spice()

    spice_score, scores = spice.compute_score(targets, candidates)
    scores = compute_spice_single(scores)
    return spice_score * 100, np.array(scores)

def weighted_mean(input):
    weights = [0.25, 0.25, 0.25, 0.25]
    weighted_mean = sum(x * p for x, p in zip(input, weights))
    score = weighted_mean * 100
    return score

def match_percentage(word_groups, text):
    text = text.lower()
    matched_groups = 0

    for group in word_groups:
        if any(word in text for word in group):
            matched_groups += 1

    percentage = (matched_groups / len(word_groups)) * 100
    return percentage

def constraint_satisfaction(candidates):
    percentages = []
    for i in range(len(candidates)):
        concepts = list(candidates.keys())[i].split('#')
        variants = []
        for c in concepts:
            v = generate_variants_commongen(c)
            v += [s.capitalize() for s in v if s.capitalize() not in v]
            variants.append(v)
        percentages.append(match_percentage(variants, list(candidates.values())[i][0]))
    return sum(percentages) / len(percentages), np.array(percentages)

def load_times(json_path):
    with open(json_path, 'r') as file:
        data = json.load(file)

    time_values = [entry['time'] for entry in data]
    average_time = sum(time_values) / len(time_values) if time_values else 0

    return time_values, average_time

def evaluator(gts, res):
    # =================================================
    # Set up scorers
    # =================================================

    gts = tokenize(gts)
    res = tokenize(res)

    # =================================================
    # Set up scorers
    # =================================================

    scorers = [
        (Bleu(), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
        (Rouge(), "ROUGE_L"),
        (Cider(), "CIDEr"),
    ]

    # =================================================
    # Compute scores
    # =================================================
    final_scores = {}

    for scorer, method in scorers:
        #print('computing %s score...' % (scorer.method()))
        score, scores = scorer.compute_score(gts, res)
        if type(method) == list:
            score_dict = {}
            scores_dict = {}
            for sc, scs, m in zip(score, scores, method):
                score_dict[m] = sc
                scores_dict[m] = list(scs)

            score_dict["BLEU4"] = weighted_mean(score_dict.values())
            final_scores["BLEU4"] = score_dict
            final_scores["BLEU4_scores"] = scores_dict
            
        else:

            final_scores[method] = score
            final_scores[method + "_scores"] = list(scores)
    
    spice_score, spice_scores = compute_spice(gts, res)

    final_scores["SPICE"] = spice_score
    final_scores["SPICE_scores"] = list(spice_scores)

    constraint_score, constraint_scores = constraint_satisfaction(res)

    final_scores["constraint"] = constraint_score
    final_scores["constraint_scores"] = list(constraint_scores)

    return final_scores

def main():
    l = []
    path_refs = '../results/gpt2_supervised_commongen/references/'
    path_cands = '../results/gpt2_supervised_commongen/candidates/'
    path_confs = '../results/gpt2_supervised_commongen/times/'
    refs_files = os.listdir(path_refs)
    #cands_files = os.listdir(path_cands)
    #confs_files = os.listdir(path_confs)

    for refs in refs_files:

        cands = refs.replace('references', 'candidates')
        times = refs.replace('references', 'times')
        ref_path = os.path.join(path_refs, refs)
        cand_path = os.path.join(path_cands, cands)
        times_path = os.path.join(path_confs, times)
        
        l += [(ref_path, cand_path, times_path)]

    path_refs = '../results/gpt2_unsupervised_commongen/references/'
    path_cands = '../results/gpt2_unsupervised_commongen/candidates/'
    path_confs = '../results/gpt2_unsupervised_commongen/times/'
    refs_files = os.listdir(path_refs)
    #cands_files = os.listdir(path_cands)
    #confs_files = os.listdir(path_confs)

    for refs in refs_files:

        cands = refs.replace('references', 'candidates')
        times = refs.replace('references', 'times')
        ref_path = os.path.join(path_refs, refs)
        cand_path = os.path.join(path_cands, cands)
        times_path = os.path.join(path_confs, times)
        
        l += [(ref_path, cand_path, times_path)]

    for targ, cand, times in tqdm(l):

        targets = load_targets(targ)
        candidates = load_candidates(cand)
        times_list, avg_time = load_times(times)

        targets = tokenize(targets)
        candidates = tokenize(candidates)
        
        targets, candidates = synchronize_dictionaries(targets, candidates)

        current_result = evaluator(targets, candidates)

        current_result["times"] = list(times_list)
        current_result["avg_time"] = avg_time

        print(cand)
        print("ROUGE_L:", current_result["ROUGE_L"]*100, "BLEU4:", current_result["BLEU4"]["BLEU4"], "CIDEr:", current_result["CIDEr"]*10, "SPICE:", current_result["SPICE"], "constraint:", current_result["constraint"], "avg_time:", current_result["avg_time"])

if __name__ == "__main__":
    main()