import json
import os
import subprocess

from glob import glob

def extract_caption(input_json, output_json):
    out_dict = {}
    with open(input_json, "r") as f_in:
        json_data = json.load(f_in)
        for obj in json_data:
            out_dict[obj["audio_id"]] = {
            "caption":  obj["caption"],
            "timestamp_events": obj["timestamp_events"]
            }
            # print(obj["audio_id"])
    with open(output_json, "w") as f_out:
        json_string = json.dumps(out_dict, indent=2)
        f_out.write(json_string)

def format_prediction(pred_json, caption_json, out_json):
    # cap_dict = {}
    with open(caption_json, "r") as f_cap:
        cap_dict = json.load(f_cap)
    
    with open(pred_json, "r") as f_pred:
        pred_data = json.load(f_pred)
        for data in pred_data:
            audio_id = os.path.basename(data["audio_id"])
            data["scence_caption"] = cap_dict[audio_id]["caption"]
            data["timestamp_events"] = cap_dict[audio_id]["timestamp_events"]
    
    with open(out_json, "w") as f_out:
        json_string = json.dumps(pred_data, indent=2)
        f_out.write(json_string)


def analyze_score(score_json):
    clar_list, corr_list, enga_list = [], [], []
    with open(score_json, "r") as f_in:
        # json_data = json.load(f_in)
        for line in f_in.readlines():
            # print(line)
            obj = json.loads(line)
            clarity_s = float(obj["clarity"])
            correctness_s = float(obj["correctness"])
            engagement_s = float(obj["engagement"])
            clar_list.append(clarity_s)
            corr_list.append(correctness_s)
            enga_list.append(engagement_s)
            # print(clarity_s)
    # print(clar_list)
    print("Clarity score: ", sum(clar_list)/len(clar_list))
    print("Correctness score: ", sum(corr_list)/len(corr_list))
    print("Engagement score: ", sum(enga_list)/len(enga_list))

def format_mmu_pred(gt_file, pred_file, out_file):
    pred_dict = {}
    with open(gt_file, "r") as gt_f:
        gt_data = json.load(gt_f)
    with open(pred_file, "r") as pred_f:
        pred_data= json.load(pred_f)
    for data in pred_data:
        pred_dict[data["audio_id"]] = data["prediction"]

    for data in gt_data:
        data["model_output"] = pred_dict[data["audio_id"]]

    with open(out_file, "w") as out_f:
        json_data = json.dumps(gt_data, indent=4)
        out_f.write(json_data)

def cvt_audio(folder):
    list_aud = glob(os.path.join(folder, "*.wav"))
    out_folder = "new-test-mini-audios"
    cmd = "ffmpeg -i {input_audio} -ar 16000 -ac 1 {output_audio}"
    for audio in list_aud:
        audio_fn = os.path.basename(audio)
        out_audio = os.path.join(out_folder, audio_fn)
        exe_cmd = cmd.format(audio, out_audio)
        subprocess.Popen(exe_cmd)
    print(list_aud)

# def evalute_mmau(json_file):
#     corr = 0
#     with open(json_file, "r") as json_f:
#         j_data = json.load(json_f)
#         for data in j_data:
#             ans = data["answer"] 

if __name__ =="__main__":
    input_json = "inference_json/CompA-R-test.json"
    cap_json = "inference_json/audio_caption.json"
    pred_json = "inference_json/GAMMA_USW_fn.json"
    out_json = "inference_json/GAMMA_USW_fn_complete.json"
    # extract_caption(input_json, cap_json)
    # format_prediction(pred_json, cap_json, out_json)
    # analyze_score("inference_json/caption_score_usw.json")
    # cvt_audio("test-mini-audios")
    format_mmu_pred("mmau/mmau-test-mini.json", "mmau_test_mini_GAMMA_alpha0.1_2k4steps.json","mmau_test_mini_GAMMA_alpha0.1_2k4steps_eval.json")