import numpy as np 
from .base import BaseAlgo

class Algo(BaseAlgo): 
    def __init__(self, cfg, sampler, holdout_sampler=None, eps=1e-2): 
        super().__init__(cfg, sampler)
        np.random.seed(cfg.seed)
        self.holdout_sampler = holdout_sampler
        self.holdout_rhat = None
        self.batch_rmax = cfg.method.batch_rmax
        self.eps = eps 

    # def _set_holdout_params(self): 
    #     if self.holdout_sampler is None: 
    #         self.holdout_rhat = np.copy(self.rhat)
    #     else: 
    #         self.holdout_rhat = np.array(self.holdout_sampler.get_rewards())
            
    def set_piref(self): 
        self.piref = np.ones(self.k) / self.k
    
    def set_M(self): 
        self.rmax = self.rhat.max() + self.eps if self.batch_rmax else self.sampler.rmax 
        self.M = (self.rmax - self.z) / self.beta
    
    def estimate_z(self): 
        rhat = self.holdout_rhat if self.holdout_sampler is not None else self.rhat 
        sort_idxs = np.argsort(rhat)
        sort_rhat = rhat[sort_idxs]
        sort_piref = self.piref[sort_idxs]

        partition = np.sum(self.piref) 
        exp_reward = sort_piref @ sort_rhat
        for i in range(self.k): 
            z = (exp_reward - 2*self.beta) / partition  
            if i == 0 and z < sort_rhat[i]: 
                break
            elif z < sort_rhat[i] and z >= sort_rhat[i-1]: 
                break
            # elif i == self.k - 1:
            #     break
            else:
                partition -= sort_piref[i] 
                exp_reward -= sort_piref[i] * sort_rhat[i] 
        self.z = z

    def reset(self):
        super().reset()
        # if self.holdout_sampler is not None:
        #     self.holdout_sampler._reset()
        # self._set_holdout_params() 
        # assert len(self.piref) == self.k
        self.estimate_z()
        self.set_M()

    def sample_policy(self):
        probs = np.maximum(self.rhat - self.z, 0) / (self.beta * self.M)
        xis = [np.random.binomial(1, p) for p in probs]
        # '''
        # TODO 
        # output length, and have all algos output the same
        # check rmax 
        # if doesnt terminate what do
        # '''
        try: 
            idx = xis.index(1)
            return self.sampler.get_outputs(idx), idx
        except: 
            return None, len(self.sampler.outputs)

    