import random
import numpy as np
from typing import List
import pickle
import os

class Bandit:
    def __init__(self, 
          arms: List = [12e-5, 18e-5, 25e-5], 
          init: float = 0.0,
          window: int = 5,
          decay: float = 0.9,
          lr: float = 0.2,
          use_std: bool = True) -> None:
        self.arms = arms
        self.w = [init] * len(arms)
        self.arm = 2
        self.error_buffer = []
        self.window = window
        self.lr = lr
        self.use_std = use_std
        self.decay = decay

    def update_lr(self):
        p = [np.exp(x) for x in self.w]
        p /= np.sum(p)
        self.arm = np.random.choice(range(0,len(p)), p=p)
        return self.arms[self.arm]

    def get_probs(self) -> List:
        p = [np.exp(x) for x in self.w]
        p /= np.sum(p) # normalize to make it a distribution
        return p    

    def update_dists(self, feedback: float, arm: float, norm: float = 1.0) -> None:

        self.arm = self.arms.index(arm)

        # Since this is non-stationary, subtract mean of previous self.window errors. 
        self.error_buffer.append(feedback)
        self.error_buffer = self.error_buffer[-self.window:]
        
        # normalize
        feedback -= np.mean(self.error_buffer)
        if self.use_std and len(self.error_buffer) > 1: norm = np.std(self.error_buffer); 
        feedback /= (norm + 1e-3)
        
        # update arm weights
        self.w[self.arm] *= self.decay
        self.w[self.arm] += self.lr * (feedback/max(np.exp(self.w[self.arm]), 0.001))
