import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from nips2025.utils_plot import plot_abs_hr, plot_hr_incre
import json
from utils import APPLICATION

import numpy as np
import logging
logging.disable(level=logging.DEBUG)

class QualityMetric:
    def __init__(self, is_target_trace_func=lambda e: True):
        self.m_sota_trace_type_result = self.load_code_trace_type_result(os.path.join(os.path.dirname(os.path.abspath(__file__)), f"data-{APPLICATION}", "human-heuristics", "eval.jsonl"), is_target_trace_func)
        self.sota_repr = "lfu" if APPLICATION == "cache" else "0"
        self.m_trace_type_filename = {
            t: [e["trace_file_name"] for e in self.m_sota_trace_type_result[self.sota_repr][t]]
            for t in self.m_sota_trace_type_result[self.sota_repr]
        }
        self.is_target_trace_func = is_target_trace_func

    @classmethod
    def load_code_trace_type_result(cls, result_jsonl_path, is_target_trace_func=lambda e: True):
        m_code_trace_type_result = dict()
        with open(result_jsonl_path, 'r') as file:
            for l in file:
                entry_dict = json.loads(l)
                if entry_dict["tuned_param_mr_info"]["mr_test"] == None:
                    continue
                if "cache_cap" in entry_dict and entry_dict["cache_cap"] < 2:
                    continue
                if is_target_trace_func(entry_dict) == False:
                    continue
                code = entry_dict["algo"]
                trace_type = entry_dict["trace_type"]
                if code not in m_code_trace_type_result:
                    m_code_trace_type_result[code] = dict()
                if trace_type not in m_code_trace_type_result[code]:
                    m_code_trace_type_result[code][trace_type] = []
                m_code_trace_type_result[code][trace_type].append(entry_dict)

        m_code_trace_type_result_sorted = {
            code: {
                trace_type: sorted(entry_list, key=lambda e: e["trace_file_name"])
                for trace_type, entry_list in trace_type_result.items()
            }
            for code, trace_type_result in m_code_trace_type_result.items()
        }
        return m_code_trace_type_result_sorted

    def check_code(self, trace_type_result, trace_type_list):
        for t in trace_type_list:
            if t not in trace_type_result:
                return False
            if [e["trace_file_name"] for e in trace_type_result[t]] != self.m_trace_type_filename[t]:
                return False
        return True

    def get_valid_trace_type_list(self, trace_type):
        return sorted([t for t in self.m_sota_trace_type_result[self.sota_repr] if trace_type in t])

    def get_valid_code_list(self, m_code_trace_type_result, trace_type_list):
        return [c for c in m_code_trace_type_result if self.check_code(m_code_trace_type_result[c], trace_type_list) == True]

    
    
    def get_valid_code_valid_trace_type_hr(self, result_jsonl_path, trace_type):
        m_code_trace_type_result = self.load_code_trace_type_result(result_jsonl_path, self.is_target_trace_func)

        # filter valid trace_types
        valid_trace_type_list = self.get_valid_trace_type_list(trace_type)
        print(f"valid_trace_types = {valid_trace_type_list}")

        # filter valid codes
        valid_code_list = self.get_valid_code_list(m_code_trace_type_result, valid_trace_type_list)
        print(f"#valid_codes (exclude_sota) = {len(valid_code_list)}")

        # result
        def calc_hr(mr):
            if APPLICATION == "tsp-aco":
                return mr
            else:
                return round(1 - mr, 4)
        
        m_vc_vt_hr = {
            c: {
                t: [calc_hr(e["tuned_param_mr_info"]["mr_test"]) for e in m_code_trace_type_result[c][t]]
                for t in valid_trace_type_list
            }
            for c in valid_code_list
        }
        for sota in self.m_sota_trace_type_result:
            assert sota not in m_vc_vt_hr
            m_vc_vt_hr[sota] = {
                t: [calc_hr(e["tuned_param_mr_info"]["mr_test"]) for e in self.m_sota_trace_type_result[sota][t]]
                for t in valid_trace_type_list
            }
        # debug
        for c in m_vc_vt_hr:
            for t in m_vc_vt_hr[c]:
                print(f"{t} size = {len(m_vc_vt_hr[c][t])}")
            break
        return m_vc_vt_hr


    def plot_hr(self, result_jsonl_path, xtick_list, trace_type, png_path, relative_algo=None, sort_code_func=None, rename_code_func=None, includ_hybrid=False, include_sota=True):
        
        def get_code_type(c):
            code_type = None
            if "/rsdict/" in c:
                code_type = "mm"
            if "/rsdict-sf/" in c:
                assert code_type == None
                code_type = "mm"
            if "/plansearch/" in c:
                assert code_type == None
                code_type = "ps"
            if "/rs/" in c:
                assert code_type == None
                code_type = "rs"
            if "/mcts/" in c:
                assert code_type == None
                code_type = "mcts"
            if "/reevo/" in c:
                assert code_type == None
                code_type = "reevo"
            if "/openevolve/" in c:
                assert code_type == None
                code_type = "openevolve"
            if c in [
                "s3fifo",
                "slru",
                "tinyLFU",
                "lfu",
                "arc",
                "sieve",
                "clock",
                "lru",
                "fifo",
                "0" # tsp-aco
            ]:
                code_type = "sota"
            assert code_type != None, f"Unknown code_type for {c}"

            if code_type == "mm":
                return 0
            elif code_type == "ps":
                return 1
            elif code_type == "mcts":
                return 2
            elif code_type == "reevo":
                return 3
            # elif code_type == "openevolve":
            #     return 4
            elif code_type == "rs":
                # return 5
                return 4
            else:
                assert code_type == "sota"
                # return 6
                return 5
        
        m_valid_code_valid_trace_type_hr = self.get_valid_code_valid_trace_type_hr(result_jsonl_path, trace_type)
        if includ_hybrid == True and "533" not in m_valid_code_valid_trace_type_hr:
            m_valid_code_valid_trace_type_hr["533"] = self.get_valid_code_valid_trace_type_hr(
                result_jsonl_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), f"data-{APPLICATION}", "case-study", "eval.jsonl"),
                trace_type=trace_type
            )["533"]
        # get hr incre
        tot_hr_list_len = sum([len(self.m_sota_trace_type_result[self.sota_repr][t]) for t in self.get_valid_trace_type_list(trace_type)])
        print(f"#tot_entries = {tot_hr_list_len}")
        m_vc_hr_list = dict()
        for c in m_valid_code_valid_trace_type_hr:
            m_vc_hr_list[c] = []
            for t in m_valid_code_valid_trace_type_hr[c]:
                m_vc_hr_list[c] += m_valid_code_valid_trace_type_hr[c][t]
            assert len(m_vc_hr_list[c]) == tot_hr_list_len
        init_code = ["533"] if includ_hybrid == True else []
        if sort_code_func == None:
            # sorted_vc_list = init_code + sorted(list([k for k in m_vc_hr_list.keys() if k != "533"]), key=lambda c: float(np.mean(m_vc_hr_list[c])), reverse=True)
            if APPLICATION == "cache":
                sorted_vc_list = init_code + sorted(list([k for k in m_vc_hr_list.keys() if k != "533"]), key=lambda c: (get_code_type(c), -float(np.mean(m_vc_hr_list[c])))) # decresing order of hr
            else:
                sorted_vc_list = init_code + sorted(list([k for k in m_vc_hr_list.keys() if k != "533"]), key=lambda c: (get_code_type(c), float(np.mean(m_vc_hr_list[c])))) # increasing order of costs

        else:
            sorted_vc_list = init_code + sorted(list([k for k in m_vc_hr_list.keys() if k != "533"]), key=sort_code_func)

        m_code_type_alg_num = dict()
        for k in m_vc_hr_list.keys():
            if k == "533":
                continue
            else:
                code_type = get_code_type(k)
            # if code_type == 6: # sota
            if code_type == 5:
                continue
            if code_type not in m_code_type_alg_num:
                m_code_type_alg_num[code_type] = 0
            m_code_type_alg_num[code_type] += 1

        assert list(range(len(m_code_type_alg_num))) == sorted(list(m_code_type_alg_num.keys()))

        alg_num_list = [m_code_type_alg_num[i] for i in range(len(m_code_type_alg_num))]
        print(m_code_type_alg_num)

        if rename_code_func == None:
            rename_code_func = lambda c: c

        if relative_algo == None:
            plot_abs_hr(
                abs_m_algo_hr={
                    rename_code_func(vc): m_vc_hr_list[vc]
                    for vc in sorted_vc_list
                },
                png_path=png_path,
                xtick_list=xtick_list,
                include_hybrid=includ_hybrid,
                include_sota=include_sota,
                alg_num_list=alg_num_list,
                title="Hit Ratio" if APPLICATION == "cache" else "Costs"
            )
        else:
            plot_hr_incre(
                abs_m_algo_hr={
                    rename_code_func(vc): m_vc_hr_list[vc]
                    for vc in sorted_vc_list
                },
                png_path=png_path,
                relative_algo=relative_algo,
                xtick_list=xtick_list,
                include_hybrid=includ_hybrid,
                include_sota=include_sota,
            )

    # def tab_hr_incre(self, result_jsonl_path, trace_type, avg_mode, top_k=None, need_sota=True):
    #     m_valid_code_valid_trace_type_hr = self.get_valid_code_valid_trace_type_hr(result_jsonl_path, trace_type)
    #     # get hr incre
    #     if avg_mode == "per_entry":
    #         tot_hr_list_len = sum([len(self.m_sota_trace_type_result["lfu"][t]) for t in self.get_valid_trace_type_list(trace_type)])
    #         print(f"#tot_entries = {tot_hr_list_len}")
    #         m_vc_hr_list = dict()
    #         for c in m_valid_code_valid_trace_type_hr:
    #             m_vc_hr_list[c] = []
    #             for t in m_valid_code_valid_trace_type_hr[c]:
    #                 m_vc_hr_list[c] += m_valid_code_valid_trace_type_hr[c][t]
    #             assert len(m_vc_hr_list[c]) == tot_hr_list_len

    #         m_vc_avg_hr = {
    #             c: float(np.mean(m_vc_hr_list[c]))
    #             for c in m_valid_code_valid_trace_type_hr
    #         }
            

    #     elif avg_mode == "per_dataset":
    #         m_vc_avg_hr = {
    #             c: float(np.mean([
    #                 float(np.mean([
    #                     m_valid_code_valid_trace_type_hr[c][t]
    #                 ]))
    #                 for t in self.get_valid_trace_type_list(trace_type)
    #             ]))
    #             for c in m_valid_code_valid_trace_type_hr
    #         }
    #     else:
    #         raise ValueError(f"Unknown avg_mode: {avg_mode}")

    #     sota_hr = max(list([m_vc_avg_hr[s] for s in self.m_sota_trace_type_result]))
    #     print(f"sota_hr = {round(sota_hr * 100, 2)}")
    #     if top_k == None:
    #         top_k = len(m_vc_avg_hr)
        
    #     printed_vc = set()
    #     sorted_vc_list = sorted(list(m_vc_avg_hr.keys()), key=lambda c: m_vc_avg_hr[c], reverse=True)
    #     for vc in sorted_vc_list[:top_k]:
    #         print(f"{vc}: {round((m_vc_avg_hr[vc] - sota_hr) * 100, 2)}\\%")
    #         printed_vc.add(vc)
        
    #     if need_sota == True:
    #         for vc in sorted_vc_list:
    #             if vc in self.m_sota_trace_type_result and vc not in printed_vc:
    #                 print(f"{vc}: {round(m_vc_avg_hr[vc] * 100, 2)}\% {round((m_vc_avg_hr[vc] - sota_hr) * 100, 2)}\\%")

