import json
import re
import numpy as np
import os
import argparse
import scipy.stats
from scipy.stats import spearmanr, kendalltau

def spearman_corr(x, y):
    corr, _ = spearmanr(x, y)
    return corr

def kendall_corr(x, y):
    corr, _ = kendalltau(x, y)
    return corr

def extract_last_0_to_3(text):
    # Regular expression to match numbers between 0 and 3
    pattern = r'[0-3]'
    
    # Find all matches as an iterator
    matches = re.finditer(pattern, text)
    
    # Initialize the last match as None
    last_match = None
    
    # Iterate through all matches to find the last one
    for match in matches:
        last_match = match.group()
    
    return last_match

def correlation(v, u):
    return {
        "Pearson" : np.corrcoef(v, u)[0, 1],
        "Spearman" : spearman_corr(v, u),
        "Kendall" : kendall_corr(v, u)
    }

def correlation_sample_for_n(v, u, n):
    avg_p = 0
    avg_s = 0
    avg_k = 0
    i = 0
    cnt = 0

    while i + n <= len(v): 
        vs = v[i: i + n]
        us = u[i: i + n]

        if np.isnan(scipy.stats.spearmanr(vs, us)[0]): 
            i += 1
            continue

        avg_p += np.corrcoef(vs, us)[0, 1]
        avg_s += scipy.stats.spearmanr(vs, us)[0]
        avg_k += scipy.stats.kendalltau(vs, us)[0]
        cnt += 1
        i += n
    
    avg_p /= cnt
    avg_s /= cnt
    avg_k /= cnt
    return {"pearson": avg_p, "spearman": avg_s, "kendall": avg_k}

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_path', type=str, default=r"xxx\inference_results.jsonl")  
    parser.add_argument('--output_path', type=str, default="./results")
    args = parser.parse_args()

    with open(args.input_path, "r", encoding='utf-8') as f:
        data = [json.loads(line) for line in f.readlines()]
    
    misattrillm_scores = []
    human_scores = []
    human_data = []
    cnt = 0
    for i, d in enumerate(data):
        cnt += 1

        output = d['predict']
        score = extract_last_0_to_3(output)

        if score is None: 
            score = -1                
        
        goldscore = extract_last_0_to_3(d['label'])
        misattrillm_scores.append(int(score))
        human_scores.append(int(goldscore))  
        human_data.append(d)

    misattrillm_scores = np.array(misattrillm_scores)
    human_scores = np.array(human_scores)
    misattrillm_scores = np.array(misattrillm_scores, dtype=int)
    human_scores = np.array(human_scores, dtype=int)

    print(human_scores)
    print(misattrillm_scores)

    print("Overall Scores:")
    print("pearson", np.corrcoef(human_scores, misattrillm_scores)[0, 1])
    print("spearman", scipy.stats.spearmanr(human_scores, misattrillm_scores)[0])
    print("kendall", scipy.stats.kendalltau(human_scores, misattrillm_scores)[0])