import glob as glob
import json
import os.path
import warnings

import pandas as pd
from nlgmetricverse import NLGMetricverse
from nlgmetricverse import load_metric

warnings.filterwarnings("ignore")

metrics = [
    load_metric("bleu", resulting_name="bleu_1", compute_kwargs={"max_order": 1}),
    load_metric("bleu", resulting_name="bleu_2", compute_kwargs={"max_order": 2}),
    load_metric("bleu", resulting_name="bleu_3", compute_kwargs={"max_order": 3}),
    load_metric("bleu", resulting_name="bleu_4", compute_kwargs={"max_order": 4}),
    load_metric("bertscore", resulting_name="bertscore_1",
                compute_kwargs={"model_type": "bert-base-uncased"}),
    load_metric("rouge", resulting_name="rouge"),
    load_metric("meteor", resulting_name="meteor"),
    load_metric("cider", resulting_name="cider")]
scorer = NLGMetricverse(metrics=metrics)

print("finish metric loading")


def remove_contain_none(tail: list):
    unique_tails = list(set(tail))
    if len(unique_tails) == 1 and unique_tails[0] == 'none':
        return ['none']
    elif len(unique_tails) == 1 and unique_tails[0] != 'none':
        return unique_tails
    else:
        return [i for i in unique_tails if i != 'none']


generation_folders = glob.glob('./eval_output/*')

record_dict = {}

if os.path.exists('./COMET-result.csv'):
    result = pd.read_csv('./COMET-result.csv')
else:
    result = pd.DataFrame(
        columns=['training-data', 'bleu_1', 'bleu_2', 'bleu_3', 'bleu_4', 'meteor', 'rougeL', 'cider',
                 'bertscore'])

for g in generation_folders:
    if not os.path.exists('{}/generated_predictions.jsonl'.format(g)):
        continue
    else:
        training_data = g.split('/')[-1].split('_')[0] if 'FT' not in g else '_'.join(g.split('/')[-1].split('_')[:2])
        learning_rate = g.split('/')[-1].split('_')[1] if 'FT' not in g else g.split('/')[-1].split('_')[2]
        training_data = training_data + '_' + learning_rate
        if training_data in result['training-data'].tolist():
            continue

        record_dict[training_data] = []

        with open('{}/generated_predictions.jsonl'.format(g), 'r') as f:
            lines = f.readlines()
        # convert to dataframe
        df = pd.DataFrame([json.loads(l) for l in lines])
        df = df[df['predict'].apply(lambda x: len(x.split(' ')) <= 20)]
        generations = df['predict'].tolist()
        references = df['label'].apply(lambda x: [ss.strip() for ss in x.split('|')]).tolist()

        print('\n', g)

        print('\n\n')
        print(df['predict'].tolist()[:3])
        print(df['label'].tolist()[:3])
        print('\n\n')

        eval_result = scorer(predictions=generations, references=references)

        print(eval_result['total_items'], eval_result['empty_items'], eval_result['total_time_elapsed'])
        for metric in ['bleu_1', 'bleu_2', 'bleu_3', 'bleu_4', 'meteor', 'rougeL', 'cider', 'bertscore_1', ]:
            if metric != 'rougeL':
                print(metric, eval_result[metric]['score'])
                record_dict[training_data].append("%.3f" % (100 * eval_result[metric]['score']))
            else:
                print(metric, eval_result['rouge']['rougeL'])
                record_dict[training_data].append("%.3f" % (100 * eval_result['rouge']['rougeL']))
        print('\n\n\n')

for k in record_dict:
    result.loc[len(result)] = [k] + record_dict[k]
result['bertscore'] = result['bertscore'].astype(float)
result.drop_duplicates().sort_values(by=['bertscore']).to_csv('./COMET-result.csv', index=None)

for data in result['training-data'].unique():
    print('\n\n', data)
    if 'candle' in data:
        # print the max scores for each of the metrics
        record = result[result['training-data'] == data].reset_index(drop=True)
        print(" & ".join("%.2f" % record[metric].astype(float).max() if record[metric].astype(
            float).max() >= 10 else "%.3f" % record[metric].astype(float).max() for metric in
                         ['bleu_1', 'bleu_2', 'bleu_3', 'bleu_4', 'meteor', 'rougeL', 'cider', 'bertscore']))
    else:
        record = result[result['training-data'] == data].reset_index(drop=True)
        print(" & ".join("%.2f" % record[metric].astype(float).min() if record[metric].astype(
            float).min() >= 10 else "%.3f" % record[metric].astype(float).min() for metric in
                         ['bleu_1', 'bleu_2', 'bleu_3', 'bleu_4', 'meteor', 'rougeL', 'cider', 'bertscore']))
