import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

import json
import matplotlib.pyplot as plt
import itertools
import numpy as np
from matplotlib.ticker import PercentFormatter
from utils import APPLICATION

def load_entries_from_record(record_jsonl_path, entry_filter=lambda e: True):
    entry_dict_list = []
    with open(record_jsonl_path, 'r') as file:
        for l in file:
            entry_dict = json.loads(l)
            if entry_filter(entry_dict) == True:
                entry_dict_list.append(entry_dict)
    
    return entry_dict_list

def plot_abs_hr(abs_m_algo_hr: dict, png_path, xtick_list, alg_num_list, title, include_sota=True, include_hybrid=True):
    print(xtick_list)
    print(alg_num_list)
    assert len(xtick_list) == len(alg_num_list)

    def latency_improvement(hr, base_hr):
        '''
        latency = t_h x %hit + t_m x (1 - %hit)
                = [ %hit + k x (1 - %hit) ] x t_h
        where t_m = k x t_h.

        latency_improvement = ratio(%h + k(1 - %h))
        '''
        k = 60
        return (hr + k * (1 - hr)) / (base_hr + k * (1 - base_hr))

    def print_y(y):
        seg_y_list = [
            y[sum(alg_num_list[:i]):sum(alg_num_list[:(i+1)])]
            for i in range(len(xtick_list))
        ]
        sota_num = 9 if APPLICATION == "cache" else 1
        if include_sota == True:
            seg_y_list.append(y[-sota_num:])
        seg1_max = max(seg_y_list[0]) if APPLICATION == "cache" else min(seg_y_list[0])
        loc_xtick_list = [x for x in xtick_list[1:]]
        if include_sota:
            loc_xtick_list.append("SOTA")
        for seg, loc_xtick in zip(seg_y_list[1:], loc_xtick_list):
            print(f"{xtick_list[0]} - {loc_xtick}:")
            seg_min = min(seg) if APPLICATION == "cache" else max(seg)
            seg_max = max(seg) if APPLICATION == "cache" else min(seg)
            print(f"\t{round((seg1_max - seg_max) * 100, 2)}% - {round((seg1_max - seg_min) * 100, 2)}%")
            print(f"\t{round(latency_improvement(seg1_max, seg_max), 2)}x - {round(latency_improvement(seg1_max, seg_min), 2)}x")


    markers = itertools.cycle("osv>v*p") if APPLICATION == "cache" else itertools.cycle(reversed("osv>v*p"))
    colors = itertools.cycle(reversed(["#b2182b", "#ef8a62", "#67a9cf", "#2166ac"])) if APPLICATION == "cache" else itertools.cycle(["#b2182b", "#ef8a62", "#67a9cf", "#2166ac"])
    ### Select algo
    raw_algo_list = list(abs_m_algo_hr.keys())
    algo_filter_list = []
    if include_hybrid == False:
        algo_filter_list.append(lambda a: a != "hybrid")
    else:
        if len(raw_algo_list) > alg_num_list[0]:
            algo_filter_list.append(lambda a: a != raw_algo_list[alg_num_list[0]])
        else:
            # assert len(raw_algo_list) < 10, f"len(raw_algo_list) = {len(raw_algo_list)} < 10"
            pass
    if include_sota == False:
        if APPLICATION == "cache":
            algo_filter_list.append(lambda a: a not in ["s3fifo", "slru", "tinyLFU", "lfu", "arc", "sieve", "clock", "lru", "fifo"])
        else:
            algo_filter_list.append(lambda a: a not in ["0"])
    check_algo_func = lambda a: all([f(a)==True for f in algo_filter_list])
    algo_list = [a for a in raw_algo_list if check_algo_func(a) == True]
    print(f"algo_list:\n\t", "\n\t".join(algo_list))

    ### collect y data
    block_size = len(xtick_list) + 1 if include_sota else len(xtick_list)
    width = (20 * block_size / 5) * len(xtick_list) / 3 if block_size > 1 else 12
    height = 6 * 1.4 * block_size / 5 if block_size > 1 else 3.6 * 1.4
    
    plt.figure(figsize=(width, height))
    font_size = 35 * block_size / 5 if block_size > 1 else 21
    
    ax = plt.gca()
    ax.set_axisbelow(True)
    ax.yaxis.grid(True, linestyle='--', color='lightgray')
    for x_line in range(len(algo_list)):
        ax.axvline(x=x_line, linestyle="--", color='lightgray', zorder=0, linewidth=0.8)
    percentiles = [50, 75, 90] if APPLICATION == "cache" else [10, 25, 50]
    for perc in percentiles:
        y = [np.percentile(abs_m_algo_hr[algo], perc) for algo in algo_list]
        plt.scatter(range(len(y)), y, label=f"P{perc}", marker=next(markers), color=next(colors), s=font_size * 6)
        print(f"P{perc}")
        print_y(y)
        if perc == 50:
            y = [np.mean(abs_m_algo_hr[algo]) for algo in algo_list]
            plt.scatter(range(len(y)), y, label="Mean", marker=next(markers), color=next(colors), s=font_size * 6)
            print("Mean")
            print_y(y)
    if plt.ylim()[0] < -0.1:
        plt.ylim(bottom=-0.1)
    ax.set_xlim(-0.5, len(algo_list) - 0.5)
    

    ## plot
    # methods
    xtick_num = len(xtick_list) # the method labels
    xtick_pos = [(alg_num_list[i] - 1) / 2 + sum(alg_num_list[:i]) for i in range(xtick_num)]
    xtick_label = [l for l in xtick_list]
    if include_sota == True:
        # sotas
        sota_num = 9 if APPLICATION == "cache" else 1
        xtick_pos += [sum(alg_num_list) + i for i in range(sota_num)] # sotas
        xtick_label += [a for a in algo_list[-sota_num:]]
    print("xtick_label:", xtick_label)
    ax.set_xticks(xtick_pos)
    tick_labels = ax.set_xticklabels(xtick_label)
    for i, label in enumerate(tick_labels):
        if i < xtick_num:
            label.set_rotation(0)
        elif include_sota == True and len(xtick_list) > 0:
            label.set_rotation(90)
        else:
            label.set_rotation(0)
        label.set_fontsize(font_size)
    # # vlines
    vlines = [-0.5 + sum(alg_num_list[:i]) for i in range(1, xtick_num + 1)] if include_sota else [-0.5 + sum(alg_num_list[:i]) for i in range(1, xtick_num)]
    
    for x_line in vlines:
        ax.axvline(x=x_line, linestyle="-.", color='black', linewidth=1)

    #grids
    ax.xaxis.grid(False)  # Disable vertical grid lines

    plt.ylabel(title, fontsize=font_size * 1.15)
    plt.xticks(fontsize=font_size * 1.15)
    if APPLICATION == "cache":
        ax.yaxis.set_major_formatter(PercentFormatter(xmax=1, decimals=0))
        plt.yticks([i/20 for i in range(11)], fontsize=font_size)
        plt.ylim(bottom=0.2)
    plt.legend(
        ncol=6,
        loc="upper center",
        fontsize=font_size * 1.06,
        # bbox_to_anchor=(-0.02, 1.2),
        bbox_to_anchor=(0.5, 1.18),
        frameon=False,
        columnspacing=width/50,  # space between columns (i.e., between entries)
    )
    plt.savefig(png_path, bbox_inches="tight")
    plt.clf()

def plot_hr_incre(abs_m_algo_hr: dict, png_path, xtick_list, relative_algo="lfu", include_hybrid=True, include_sota=True):
    assert relative_algo in abs_m_algo_hr
    m_algo_hr = dict()
    for algo in abs_m_algo_hr:
        m_algo_hr[algo] = [hr - base_hr for hr, base_hr in zip(abs_m_algo_hr[algo], abs_m_algo_hr[relative_algo])]

    plot_abs_hr(m_algo_hr, png_path, xtick_list, f"Hit Ratio Improvement vs. {relative_algo}", include_sota, include_hybrid)
