"""
Provides implemntation of a stochastic and an adversarial combinatorial multi armed bandit environment.
"""
import random
import numpy as np


class Environment:
    """
    Wrapper for environments with Bernoulli arms.
    """
    def __init__(self, gap, dim, m_size, n_steps):
        self.gap = gap
        self.dim = dim
        self.m_size = m_size
        self.n_steps = n_steps
        self.baseline = None
        self.mean_losses = None

    def reset(self):
        raise NotImplementedError

    def play(self, action, time_step):
        del time_step # unused in i.i.d.
        feedback = [0 if random.random() > self.mean_losses[i] else 1.0 for i in action]
        regret = self.baseline - self.mean_losses[action].sum()
        return feedback, regret


class Stochastic(Environment):
    """
    Bandit environment that sets mean rewards around 0.5 and picks losses from Bernoulli distributions.
    The gap vector at initialization determines the mean rewards.
    """

    def __init__(self, action_set, gap, dim, m_size, n_steps):
        super().__init__(gap, dim, m_size, n_steps)

        divide=np.linspace(0,1,m_size).sum()/1.8
        other_base = np.linspace(0,1,m_size)[::-1]/divide
        optimal_base =  np.append(np.array([1, 0.9]), np.array([0 for i in range(m_size-2)]))

        self.mean_losses = np.append(other_base, optimal_base)
        for i in range(m_size-2):
            self.mean_losses = np.append(self.mean_losses, other_base)
        best_action = [m_size,m_size+1]
        self.baseline = self.mean_losses[best_action].sum()
        self.mean_losses = np.array([1.044, 0.783, 0.522, 0.261, 0, 0,
                                     0.936, 0.702, 0.468, 0.234, 0, 0,
                                     0.828, 0.621, 0.414, 0.207, 0, 0,
                                     0.72, 0.54, 0.36, 0.18, 0, 0])
        best_action = [12, 13, 14]
        self.baseline = self.mean_losses[best_action].sum()

        print(self.mean_losses)


        # divide=np.linspace(0,1,m_size).sum()/1.8
        # other_base = np.linspace(0,1,m_size)[::-1]/divide

        # divide_optimal=np.linspace(0,1,3).sum()/(other_base[0:3].sum()+0.1)
        # optimal = np.linspace(0,1,3)[::-1]/divide_optimal


        # optimal_base =  np.append(np.array(optimal), np.array([0 for i in range(m_size-3)]))
        # self.mean_losses = np.append(other_base, optimal_base)
        # for i in range(m_size-2):
        #     self.mean_losses = np.append(self.mean_losses, other_base)
        # best_action = [m_size,m_size+1]
        # print(self.mean_losses)
        # self.baseline = self.mean_losses[best_action].sum()
        # print(self.baseline)

    def reset(self) -> object:
        pass


