from prettytable import PrettyTable
from scipy.stats import spearmanr, pearsonr, kendalltau
import pandas as pd
import re
import json
import glob


def extract_answer(text, unmatched=[], idx=None):
    pattern = r"the score is (\d+(\.\d+)?)(?=\s*[.,]|$)"

    match = re.search(pattern, text, re.IGNORECASE)
    if match:
        return match.group(1)
    else:
        if idx is not None:
            unmatched.append(idx)
        else:
            unmatched.append(text)
        return 2.5


def calculate_correlation(pred_score, human_score, result):
    assert len(pred_score) == len(human_score)

    if len(result) == 0:
        result = {'pearson': 0, 'spearman': 0, 'kendalltau': 0}
    result['pearson'] += pearsonr(pred_score, human_score)[0]
    result['spearman'] += spearmanr(pred_score, human_score)[0]
    result['kendalltau'] += kendalltau(pred_score, human_score)[0]

    return result


def print_correlations(result, n):
    table = PrettyTable(['Pearson', 'Spearman', 'Kendall'])
    if n == 0:
        n = 1
    table.add_row(
        [round(result['pearson'] / n, 4), round(result['spearman'] / n, 4), round(result['kendalltau'] / n, 4)])
    print(table)


def compute_scores(path, scores_tf, scores_fc, scores_vf):
    score_files = sorted(glob.glob(f'{path}/*'))
    scores = [json.load(open(item, 'r')) for item in score_files]

    tf_scores = [item.get('Textual_Faithfulness', '') for item in scores]
    tf_scores = [float(extract_answer(item)) for item in tf_scores]

    fc_scores = [item.get('Frame_Consistency', '') for item in scores]
    fc_scores = [float(extract_answer(item)) for item in fc_scores]

    vf_scores = [item.get('Video_Fidelity', '') for item in scores]
    vf_scores = [float(extract_answer(item)) for item in vf_scores]

    results = {'pearson': 0, 'spearman': 0, 'kendalltau': 0}
    results_tf = calculate_correlation(scores_tf, tf_scores, results.copy())
    results_fc = calculate_correlation(scores_fc, fc_scores, results.copy())
    results_vf = calculate_correlation(scores_vf, vf_scores, results.copy())

    return {'textual_faithfulness': results_tf, 'frame_consistency': results_fc, 'video_fidelity': results_vf}


def calculate_all_correlations(score_dir, reference_scores_tf=None, reference_scores_fc=None, reference_scores_vf=None, LLM_REFIX=False, PARTIAL_CALCULATION=False, return_scores=False):
    # Step 1: Load all score files
    if isinstance(score_dir, str):
        score_files = sorted(glob.glob(f'{score_dir}/*'))
        scores = [json.load(open(item, 'r')) for item in score_files]
    elif isinstance(score_dir, list):
        scores = score_dir

    tf_scores = [None]
    fc_scores = [None]
    vf_scores = [None]
    # Step 2: Extract and convert scores
    if reference_scores_tf is not None:
        unmatched_tf = []
        tf_scores = [float(extract_answer(item['Textual_Faithfulness'], unmatched=unmatched_tf, idx=index)) if 'Textual_Faithfulness' in item else None for index, item in enumerate(scores)]

        print(f'Unmatched Textual Faithfulness Count: {len(unmatched_tf)}')
        # print(unmatched_tf)
        if PARTIAL_CALCULATION:
            reference_scores_tf = reference_scores_tf[:len(tf_scores)]
        if LLM_REFIX and len(unmatched_tf) != 0:
            pass
    if reference_scores_fc is not None:
        unmatched_fc = []
        fc_scores = [float(extract_answer(item['Frame_Consistency'], unmatched=unmatched_fc, idx=index)) if 'Frame_Consistency' in item else None for index, item in enumerate(scores)]
        print(f'Unmatched Frame Consistency Count: {len(unmatched_fc)}')
        # print(unmatched_fc)
        if PARTIAL_CALCULATION:
            reference_scores_fc = reference_scores_fc[:len(fc_scores)]
        if LLM_REFIX and len(unmatched_fc) != 0:
            pass
    if reference_scores_vf is not None:
        unmatched_vf = []
        vf_scores = [float(extract_answer(item['Video_Fidelity'], unmatched=unmatched_vf, idx=index)) if 'Video_Fidelity' in item else None for index, item in enumerate(scores)]
        print(f'Unmatched Video Fidelity Count: {len(unmatched_vf)}')
        # print(unmatched_vf)
        if PARTIAL_CALCULATION:
            reference_scores_vf = reference_scores_vf[:len(vf_scores)]
        if LLM_REFIX and len(unmatched_vf) != 0:
            pass

    # Step 3: Calculate correlations
    correlation_results = {}
    final_scores = {}
    if any(tf_scores):
        correlation_results['textual_faithfulness'] = calculate_correlation(reference_scores_tf, tf_scores, {'pearson': 0, 'spearman': 0, 'kendalltau': 0})
        final_scores['textual_faithfulness'] = tf_scores
    else:
        correlation_results['textual_faithfulness'] = {'pearson': None, 'spearman': None, 'kendalltau': None}

    if any(fc_scores):
        correlation_results['frame_consistency'] = calculate_correlation(reference_scores_fc, fc_scores, {'pearson': 0, 'spearman': 0, 'kendalltau': 0})
        final_scores['frame_consistency'] = fc_scores
    else:
        correlation_results['frame_consistency'] = {'pearson': None, 'spearman': None, 'kendalltau': None}

    if any(vf_scores):
        correlation_results['video_fidelity'] = calculate_correlation(reference_scores_vf, vf_scores, {'pearson': 0, 'spearman': 0, 'kendalltau': 0})
        final_scores['video_fidelity'] = vf_scores
    else:
        correlation_results['video_fidelity'] = {'pearson': None, 'spearman': None, 'kendalltau': None}

    # Step 4: Print results
    for key in ['textual_faithfulness', 'frame_consistency', 'video_fidelity']:
        pearson = correlation_results[key]['pearson']
        spearman = correlation_results[key]['spearman']
        kendalltau = correlation_results[key]['kendalltau']
        if pearson is not None and spearman is not None and kendalltau is not None:
            print(f"{pearson:.2f} & {spearman:.2f} & {kendalltau:.2f}")
        else:
            print("N/A & N/A & N/A")
    if return_scores:
        return final_scores
    return correlation_results


df = pd.read_csv('labeled_full.csv')
df.keys()
scores_tf = list(df['Textual Faithfulness'])
scores_fc = list(df['Frame Consistency'])
scores_vf = list(df['Video Fidelity'])

print('-' * 40 + '\n' + 'VideoLLaMA2:')
calculate_all_correlations('MLLM_outputs/VideoLLaMA2', scores_tf, scores_fc, scores_vf)

print('-' * 40 + '\n' + 'Gemini-pro:')
calculate_all_correlations('MLLM_outputs/Gemini-pro', scores_tf, scores_fc, scores_vf)

print('-' * 40 + '\n' + 'LLaVA-OneVision-7B:')
calculate_all_correlations('MLLM_outputs/LLaVA-OneVision-7B', scores_tf, scores_fc, scores_vf)
