from functools import partial
from .base import *
import scipy.stats as scs
import ipdb
# might be easier to parse all data in folder, rather than those selected 

def Beta2ValuePlotter(results, betas, ax, std_over_prompts=True):
    avgs = []
    stds = []
    k = '-1'
    repeats = len(list(results[str(betas[0])][k].values())[0])
    for beta in betas:
        beta_dict = results[str(beta)]
        if std_over_prompts:
            values = [float(v) for val in beta_dict[k].values() for v in val]
        else:
            values = [np.mean([float(val[r]) for val in beta_dict[k].values()]) for r in range(repeats)]
        avgs.append(np.mean(values))
        stds.append(scs.sem(values))
    avgs = np.array(avgs)
    stds = np.array(stds)
    ax.plot(betas, avgs)
    ax.fill_between(betas, avgs+stds, avgs-stds, alpha=0.5)
    ax.set_xscale('log')

class BetaPlot(Plot):
    def __init__(self, cfg, results, indices):
        betas = list(cfg.betas)
        method_key = list(cfg.method_list)[-1]
        self.plotters = [
            partial(Beta2ValuePlotter, results[method_key], betas, std_over_prompts=cfg.std_over_prompts),  
            partial(Beta2ValuePlotter, indices[method_key], betas, std_over_prompts=cfg.std_over_prompts)
        ] 
        super().__init__(cfg, self.plotters,)
        self.save_prefix = self.save_prefix + "_betas" 
    
    def get_xlabel(self, idx):
        return r"$\beta$"
    
    def get_ylabel(self, idx):
        labels = [
            "Correct", 
            r"Queries $N$"
        ]
        return labels[idx]
    
    def get_title(self, idx):
        titles = [
            f"Fraction correct",
            r"Queries $N$"
        ]
        return titles[idx]
    
    def get_header(self):
        return fr"Pessimism with $\pi =$ {self.policy} and $r =$ {self.reward}"