import os
import json
import argparse
from scipy import stats
import contextlib

def extend(array):
    """
    Flatten a list of lists.
    Example: [[1,2],[3,4]] -> [1,2,3,4]
    """
    return [item for sublist in array for item in sublist]

def calculate_correlation(file_name):
    references, predictions = [], []
    with open(file_name, "r") as f:
        data = json.load(f)

    invalid_cnt = 0
    for d in data["data"]:
        if d["code_gpt_score"]["code_gpt_score"] == -1.0:
            invalid_cnt += 1
            continue
        references.append([float(d["pass"])])
        predictions.append([float(d["code_gpt_score"]["code_gpt_score"])])

    kendalltau, _ = stats.kendalltau(extend(references), extend(predictions))
    spearmanr, _ = stats.spearmanr(extend(references), extend(predictions))

    return round(kendalltau, 3), round(spearmanr, 3), invalid_cnt, len(data["data"])

def print_correlation(file_path):
    """Print simple kendall & spearman on code_gpt_score."""
    fname = os.path.basename(file_path)
    kt, sr, bad, total = calculate_correlation(file_path)
    print(fname)
    print(f"kendalltau: {kt:.3f}")
    print(f"spearmanr: {sr:.3f}")
    print(f"invalid_cnt: {bad}/{total}")

def print_other_correlation(file_path):
    """Print kendall & spearman for other metrics in an 'other-metrics.json'."""
    fname = os.path.basename(file_path)
    refs = []
    preds = {k: [] for k in ("bleu","chrf","rougeL","ruby","meteor","CodeBERTScore_f1","CodeBERTScore_f3")}

    with open(file_path, "r") as f:
        data = json.load(f)

    for x in data:
        refs.append([float(x["pass"])])
        preds["bleu"].append([float(x["bleu"])])
        preds["chrf"].append([float(x["chrf"])])
        preds["rougeL"].append([float(x["rougel"])])
        preds["ruby"].append([float(x["ruby"])])
        preds["meteor"].append([float(x["meteor"])])
        preds["CodeBERTScore_f1"].append([x["code_bert_score_f1"]])
        preds["CodeBERTScore_f3"].append([x["code_bert_score_f3"]])

    print(fname)
    for key, arr in preds.items():
        kt, kp = stats.kendalltau(extend(refs), extend(arr))
        sr, sp = stats.spearmanr(extend(refs), extend(arr))
        print(key)
        print(f"  kendalltau: {kt:.3f} ({kp:.3f})")
        print(f"  spearmanr: {sr:.3f} ({sp:.3f})")

def batch_process():
    #roots = [
        #"output/humaneval_test2",
        #"output/humaneval_test3",
        #"output_add_path/humaneval_test2"
    #]
    roots = [
        "output/humaneval_test4",
    ]
    #test_cases = ["cpp-small-test", "go-small-test", "java-small-test", "js-small-test"]
    test_cases = ["cpp-small-test", "go-small-test", "java-small-test", "js-small-test"]

    for root in roots:
        for tc in test_cases:
            folder = os.path.join(root, tc)
            if not os.path.isdir(folder):
                continue
            for fname in os.listdir(folder):
                if not fname.endswith(".json"):
                    continue
                in_path = os.path.join(folder, fname)
                out_path = os.path.join(folder, fname.replace(".json", ".txt"))

                # Redirect all prints into the .txt file
                with open(out_path, "w") as outf, contextlib.redirect_stdout(outf):
                    if fname == "other-metrics.json":
                        print_other_correlation(in_path)
                    else:
                        print_correlation(in_path)

                print(f"[ok] {in_path} → {out_path}")

if __name__ == "__main__":
    batch_process()
