from recording.Recorder import RecorderModule, plt, Colors
import numpy as np
np.set_printoptions(precision=2)

def log_res(res, key_prefix):
    try:
        num_iters, kl, entropy = [np.array(x) for x in res]
    except Exception as e:
        print("debug")

    last_rec = {key_prefix + "_num_iterations": num_iters, key_prefix + "_kl": kl, key_prefix + "_entropy": entropy}
    log_string = "Updated for {:d} iterations. ".format(num_iters)
    log_string += "KL: {:.5f}. ".format(kl)
    log_string += "Entropy: {:.5f} ".format(entropy)


#    log_string += str(add_text)
#    log_string += "Reward: {} ".format(rewards)
    return log_string, last_rec

def log_res_weight(res, key_prefix):
    num_iters, kl, entropy, reward, target_lnpdfs, entropies, log_responsibilities = [np.array(x) for x in res]
    last_rec = {key_prefix + "_num_iterations": num_iters, key_prefix + "_kl": kl, key_prefix + "_entropy": entropy, key_prefix + "_reward": reward}
 #   log_string = "Updated for {:d} iterations. ".format(num_iters)
  #  log_string += "KL: {:.5f}. ".format(kl)
  #  log_string += "Entropy: {:.5f} ".format(entropy)
    log_string = "Reward: {}".format(reward)
  #  log_string += "\ntarget_lnpdfs: {} \n".format(target_lnpdfs)
  #  log_string += "Entropies: {} \n".format(entropies)
  #  log_string += "log_responsibilities: {} ".format(log_responsibilities)
    return log_string, last_rec


class WeightUpdateRecMod(RecorderModule):
    def __init__(self, plot):
        super().__init__()
        self._last_rec = None
        self._plot = plot
        self._kls = []
        self._entropies = []
        self._reward = []
        self._num_iters = -1

    def initialize(self, recorder, plot_realtime, save, num_iters):
        super().initialize(recorder, plot_realtime, save)
        self._num_iters = num_iters

    def record(self, res):
        log_string, self._last_rec = log_res_weight(res, "weights")
        self._logger.info(log_string)
        if self._plot:
            self._kls.append(self._last_rec["weights_kl"])
            self._entropies.append(self._last_rec["weights_entropy"])
#            self._reward.append(self._last_rec["weights_reward"])

            self._recorder.handle_plot("Weight Update", self._plot_fn)

    def _plot_fn(self):
        plt.subplot(2, 1, 1)
        plt.title("Expected KL")
        plt.plot(self._kls)
        plt.xlim(0, self._num_iters)
        plt.subplot(2, 1, 2)
        plt.title("Expected Entropy")
        plt.plot(self._entropies)
        plt.xlim(0, self._num_iters)
 #       plt.subplot(3, 1, 3)
 #       plt.title("Expected Reward")
 #       plt.plot(self._reward)
 #       plt.xlim(0, self._num_iters)
        plt.tight_layout()

    @property
    def logger_name(self):
        return "Weight Update"

    def get_last_rec(self):
        assert self._last_rec is not None
        return self._last_rec

    def finalize(self):
        if self._plot:
            self._recorder.save_img("WeightUpdates", self._plot_fn)


class ComponentUpdateRecMod(RecorderModule):

    def __init__(self, plot, summarize=True):
        super().__init__()
        self._plot = plot
        self._last_rec = None
        self._summarize = summarize
        self._kls = None
        self._entropies = None
        self._num_iters = -1
        self._num_components = -1
        self._c = Colors()

    def initialize(self, recorder, plot_realtime, save, num_iters, num_components):
        super().initialize(recorder, plot_realtime, save)
        self._num_iters = num_iters
        self._num_components = num_components
        self._kls = [[] for _ in range(self._num_components)]
        self._entropies = [[] for _ in range(self._num_components)]

    def record(self, res_list):
        self._last_rec = {}
        for i, res in enumerate(res_list):
            cur_log_string, cur_last_rec = log_res(res, "component_{:d}".format(i))
            self._last_rec = {**self._last_rec, **cur_last_rec}
            if not self._summarize:
                self._logger.info("Component{:d}: ".format(i + 1) + cur_log_string)
            if self._plot:
                self._kls[i].append(self._last_rec["component_{:d}_kl".format(i)])
                self._entropies[i].append(self._last_rec["component_{:d}_entropy".format(i)])
        for j in range(i+1, len(self._kls)):
            if self._plot:
                self._kls[j].append(np.nan)
                self._entropies[j].append(np.nan)
        if self._summarize:
            self._summarize_results(res_list)
        if self._plot:
            self._recorder.handle_plot("Component Update", self._plot_fn)

    def _summarize_results(self, res_list):
        fail_ct = 0
        for res in res_list:
            if "fail" in str(res[-1]).lower():
                fail_ct += 1
        num_updt = len(res_list)
        log_str = "{:d} components updated - {:d} successful".format(num_updt, num_updt - fail_ct)
        self._logger.info(log_str)

    def _plot_fn(self):
        plt.subplot(2, 1, 1)
        plt.title("Expected KL")
        l_list = []
        for i in range(self._num_components):
            plt.plot(self._kls[i], c=self._c(i))
            if np.any(np.logical_not(np.isnan(self._kls[i]))):
                l_list.append("Component {:d}".format(i + 1))
        plt.legend(l_list)
        plt.xlim(0, self._num_iters)
        plt.subplot(2, 1, 2)
        plt.title("Expected Entropy")
        for i in range(self._num_components):
            plt.plot(self._entropies[i], c=self._c(i))
        plt.xlim(0, self._num_iters)
        plt.tight_layout()

    @property
    def logger_name(self):
        return "Component Update"

    def get_last_rec(self):
        assert self._last_rec is not None
        return self._last_rec

    def finalize(self):
        if self._plot:
            self._recorder.save_img("ComponentUpdates", self._plot_fn)

class ContextUpdateRecMod(RecorderModule):

    def __init__(self, plot):
        super().__init__()
        self._last_rec = None
        self._plot = plot
        self._kls = []
        self._entropies = []
        self._num_iters = -1

    def initialize(self, recorder, plot_realtime, save, num_iters):
        super().initialize(recorder, plot_realtime, save)
        self._num_iters = num_iters

    def record(self, res):
        log_string, self._last_rec = log_res(res, "contexts")
        self._logger.info(log_string)
        if self._plot:
            self._kls.append(self._last_rec["contexts_kl"])
            self._entropies.append(self._last_rec["contexts_entropy"])
            self._recorder.handle_plot("Context-Weights Update", self._plot_fn)

    def _plot_fn(self):
        plt.subplot(2, 1, 1)
        plt.title("Expected KL")
        plt.plot(self._kls)
        plt.xlim(0, self._num_iters)
        plt.subplot(2, 1, 2)
        plt.title("Expected Entropy")
        plt.plot(self._entropies)
        plt.xlim(0, self._num_iters)
        plt.tight_layout()

    @property
    def logger_name(self):
        return "Context Update"

    def get_last_rec(self):
        assert self._last_rec is not None
        return self._last_rec

    def finalize(self):
        if self._plot:
            self._recorder.save_img("WeightUpdates", self._plot_fn)