import matplotlib.pyplot as plt
import numpy as np
from itertools import count

from numpy.core.fromnumeric import argmax

class SamplerException(Exception):
    def __init__(self):
        super()

class Sampler(object):
    def __init__(self, data, samples = 100, best = None):
        self.data = np.concatenate((data, np.array([0])))
        self.num_samples = samples
        self.counter = 0
        self.history = []
        self.label = 'Label not specified. Call reset before.'
        self.old_history = []
        self.plot_kwargs = {'label': 'Label not specified'}
        self.best = np.max(self.data) if best is None else best

    def __enter__(self):
        self.counter = 0
        self.history += [[]]
        self.sampled_idxs = [-1]
        return self
  
    def __exit__(self, exc_type, exc_value, tb):
        #print(exc_type, exc_value, tb)
        return exc_type is SamplerException
        
    def sample(self, idx, only_unique=False):
        if self.counter == self.num_samples:
            raise SamplerException

        x = self.data[idx]
                
        if not idx in self.sampled_idxs:
            self.sampled_idxs += [idx]
            self.counter += 1
            self.history[-1] += [x]
            return x
        else:
            if only_unique: return None
            else: return x

    def __len__(self):
        return len(self.data)-1
    
    def get_history(self):
        return np.array(self.history).mean(0)

    def get_best(self, history):
        return np.maximum.accumulate(history, axis = 1).mean(0)

    def get_regret(self, history):
        return  self.best - self.get_best(history)

    def reset(self, **plot_kwargs):
        if len(self.history) > 0:
            self.old_history += [(self.plot_kwargs, self.history)]
        self.history = []
        self.plot_kwargs = plot_kwargs
        
    def plot_best(self):
        self.reset(label = 'Label not specified')
        for args,h in self.old_history:
            plt.plot(self.get_best(h), **args)
        plt.legend()

    def plot_regrets(self):
        self.reset(label ='Label not specified. Call reset before.')
        for args,h in self.old_history:
            plt.plot(self.get_regret(h), **args)
        plt.legend()

def run_sim_random(sampler):
    with sampler:
        for _ in count():
            sampler.sample(np.random.randint(len(sampler)))
