import numpy as np 
import ipdb

class BaseAlgo(): 

    def __init__(self, cfg, sampler, **kwargs): 
        self.sampler = sampler 
        self.subsampling = cfg.use_subsampling

        kmax = cfg.ks.kmax 
        self.kmax = kmax if kmax <= 0 else 2**kmax

        self.prompt_idx = None
        self.repeat = None
        self.k = None 
        self.beta = None

        self.rhat = None 
        self.piref = None 
        
    def set_params(self, prompt_idx, repeat, k, beta): 
        self.beta = beta
        # load the data and set prompt idx 
        if prompt_idx != self.sampler.prompt_idx:
            self.sampler.load_files(prompt_idx)
        self.prompt_idx = prompt_idx
        # if subsampling, only collect kmax data at the beginning 
        if self.subsampling and repeat != self.repeat:
            if k == self.kmax: # kmax=kmin=-1 should trigger this every time  
                self.sampler.sample_outputs(self.kmax)
                self.k = k if k > 0 else len(self.sampler.outputs) 
                self.set_rhat()
                self.set_piref()
            else:
                raise Exception("kmax must be run first for subsampling")
            self.repeat = repeat
        
        self.k = k if k > 0 else len(self.sampler.outputs) 
        self.reset()
        # assert len(self.rhat) == self.k
        
    def reset(self):
        if self.subsampling:
            if self.k > 0:
                self.rhat = self.rhat[:self.k]
            else: 
                self.set_rhat()
            self.set_piref()
            # self.piref = self.piref[:self.k] if self.piref is not None else None 
        else:
            self.sampler.sample_outputs(self.k)
            self.set_rhat()
            self.set_piref() 
       
    def set_piref(self): 
        self.piref = None
    
    def set_rhat(self): 
        self.rhat = np.array(self.sampler.get_rewards())

    def sample_policy(self):
        raise NotImplemented("Policy sampling method not implemented.")


