import os
import json
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.cm as cm

# write strategy
figure_groups = {
    # "freq": {
    #     "freq-2": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local4-topk30",
    #     "freq-4": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz4-local4-topk30",
    #     "freq-8": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz8-local4-topk30"
    # },
    # "ns": {
    #     "ns-0": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local0-topk30",
    #     "ns-4": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local4-topk30",
    #     "ns-8": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local8-topk30"
    # },
    # "topk": {
    #     "topk-10": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local4-topk10",
    #     "topk-30": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local4-topk30",
    #     "topk-50": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local4-topk50"
    # },
    "feedback": {
        "No Evolution": "amem-eval-output/in-context-evolution/info_type/gpt-4.1-mini-in-context-bsz2-local4",
        "Question Only": "amem-eval-output/in-context-evolution/info_type/vanilla_question_only",
        "Complete": "amem-eval-output/in-context-evolution/info_type/vanilla",
    }
}

def process_diagnostic_metrics(write_failure, read_failure, memory_success, overall):
    overall = (overall>0.5).astype(int)
    utilization_failure = ((overall < 0.5).astype(bool) & memory_success.astype(bool)).astype(float)
    memory_failure = write_failure + read_failure
    write_failure = write_failure.astype(float) / memory_failure
    read_failure = read_failure.astype(float) / memory_failure
    write_failure[overall.astype(bool) | memory_success.astype(bool)] = 0
    read_failure[overall.astype(bool) | memory_success.astype(bool)] = 0
    return write_failure.astype(float), read_failure.astype(float), utilization_failure.astype(float)

# colors = cm.tab10(np.linspace(0, 1, len(name2rundir)))
colors = ['#4C72B0', '#55A868', '#C44E52', '#DD8452']
# colors = cm.Dark2(np.linspace(0, 1, 4))
for key, name2rundir in figure_groups.items():
    plt.clf()
    for i, (name, rundir) in enumerate(name2rundir.items()):
        print(name)
        print(rundir)
        num_questions = 0
        write_failure, read_failure, utilization_failure = [], [], []
        # iterate through subdirs
        for subdir in os.listdir(rundir):
            if not os.path.isdir(os.path.join(rundir, subdir)):
                continue
            diag_path = os.path.join(rundir, subdir, "diagnosis_metrics.json")
            overall_path = os.path.join(rundir, subdir, "overall_metrics.json")
            with open(diag_path) as f:
                diag = json.load(f)
                write_failure_i = np.array(diag["write_failure"])
                read_failure_i = np.array(diag["read_failure"])
                memory_success_i = np.array(diag["memory_success"])
            with open(overall_path) as f:
                overall_i = np.array(json.load(f)["accuracy"])
            write_failure_i, read_failure_i, utilization_failure_i = process_diagnostic_metrics(write_failure_i, read_failure_i, memory_success_i, overall_i)
            write_failure.append(write_failure_i)
            read_failure.append(read_failure_i)
            utilization_failure.append(utilization_failure_i)
            num_questions += write_failure_i.shape[1]
        write_failure = np.concatenate(write_failure, axis=1)
        read_failure = np.concatenate(read_failure, axis=1)
        utilization_failure = np.concatenate(utilization_failure, axis=1)
        x = np.arange(write_failure.shape[0])
        y_write = write_failure.sum(axis=1) / num_questions
        y_read = read_failure.sum(axis=1) / num_questions
        y_util = utilization_failure.sum(axis=1) / num_questions
        # y_read = num_read_failure / num_total
        plt.plot(x, y_write, label=name+"-write", marker='o', color=colors[i])
        plt.plot(x, y_read, label=name+"-read", marker='x', linestyle='--', color=colors[i])
        plt.plot(x, y_util, label=name+"-util.", marker='s', linestyle='-.', color=colors[i])

        print(f"{name}, Write: {np.mean(y_write):.3f}, Read: {np.mean(y_read):.3f}, Util: {np.mean(y_util):.3f}")

    # plt.legend(fontsize=8)
    ncols = 4 if key == "write" else 3
    plt.legend(loc='upper center', 
        bbox_to_anchor=(0.5, 1.2),
        ncol=ncols,
        fancybox=True, 
        shadow=True,
        columnspacing=0.5,
        fontsize=11
    )
    plt.xlabel("Period Index", fontsize=14)
    plt.ylabel("Failure Rate", fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    
    plt.grid()
    plt.savefig(f"figures/diagnosis/{key}.png", dpi=300, bbox_inches='tight')
    plt.savefig(f"figures/diagnosis/{key}.pdf", dpi=300, bbox_inches='tight')
    plt.close()
