import os
import json
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.cm as cm

# write strategy
figure_groups = {
    "freq": {
        "freq-2": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local4-topk30",
        "freq-4": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz4-local4-topk30",
        "freq-8": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz8-local4-topk30"
    },
    "ns": {
        "ns-0": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local0-topk30",
        "ns-4": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local4-topk30",
        "ns-8": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local8-topk30"
    },
    "topk": {
        "topk-10": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local4-topk10",
        "topk-30": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local4-topk30",
        "topk-50": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local4-topk50"
    },
    "write": {
        "LLM": "eval-output/v0.1.1/small/long-ctx/gpt-4.1-mini",
        "AWI": "eval-output/v0.1.1/small/in-context/gpt-4.1-mini-in-context-bsz2-local4",
        "RAG": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-RAG-bsz2-local4-topk30",
        "AWE": "eval-output/v0.1.1/small/mem0/gpt-4.1-mini-aw-bsz2-local4-topk30",
    }
}

def process_diagnostic_metrics(write_failure, read_failure, memory_success, overall):
    overall = (overall>0.5).astype(int)
    utilization_failure = ((overall < 0.5).astype(bool) & memory_success.astype(bool)).astype(float)
    memory_failure = write_failure + read_failure
    write_failure = write_failure.astype(float) / memory_failure
    read_failure = read_failure.astype(float) / memory_failure
    write_failure[overall.astype(bool) | memory_success.astype(bool)] = 0
    read_failure[overall.astype(bool) | memory_success.astype(bool)] = 0
    return write_failure.astype(float), read_failure.astype(float), utilization_failure.astype(float)

# colors = cm.tab10(np.linspace(0, 1, len(name2rundir)))
colors = ['#4C72B0', '#55A868', '#C44E52', '#DD8452']
# colors = cm.Dark2(np.linspace(0, 1, 4))
for key, name2rundir in figure_groups.items():
    plt.clf()
    for i, (name, rundir) in enumerate(name2rundir.items()):
        print(name)
        print(rundir)
        num_questions = 0
        write_failure, read_failure, utilization_failure = [], [], []
        # iterate through subdirs
        for subdir in os.listdir(rundir):
            if not os.path.isdir(os.path.join(rundir, subdir)):
                continue
            diag_path = os.path.join(rundir, subdir, "diagnosis_metrics.json")
            overall_path = os.path.join(rundir, subdir, "overall_metrics.json")
            with open(diag_path) as f:
                diag = json.load(f)
                write_failure_i = np.array(diag["write_failure"])
                read_failure_i = np.array(diag["read_failure"])
                memory_success_i = np.array(diag["memory_success"])
            with open(overall_path) as f:
                overall_i = np.array(json.load(f)["accuracy"])
            write_failure_i, read_failure_i, utilization_failure_i = process_diagnostic_metrics(write_failure_i, read_failure_i, memory_success_i, overall_i)
            write_failure.append(write_failure_i)
            read_failure.append(read_failure_i)
            utilization_failure.append(utilization_failure_i)
            num_questions += write_failure_i.shape[1]
        write_failure = np.concatenate(write_failure, axis=1)
        read_failure = np.concatenate(read_failure, axis=1)
        utilization_failure = np.concatenate(utilization_failure, axis=1)
        x = np.arange(write_failure.shape[0])
        y_write = write_failure.sum(axis=1) / num_questions
        y_read = read_failure.sum(axis=1) / num_questions
        y_util = utilization_failure.sum(axis=1) / num_questions
        # y_read = num_read_failure / num_total
        plt.plot(x, y_write, label=name+"-write", marker='o', color=colors[i])
        plt.plot(x, y_read, label=name+"-read", marker='x', linestyle='--', color=colors[i])
        plt.plot(x, y_util, label=name+"-util.", marker='s', linestyle='-.', color=colors[i])

    # plt.legend(fontsize=8)
    ncols = 4 if key == "write" else 3
    plt.legend(loc='upper center', 
        bbox_to_anchor=(0.5, 1.2),
        ncol=ncols,
        fancybox=True, 
        shadow=True,
        columnspacing=0.5,
        fontsize=11
    )
    plt.xlabel("Period Index", fontsize=14)
    plt.ylabel("Failure Rate", fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    
    plt.grid()
    plt.savefig(f"figures/diagnosis/{key}.png", dpi=300, bbox_inches='tight')
    plt.savefig(f"figures/diagnosis/{key}.pdf", dpi=300, bbox_inches='tight')
    plt.close()
