import os, sys, ipdb
import numpy as np 
import matplotlib.pyplot as plt 
import seaborn as sns

params = {'legend.fontsize': 12,
         'axes.labelsize': 12,
         'axes.titlesize':12,
         }
plt.rcParams.update(params)
sns.set_style("white")


class Plot(): 
    def __init__(self, cfg, plotters, height=5, width=5):
        self.cfg = cfg 
        self.policy = cfg.policy.name 
        self.reward = cfg.reward.name 
        self.task = cfg.task.name 

        self.plotters = plotters
        self.save_prefix = os.path.join(cfg.io.save_root, cfg.io.prefix)
        self.num_plots = len(plotters)
        fig, axs = plt.subplots(1, self.num_plots, figsize=(self.num_plots*width, height))
        self.fig = fig 
        self.axs = axs 
    
    def plot(self):
        for idx, (ax, plotter) in enumerate(zip(self.axs, self.plotters)):
            plotter(ax)
            self.format_axes(idx)
            self.format_titles(idx)
        self.format_header()
    
    def format_axes(self, idx): 
        xlabel = self.get_xlabel(idx)
        ylabel = self.get_ylabel(idx)
        self.axs[idx].set_ylabel(ylabel)
        self.axs[idx].set_xlabel(xlabel)

    def format_titles(self, idx):
        title = self.get_title(idx)
        self.axs[idx].set_title(title)
    
    def format_header(self):
        self.fig.suptitle(self.get_header())

    def get_xlabel(self, idx): 
        return ""
    
    def get_ylabel(self, idx):
        return ""
    
    def get_title(self, idx):
        return ""
    
    def get_header(self):
        return ""
    
    def save_figure(self): 
        print(f"Saving figure to {self.save_prefix}.png")
        self.fig.savefig(f'{self.save_prefix}.png', bbox_inches='tight')


class Plotter(): 
    def __init__(self, cfg, results, ax, add_legend = False): 
        self.cfg = cfg
        self.results = results 
        self.ax = ax 
        self.add_legend = add_legend

        self.xmin = None
        self.xmax = None
        self.ymin = None
        self.ymax = None

        palette = sns.color_palette("husl")
        sns.set_palette(palette)
    
    def plot(self):
        raise NotImplemented
    
    def get_legend_labels(self): 
        raise NotImplemented
    
    def get_title(self):
        raise NotImplemented
    
    def get_ylabel(self):
        raise NotImplemented
    
    def get_xlabel(self):
        raise NotImplemented
    
    def set_yscale(self):
        raise NotImplemented
    
    def set_xscale(self):
        raise NotImplemented
    
    def set_grid(self):
        raise NotImplemented
    
    def update_yminmax(self, x): 
        self.ymin = min(self.ymin, min(x)) if self.ymin is not None else min(x)
        self.ymax = max(self.ymax, max(x)) if self.ymax is not None else max(x)
    
    def update_xminmax(self, x): 
        self.xmin = min(self.xmin, min(x)) if self.xmin is not None else min(x)
        self.xmax = max(self.xmax, max(x)) if self.xmax is not None else max(x)

    def format_axes(self): 
        self.ax.set_title(self.get_title())
        self.ax.set_ylabel(self.get_ylabel())
        self.ax.set_xlabel(self.get_xlabel())
        if self.xmin is not None:
            self.ax.set_xlim(left=self.xmin, right=self.xmax)
        if self.ymin is not None:
            self.ax.set_ylim(bottom=self.ymin, top=self.ymax)
        self.set_xscale()
        self.set_yscale()
        self.set_grid()

    def add_legend(self): 
        if self.add_legend:
            handles, labels = self.get_legend_labels()
            
            legend = self.ax.legend(handles, labels, 
                    fontsize='12',
                    loc='center',
                    frameon=True,
                    handlelength=1.,
                    bbox_to_anchor=(-0.5,0.5), 
                    ncol= 1,# len(self.legend_idxs) - 2,
                    )
            for label in legend.get_texts():
                if label.get_text() in ['Rejection', 'Subsampling']: 
                    label.set_fontweight('bold')