import argparse
import numpy as np
import json
import os
import pandas as pd
from scipy.stats import kendalltau, spearmanr, pearsonr

def main(args):
    format_type = None
    if "-nl-" in args.file:
        format_type = "nl"
    elif "-no-format-" in args.file:
        format_type = "no-format"
    elif "-json-" in args.file:
        format_type = "json"
    corrects = []
    formats = []
    labels = []
    answers = []
    with open(args.file, 'r') as f:
        lines = f.readlines()
        report_lines = []
        for line in lines:
            line = json.loads(line)
            new_line = check_answer(line, format_type)
            if new_line["format"] == True and new_line["answer"] is not None:
                labels.append(new_line["label"])
                answers.append(new_line["answer"])
            report_lines.append(new_line)
            corrects.append(new_line["correct"])
            formats.append(new_line["format"])
    df = pd.DataFrame({"label": labels, "answer": answers})
    
    answers_valid, labels_valid = [], []
    for i, row in df.iterrows():
        if row["answer"] is not None:
            answers_valid.append(row["answer"])
            labels_valid.append(row["label"])
    answers_valid = np.array(answers_valid, dtype=float)
    labels_valid = np.array(labels_valid, dtype=float)
    
    with open(args.file.replace('.jsonl', '_report.jsonl'), 'w') as f:
        f.write(json.dumps({
            "Well-formatted": np.mean(formats), 
            "Spearman Correlation": spearmanr(labels_valid, answers_valid).statistic,
            "Kendall Tau": kendalltau(labels_valid, answers_valid).statistic
            }) + "\n")
        for line in report_lines:
            f.write(json.dumps(line) + "\n")
    
    well_formatted = np.mean(formats)
    print(f"Well-formatted: {well_formatted:.4f}")
    acc = np.mean(corrects)
    print(f"Spearman Correlation: {spearmanr(labels_valid, answers_valid).statistic}")
    print(f"Kendall Tau: {kendalltau(labels_valid, answers_valid).statistic}")
            
def check_answer(line, format_type):
    # check format
    if format_type == "json":
        format, answer_str = check_json_answer(line)
    elif format_type == "nl":
        format, answer_str = check_structured_answer(line)
    elif format_type == "no-format":
        if "-structured-" in args.file or "-decog-" in args.file:
            format, answer_str = check_structured_answer(line)
        else:
            format, answer_str = check_structured_answer(line)
    correct = str(line["label"]) == answer_str
    return {"format": format, "correct": correct, "label": str(line["label"]), "answer": answer_str}


def check_structured_answer(line):
    output = line["output"]
    for i in range(1, 6):
        keys = [f'Rating: {i}', f'Rating:{i}', f'Rating: **{i}']
        
        for new_key in keys:
            if new_key in output:
                return True, i
    return False, None
    
def check_json_answer(line):
    output = line["output"]
    # check json validity
    if '```' in output:
        output = output.replace('```json','```').split('```')[1]
    try:
        parsed_json = json.loads(output)
    except json.JSONDecodeError:
        return False, None
    return True, str(parsed_json["rating"])

def parse_and_clean(input_string):
    valid_chars = set("12345,'\" ")
    result = []
    for char in input_string:
        if char in valid_chars:
            result.append(char)
        else:
            break
    parsed_string = ''.join(result)
    cleaned_string = parsed_string.replace('"', '').replace("'", '').replace(' ', '').replace(',', '')
    return cleaned_string

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--file', type=str, default="")
    parser.set_defaults(feature=True)
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    main(args)