import os
import sys
from typing import Dict, List, Tuple, Any
from typing import OrderedDict as OrderedDictType
import numpy as np
import traceback
from collections import OrderedDict
import preference_model
import json

ERROR_SCORE = -1


class RewardFunctionBuffer:
    def __init__(self, pref_model, mask_error_code):
        """Buffer of <key, reward_obj>.

        Args:
            model (preference_model.PreferenceModel): Model for comparing reward function preferences
        """
        # self.buffer: OrderedDictType[str, float] = OrderedDict()
        self.buffer: Dict[str, Dict] = {}
        self.pref_model: preference_model.PreferenceModel = pref_model
        self.mask_error_code = mask_error_code
        self.keys_data = {
            "best": [],
            "worst": [],
            "best_scores": [],
            "best_metrics": [],
            "worst_scores": [],
            "worst_metrics": [],
            "design_mode": self.pref_model.design_mode,
            "env_name": self.pref_model.env_name,
            "early_stopping": None,
        }

    def _save_results(
        self,
        reward_objs: List[Dict],
    ) -> None:
        """
        Caches the infos of a set of reward function objects.

        :param reward_objs: list of reward function code strings.
        """
        for r_obj in reward_objs:
            key = f"iter={r_obj['iter']},sample={r_obj['sample']}"
            if key not in self.buffer:
                self.buffer[key] = r_obj

    def update_buffer(
        self,
        reward_objs: List[Dict],
    ) -> Tuple[str, str]:
        """Calculate the missing fields and save.

        Args:
            reward_objs (List[Dict]): Dict likes:

            {
                "iter": cur_iter,
                "sample": cur_sample,
                "response": [loss_rsp, backward_rsp, optimizer_rsp],
                "model": [loss_model, backward_model, optimizer_model],
                "finish_reason": [loss_finish_reason, backward_finish_reason, optimizer_finish_reason],
                "usage": [loss_usage, backward_usage, optimizer_usage],
                "code": code,
                "score": None,
                "metrics": None,
                "executable_flag": None
            }

        """
        # Compute scores for new responses
        for i in range(len(reward_objs)):
            print(
                "=" * 10
                + f"Calculate the score of the {i+1}/{len(reward_objs)} reward function."
                + "=" * 10
            )
            try:
                score, metrics = self.pref_model.compute_scores(reward_objs[i]["code"])
                executable_flag = True
                print(
                    "-" * 10 + f"The code is executable. Score={score}" + "-" * 10,
                    end="\n\n",
                )
            except Exception as e:
                score, metrics = ERROR_SCORE, {}
                executable_flag = False
                print("-" * 10 + "Execution error:" + "-" * 10, end="\n\n")
                print(traceback.format_exc())
                print("-" * 10 + "Error code:" + "-" * 10, end="\n\n")
                print(reward_objs[i]["code"])

            reward_objs[i]["score"] = score
            reward_objs[i]["metrics"] = metrics
            reward_objs[i]["executable_flag"] = executable_flag

        self._save_results(reward_objs)

    def get_preference_result(self):
        if self.mask_error_code:
            filtered_sorted_buffer = sorted(
                (
                    (k, r_obj)
                    for k, r_obj in self.buffer.items()
                    if r_obj["executable_flag"]
                ),
                key=lambda item: item[1]["score"],
            )
        else:
            filtered_sorted_buffer = sorted(
                ((k, r_obj) for k, r_obj in self.buffer.items()),
                key=lambda item: item[1]["score"],
            )

        all_failed_flag = False
        if len(filtered_sorted_buffer) == 0:
            # all failed
            tmp_keys = self.buffer.keys()
            best_tuple, worst_tuple = (tmp_keys[-1], self.buffer[tmp_keys[-1]]), (
                tmp_keys[0],
                self.buffer[tmp_keys[0]],
            )
            all_failed_flag = True
        else:
            # Sort in ascending order, "filtered_sorted_buffer" looks like [ (key, score=0.0), ... , (key, score=1.0) ]
            best_tuple, worst_tuple = filtered_sorted_buffer[-1], filtered_sorted_buffer[0]
            all_failed_flag = False

        # best/worst_tuple looks like (key, reward_obj)
        # update best_tuple for early stopping
        # ori_best_score = best_tuple[1]["score"]
        # best_score_index = -1
        # for i, (k, r_obj) in enumerate(filtered_sorted_buffer):
        #     if r_obj["score"] == ori_best_score:
        #         best_score_index = i
        #         break
        # best_tuple = filtered_sorted_buffer[best_score_index]
        # get best and worst
        best_func = best_tuple[1]["code"]
        best_score = best_tuple[1]["score"]
        worst_func = worst_tuple[1]["code"]
        worst_socre = worst_tuple[1]["score"]
        delta = best_score - worst_socre
        # log best and worst
        self.keys_data["best"].append(best_tuple[0])
        self.keys_data["worst"].append(worst_tuple[0])
        self.keys_data["best_scores"].append(best_score)
        self.keys_data["worst_scores"].append(worst_socre)
        self.keys_data["best_metrics"].append(best_tuple[1]["metrics"])
        self.keys_data["worst_metrics"].append(worst_tuple[1]["metrics"])
        # early stopping
        if best_score == 1.0 and self.keys_data["early_stopping"] is None:
            self.keys_data["early_stopping"] = best_tuple[0]

        return (
            (best_func, best_score),
            (worst_func, worst_socre),
            delta,
            all_failed_flag,
        )
        # # Flatten the cached data into (reward_function, score) list
        # reward_functions_all = []
        # scores_all = []
        # for k, v in self.buffer.items():
        #     # k looks like 'INDEX{num}<SEP>{reward_function}'
        #     _, reward_function = k.split("<SEP>")
        #     reward_functions_all.append(reward_function)
        #     scores_all.append(v)

        # # Identify best and worst samples from the updated cache
        # if self.mask_error_code:
        #     scores_all = np.array(scores_all)
        #     valid_mask = scores_all != ERROR_SCORE
        #     valid_scores_all = scores_all[valid_mask]
        #     if len(valid_scores_all) <= 0:
        #         raise Exception("" * 10 + "No Valid Reward Function Code!" + "" * 10)
        #     max_score = valid_scores_all.max()
        #     min_score = valid_scores_all.min()
        #     best_index = np.where(scores_all == max_score)[0][-1]
        #     worst_index = np.where(scores_all == min_score)[0][0]
        #     delta = max_score - min_score
        # else:
        #     best_index = np.argmax(scores_all)
        #     worst_index = np.argmin(scores_all)
        #     delta = max(scores_all) - min(scores_all)

        # return (
        #     (reward_functions_all[best_index], scores_all[best_index]),
        #     (reward_functions_all[worst_index], scores_all[worst_index]),
        #     delta,
        # )

    def save(self, log_dir):
        # cur_log_dir = os.path.join(log_dir, f"iter_{cur_iter_num}")
        # os.makedirs(cur_log_dir, exist_ok=True)
        buffer_path = os.path.join(log_dir, "buffer.json")
        with open(buffer_path, "w", encoding="utf-8") as f:
            json.dump(self.buffer, f, ensure_ascii=False, indent=4)

        keys_path = os.path.join(log_dir, "keys.json")
        with open(keys_path, "w", encoding="utf-8") as f:
            json.dump(self.keys_data, f, ensure_ascii=False, indent=4)

        file_dir = os.path.join(log_dir, "code")
        os.makedirs(file_dir, exist_ok=True)
        for i, key in enumerate(self.keys_data["best"]):
            score = self.buffer[key]["score"]
            code = self.buffer[key]["code"]
            filename = os.path.join(file_dir, f"best-{i}-{key}.py")
            with open(filename, "w", encoding="utf-8") as f:
                f.write(f"{code}\n")
            score_filename = os.path.join(file_dir, f"best-{i}-{key}-score.txt")
            with open(score_filename, "w", encoding="utf-8") as f:
                f.write(f"{score}\n")

        for i, key in enumerate(self.keys_data["worst"]):
            score = self.buffer[key]["score"]
            code = self.buffer[key]["code"]
            filename = os.path.join(file_dir, f"worst-{i}-{key}.py")
            with open(filename, "w", encoding="utf-8") as f:
                f.write(f"{code}\n")
            score_filename = os.path.join(file_dir, f"worst-{i}-{key}-score.txt")
            with open(score_filename, "w", encoding="utf-8") as f:
                f.write(f"{score}\n")

        early_stopping_filename = os.path.join(log_dir, f"early_stopping.txt")
        with open(early_stopping_filename, "w", encoding="utf-8") as f:
            f.write(f"{self.keys_data['early_stopping']}\n")
