import matplotlib.pyplot as plt
import numpy as np

from util.Logger import LoggerLogFlags
from scipy.stats import multivariate_normal
from util.plot_results import draw_2d_gaussian
# from depricated.Plot_utils import draw_2d_gaussian


class RTPlotter:
    def __init__(self, logger, save2path, vis_plots, save2png):
        self.colors = ['blue', 'green', 'red', 'purple', 'brown', 'cyan', 'yellow',
                       'blue', 'green', 'red', 'purple', 'brown', 'cyan', 'yellow',
                       'blue', 'green', 'red', 'purple', 'brown', 'cyan', 'yellow',
                       'blue', 'green', 'red', 'purple', 'brown', 'cyan', 'yellow'] * 1000
        self.save2path = save2path
        self.logger = logger
        self.fig_dict = {}
        self.vis_plots = vis_plots
        self.save2png = save2png

    def it_plot(self, c_iteration, flags):

        for flag in flags:
            if flag in LoggerLogFlags.EPOCHS_VALUED:
                c_fig = self.get_fig_from_flag(flag)
                c_axes = self.get_axis_from_fig(c_fig)
                data2b_plotted = self.logger.from_flag2_attr[flag]
                if flag in LoggerLogFlags.LIST_VALUED:
                    for cmp_idx in range(len(data2b_plotted)):
                        c_axes.plot(range(c_iteration), data2b_plotted[cmp_idx][:c_iteration, 0],
                                    color=self.colors[cmp_idx], label=str(cmp_idx))
                else:
                    c_axes.plot(range(c_iteration), data2b_plotted[:c_iteration, 0], color='b')

                handles, labels = c_axes.get_legend_handles_labels()
                c_axes.legend(handles, labels, loc='upper left')

                if self.vis_plots:
                    plt.show()
                    plt.pause(0.0000001)
                if self.save2png:
                    self.save_plot2png(c_fig.number, flag.name, c_iteration)

    def non_it_plot(self, flags, env, c_iteration, actions=None):
        # context samples
        ctxt_samples = self.logger.from_flag2_attr[LoggerLogFlags.MC_VALS_s_i]

        for flag in flags:
            if flag not in LoggerLogFlags.EPOCHS_VALUED:
                if flag is LoggerLogFlags.CO_POSITIONS_c_i and actions is not None:
                    c_fig = self.get_fig_from_flag(flag)
                    for cmp_idx in range(self.logger.c.num_components):
                        print(cmp_idx)
                        plt.figure(c_fig.number).clf()
                        plt.title('comp_positions' + str(cmp_idx))
                        self.plot_comp_positions(actions[cmp_idx], c_iteration, env, cmp_idx, c_fig)
                else:
                    c_fig = self.get_fig_from_flag(flag)
                    c_axes = self.get_axis_from_fig(c_fig)
                    data2b_plotted = self.logger.from_flag2_attr[flag]
                    if flag in LoggerLogFlags.LIST_VALUED:
                        for cmp_idx in range(len(data2b_plotted)):
                            data2b_plotted_comp = data2b_plotted[cmp_idx]
                            c_axes.scatter(ctxt_samples, data2b_plotted_comp, marker='.', s=10, c=self.colors[cmp_idx])
                    else:
                        c_axes.scatter(ctxt_samples, data2b_plotted, marker='.', s=10, c='b')
                    handles, labels = c_axes.get_legend_handles_labels()
                    c_axes.legend(handles, labels, loc='upper left')

                if self.vis_plots:
                    plt.show()
                    plt.pause(0.0000001)
                if self.save2png:
                    if flag is not LoggerLogFlags.CO_POSITIONS_c_i:
                        self.save_plot2png(c_fig.number, flag.name, c_iteration)

    def cc_subplot(self, means, covs, local_ctxt_samples, ctxt_samples, c_iteration):

        # ctxt_samples = self.logger.from_flag2_attr[LoggerLogFlags.MC_VALS_s_i]
        # local_ctxt_samples = self.logger.from_flag2_attr[LoggerLogFlags.CC_LOC_C_SAMPLES_c_i]
        rewards_loc_samples = self.logger.from_flag2_attr[LoggerLogFlags.CC_REWARDS_LOC_C_SAMPLES_c_i]

        c_fig = self.get_fig_from_flag('cond_ctxt')
        c_axes_1, c_axes_2 = self.get_axis_from_fig_for_cc_plot(c_fig, dim=self.logger.c.ctxt_dim)


        if c_axes_2 is not None:
            c_axes_2_lims_x = c_axes_2.get_xlim()
            c_axes_2_lims_y = c_axes_2.get_ylim()

        for cmp_idx in range(len(rewards_loc_samples)):
            c_ctxt_samples = local_ctxt_samples[cmp_idx]
            c_rewards_loc_samples = rewards_loc_samples[cmp_idx]
            c_cov = covs[cmp_idx]
            c_mean = means[cmp_idx]
            distr_densities = multivariate_normal.pdf(local_ctxt_samples[cmp_idx], mean=c_mean, cov=c_cov)
            if self.logger.c.ctxt_dim == 1:
                c_axes_2.scatter(c_ctxt_samples, np.zeros(ctxt_samples.shape), marker='.', s=10, c=self.colors[cmp_idx])
                c_axes_1.scatter(c_ctxt_samples, c_rewards_loc_samples, marker='.', s=10, c=self.colors[cmp_idx])
                c_axes_2.scatter(c_ctxt_samples, distr_densities, marker='.', s=10, c=self.colors[cmp_idx])
                c_axes_2.scatter(np.ones(5) * c_mean, np.linspace(0, 1, 5), marker='.', s=30, c=self.colors[cmp_idx])
                # c_axes_2_lims_x = c_axes_2.get_xlim()
                c_axes_2_lims_x = [self.logger.c.context_range_bounds[0], self.logger.c.context_range_bounds[1]]
                c_axes_1.set_xlim(c_axes_2_lims_x[0], c_axes_2_lims_x[1])
                c_axes_2.set_xlim(c_axes_2_lims_x[0], c_axes_2_lims_x[1])
            elif self.logger.c.ctxt_dim == 2:
                # c_axes_2.scatter(c_ctxt_samples[:, 0], c_ctxt_samples[:, 1], np.zeros(ctxt_samples[:, 0].shape), marker='.',
                #                  s=10, c='black')
                # c_axes_1.scatter(c_ctxt_samples[:,0], c_ctxt_samples[:, 1], c_rewards_loc_samples, marker='.', s=10,
                #                  c=self.colors[cmp_idx])
                # c_axes_2.scatter(c_ctxt_samples[:, 0], c_ctxt_samples[:, 1], distr_densities, marker='.', s=10,
                #                  c=self.colors[cmp_idx])
                draw_2d_gaussian(mu=c_mean, sigma=c_cov, fig=c_fig, color=self.colors[cmp_idx])
                plt.figure(c_fig.number)
                plt.plot(c_ctxt_samples[:, 0], c_ctxt_samples[:, 1], 'x', color=self.colors[cmp_idx])
                if c_axes_2 is not None:
                    # c_axes_2.set_xlim(c_axes_2_lims_x[0], c_axes_2_lims_x[1])
                    # c_axes_2.set_ylim(c_axes_2_lims_y[0], c_axes_2_lims_y[1])

                    c_axes_2.set_xlim(self.logger.c.context_range_bounds[0][0], self.logger.c.context_range_bounds[1][0])
                    c_axes_2.set_ylim(self.logger.c.context_range_bounds[0][1], self.logger.c.context_range_bounds[1][1])


        if self.vis_plots:
            plt.show()
            plt.pause(0.0000001)
        if self.save2png:
            self.save_plot2png(c_fig.number, 'cond_ctxt', c_iteration)

    def cg_plots(self, c_iteration, model):

        ctxt_samples = self.logger.from_flag2_attr[LoggerLogFlags.MC_VALS_s_i]
        gating_probs = model.gating_distribution.probabilities(ctxt_samples)

        c_fig = self.get_fig_from_flag('gatings')
        c_axes_1, c_axes_2 = self.get_axis_from_fig_for_cc_plot(c_fig, dim=self.logger.c.ctxt_dim)

        # c_axes_2_lims_x = c_axes_2.get_xlim()
        # c_axes_2_lims_y = c_axes_2.get_ylim()

        for cmp_idx in range(gating_probs.shape[1]):
            if self.logger.c.ctxt_dim == 1:
                c_axes_2.scatter(ctxt_samples, np.zeros(ctxt_samples.shape), marker='.', s=10, c=self.colors[cmp_idx])
                c_axes_2.scatter(ctxt_samples, gating_probs[:, cmp_idx], marker='.', s=10, c=self.colors[cmp_idx])
                c_axes_2_lims_x = c_axes_2.get_xlim()
                c_axes_1.set_xlim(c_axes_2_lims_x[0], c_axes_2_lims_x[1])

        if self.vis_plots:
            plt.show()
            plt.pause(0.0000001)
        if self.save2png:
            self.save_plot2png(c_fig.number, 'cond_ctxt', c_iteration)

    def plot_comp_positions(self, actions, c_iteration, env, cmp_idx, c_fig):
        all_ctxt_samples = self.logger.from_flag2_attr[LoggerLogFlags.MC_VALS_s_i]
        try:
            local_ctxt_samples = self.logger.from_flag2_attr[LoggerLogFlags.CC_LOC_C_SAMPLES_c_i][cmp_idx]
        except:
            print('no tracking of local_ctxt_samples setting to all ctxt samples')
            local_ctxt_samples = all_ctxt_samples

        # cond_gatings = self.logger.from_flag2_attr[LoggerLogFlags.CG_WEIGHTS_s_c_i][cmp_idx]
        cond_gatings = None
        try:
            env.visualize(actions, self.logger.c, c_fig, cond_gatings, local_ctxt_samples, all_ctxt_samples,
                      show=self.vis_plots)
        except:
            env.hole_reacing_instances[0].visualize(actions, self.logger.c, c_fig, cond_gatings, local_ctxt_samples,
                                                         all_ctxt_samples, show=self.vis_plots)

        if self.vis_plots:
            plt.show()
            plt.pause(0.0000001)
        if self.save2png:
            self.save_plot2png(c_fig.number, 'comp_positions_' + str(cmp_idx), c_iteration)

    def plot_final_comp_positions(self, flags, env):
        for flag in flags:
            if flag is LoggerLogFlags.CO_POSITIONS_c_i:
                fig = plt.figure('all positions')
                all_ctxt_samples = self.logger.from_flag2_attr[LoggerLogFlags.MC_VALS_s_i]
                comp_positions = self.logger.from_flag2_attr[LoggerLogFlags.CO_POSITIONS_c_i]
                for i in range(len(comp_positions)):
                    first_time_all_plot = True if i ==0 else False
                    c_actions = comp_positions[i]
                    try:
                        env.visualize(actions=c_actions, config=self.logger.c,fig=fig, cond_gating=None, contexts=None,
                                  all_contexts=all_ctxt_samples, show=self.vis_plots, color=self.colors[i])
                    except:
                        env.hole_reacing_instances[0].visualize(actions=c_actions, config=self.logger.c,fig=fig, cond_gating=None,
                                                                contexts=all_ctxt_samples, local_ctxt_samples=all_ctxt_samples,
                                                                show=self.vis_plots, color=self.colors[i], all_plot=True,
                                                                first_time_all_plot = first_time_all_plot)
                if self.vis_plots:
                    plt.show()
                    plt.pause(0.0000001)
                if self.save2png:
                    self.save_plot2png(fig.number, 'comp_positions_all', self.logger.c.n_tot_it)

    def plot_test_progress(self):
        fig = plt.figure('test progress')
        test_rewards = self.logger.from_flag2_attr[LoggerLogFlags.TEST_REWARD_e]
        test_rewards = np.array(test_rewards)
        expected_entr = self.logger.from_flag2_attr[LoggerLogFlags.TEST_MIXT_ENTR_e]
        expected_entr = np.array(expected_entr)
        num_comps_chosen = self.logger.test_num_comps_chosen

        plt.subplot(311)
        # plt.plot(np.linspace(0, self.logger.c.train_epochs, int(self.logger.c.train_epochs/self.logger.c.test_every_it)),
        #          test_rewards)

        x_axis = np.arange(test_rewards.shape[0])
        plt.plot(x_axis, test_rewards)
        plt.title('test rewards')

        plt.subplot(312)
        plt.title('mean entropy during test')
        plt.plot(x_axis, expected_entr)

        plt.subplot(313)
        plt.title('comps chosen in last iteration')
        plt.plot(range(self.logger.c.num_components), num_comps_chosen[-1])

        if self.vis_plots:
            plt.show()
            plt.pause(0.0000001)
        if self.save2png:
            self.save_plot2png(fig.number, 'test_progress', self.logger.c.n_tot_it)

    def get_fig_from_flag(self, flag):
        try:
            c_fig = self.fig_dict[flag]
        except Exception:
            if type(flag) is not str:
                flag_name = flag.name
            else:
                flag_name = flag
            self.fig_dict[flag] = plt.figure(flag_name)
            c_fig = self.fig_dict[flag]
        # activate figure
        plt.figure(c_fig.number)
        return c_fig

    def get_axis_from_fig(self, fig):

        c_ax_list = fig.axes
        if len(c_ax_list) == 0:
            c_axes = fig.add_subplot(111)
        else:
            c_axes = c_ax_list[-1]
            c_axes.clear()
        return c_axes

    def get_axis_from_fig_for_cc_plot(self, fig, dim=1):

        c_ax_list = fig.axes
        if len(c_ax_list) == 0:
            if dim == 1:
                c_axes_1 = fig.add_subplot(211)
                c_axes_2 = fig.add_subplot(212)
            elif dim == 2:
                # c_axes_1 = fig.add_subplot(211, projection='3d')
                # c_axes_2 = fig.add_subplot(212, projection='3d')
                c_axes_1 = fig.add_subplot(111)
                c_axes_2 = None
            else:
                raise ValueError("plotting higher then 2 dimensional context space not supported")
        elif len(c_ax_list) == 1:
            c_axes_1 = c_ax_list[0]
            if dim == 1:
                c_axes_2 = fig.add_subplot(212)
            elif dim == 2:
                # c_axes_2 = fig.add.subplot(212, projection ='3d')
                c_axes_1 = fig.add_subplot(111)
                c_axes_2 = None
            else:
                raise ValueError("plotting higher then 2 dimensional context space not supported")
        elif len(c_ax_list) == 2:
            c_axes_1 = c_ax_list[0]
            c_axes_2 = c_ax_list[1]
        else:
            raise ValueError("too many axes for conditional context plot")
        # clear the axis
        c_axes_1.clear()
        if c_axes_2 is not None:
            c_axes_2.clear()
        return c_axes_1, c_axes_2

    def save_plot2png(self, fig_number, description, c_iteration):
        plt.figure(fig_number)
        fname = '/' + description + '_it' + str(c_iteration) + '.png'
        plt.savefig(self.save2path + fname)