
import numpy as np 
import scipy.stats as scs
import os 
import ipdb
from collections import defaultdict

import matplotlib.pyplot as plt 
import seaborn as sns

params = {'legend.fontsize': 12,
         'axes.labelsize': 12,
         'axes.titlesize':12,
         }
plt.rcParams.update(params)

sns.set_style("white")


class Plotter(): 
    def __init__(self, cfg, results): 
        self.cfg = cfg 
        self.results = results

        self.save_prefix = os.path.join(cfg.io.save_root, cfg.io.prefix)

        self.fig = None 
        self.axs = None 
    
    def save_figure(self): 
        print(f"Saving figure to {self.save_prefix}.png")
        self.fig.savefig(f'{self.save_prefix}.png', bbox_inches='tight')

def fill_in_rejection(results, method='rejection-holdout-False'):
    for metric, metric_dict in results.items(): 
        method_dict = metric_dict[method]
        for beta, beta_dict in method_dict.items(): 
            for k, k_dict in beta_dict.items(): 
                if int(k) > 0 and metric != 'nums':
                    bonfill = metric_dict['bon']['0'][k]
                    for prompt_idx, value in k_dict.items(): 
                        k_dict[prompt_idx] = [val if val is not None else bonfill[prompt_idx][i] for (i, val) in enumerate(value) ]
    return results 

        

    
class RewardvsNPlotter(Plotter): 
    def __init__(self, metrics, *args): 
        super().__init__(*args)
        # self.draws = self.cfg.draws
        # self.fillin_rejection()
        self.metrics = metrics
        self.results = fill_in_rejection(self.results)

        # self.metrics = ['correct', self.cfg.reward.name, 'no-response']
        # self.metrics = list(self.results.keys())[:-1]
        # self.metrics = [metric for metric in list(self.results.keys()) if metric not in ['response', 'answer']]
        fig, axs = plt.subplots(1, len(self.metrics), figsize=(len(self.metrics)*5,8))
        self.fig = fig 
        self.axs = axs
        self.legend_ax = 0
        
        self.xmin = None
        self.xmax = None
        self.ymin = 1 
        self.ymax = 0

    def generate_plot_dicts(self): 
        betas = list(self.cfg.betas) + [0]
        self.beta_to_color = {beta: f'C{idx}' for idx, beta in enumerate(betas)}
        self.metric_to_title = {metric: f'{metric.capitalize()} : {self.cfg.policy.name}' for metric in self.metrics}
        # self.metric_to_ylabel = {metric: 'Reward' if metric not in self.cfg.metrics + ['correct'] else metric.capitalize() for metric in self.metrics}
        self.metric_to_ylabel = {metric: metric.capitalize() for metric in self.metrics}

    def get_line_params(self, **kwargs): 
        method = kwargs['method']
        beta = kwargs['beta']
        linestyle = '--' if (('subsampling' in method) or ('piref' in method)) else '-'
        if method == 'bon': 
            linecolor = 'black'
        elif method == 'piref': 
            linecolor = 'grey'
        else: 
            linecolor = self.beta_to_color[beta]
        return linestyle, linecolor
    
    def get_label(self, **kwargs): 
        method = kwargs['method']
        beta = kwargs['beta']
        if method == 'bon': 
            label = 'BoN'
        elif method == 'piref': 
            label = r"$\pi_{\mathsf{ref}}$"
        else: 
            label = rf"$\beta={beta}$"
        return label 
    
    def format_axis(self, i): 
        ax = self.axs[i]
        ax.set_xlim(left=self.xmin, right=self.xmax)
        ax.set_ylabel(self.metric_to_ylabel[self.metrics[i]])
        ax.set_xlabel(r'Samples $N$')
        ax.set_xscale('log')
        ax.grid(True, which="both", linestyle='--', linewidth=0.5)

    def format_title(self, i): 
        self.axs[i].set_title(self.metric_to_title[self.metrics[i]])
    
    def format_method(self, method): 
        algo = method.split('-')[0]
        return algo.capitalize() 
    
    def _get_legend_labels(self): 
        ax = self.axs[self.legend_ax]
        ax.plot([self.xmax], [self.ymax], marker='None',
                linestyle='None', label='fake')
        handles,labels = ax.get_legend_handles_labels()
        
        handles_ = [] 
        labels_  = []
        for idx, method in enumerate(self.legend_idxs.keys()): 
            if idx > 2: 
                handles_ += [handles[-1]]
                labels_ += ['']
            if idx > 1: 
                handles_ += [handles[-1], handles[-1]]
                labels_ += ['', self.format_method(method)] 
            idxs = self.legend_idxs[method]
            handles_ += [handles[i] for i in idxs] 
            labels_ += [labels[i] for i in idxs] 
        return handles_, labels_
    
    def add_legend(self): 
        ax = self.axs[self.legend_ax]
        handles, labels = self._get_legend_labels()
        
        legend = ax.legend(handles, labels, 
                fontsize='12',
                loc='center',
                frameon=True,
                handlelength=1.,
                bbox_to_anchor=(-0.5,0.5), 
                ncol=len(self.legend_idxs) - 2,
                )
        for label in legend.get_texts():
            if label.get_text() in ['Rejection', 'Subsampling']: 
                label.set_fontweight('bold')


    def update_yminmax(self, x): 
        self.ymin = min(self.ymin, min(x)) if self.ymin is not None else min(x)
        self.ymax = max(self.ymax, max(x)) if self.ymax is not None else max(x)
    
    def update_xminmax(self, x): 
        self.xmin = min(self.xmin, min(x)) if self.xmin is not None else min(x)
        self.xmax = max(self.xmax, max(x)) if self.xmax is not None else max(x)
    
    def plot(self, k_list): 
        palette = sns.color_palette("husl")
        sns.set_palette(palette)
        
        self.legend_idxs = defaultdict(list)
        line_idx = 0
        alpha = 0.2 
        results = self.results
        for axis_idx, ax in enumerate(self.axs): 
            metric = self.metrics[axis_idx] 
            metric_dict = results[metric]
            self.methods = sorted(list(metric_dict.keys()))
            for method, method_dict in metric_dict.items(): 
                # self.betas = sorted(list(method_dict.keys()))
                self.generate_plot_dicts()
                betas = self.cfg.betas if 'rejection' in method else list(method_dict.keys())
                for beta in betas: 
                    # beta_type = type(list(method_dict.keys())[0])
                    beta_dict = method_dict[str(beta)]
                    # k_list = list(beta_dict.keys())
                    # import ipdb; ipdb.set_trace()
                    avgs = []
                    stds = []
                    for k in k_list:
                        values = [float(v) for val in list(beta_dict[str(k)].values()) for v in val if v != 'nan']
                        avgs.append(np.mean(values))
                        stds.append(scs.sem(values))
                        

                    k_list = np.array(list(map(int, k_list))) 
                    avgs = np.array(avgs)
                    stds = np.array(stds)
                    linestyle, linecolor = self.get_line_params(method=method, beta=beta)
                    label = self.get_label(method=method, beta=beta)
                    ax.plot(k_list, avgs, linestyle=linestyle, color=linecolor, label=label)
                    ax.fill_between(k_list, avgs+stds, avgs-stds, color=linecolor, alpha=alpha)
                    if axis_idx == 0: 
                        self.update_yminmax(avgs)
                        self.update_xminmax(k_list)
                        self.legend_idxs[method].append(line_idx)
                        line_idx += 1
            self.format_axis(axis_idx)
            self.format_title(axis_idx)
            axis_idx += 1
        self.legend_idxs = dict(self.legend_idxs)
        self.add_legend()
        self.save_figure()