"""
Provides implementation of a stochastic and an adversarial combinatorial multi armed bandit environment.
"""
import random
import numpy as np


class Environment:
    """
    Wrapper for environments with Bernoulli arms.
    """
    def __init__(self, gap, dim, m_size, n_steps):
        self.gap = gap
        self.dim = dim
        self.m_size = m_size
        self.n_steps = n_steps
        self.baseline = None
        self.mean_losses = None

    def reset(self):
        raise NotImplementedError

    def play(self, action, time_step):
        del time_step # unused in i.i.d.
        feedback = [0 if random.random() > self.mean_losses[i] else 1.0 for i in action]
        regret = self.baseline - self.mean_losses[action].sum()
        return feedback, regret


class Stochastic(Environment):
    """
    Bandit environment that sets mean rewards around 0.5 and picks losses from Bernoulli distributions.
    The gap vector at initialization determines the mean rewards.
    """

    def __init__(self, action_set, gap, dim, m_size, n_steps):
        super().__init__(gap, dim, m_size, n_steps)
        # if action_set == "full":
        #     assert abs(gap) <= 1
        #     self.mean_losses = np.array([0.5 * (1.0 + gap) if i < dim / 2 else 0.5 * (1.0 - gap) for i in range(dim)])
        #     best_action = [i for i in range(dim) if self.mean_losses[i] < 0.5]
        # elif action_set == "m-set":
        #     self.mean_losses = np.array([0.5 * (1.0 - gap) if i < m_size else 0.5 * (1.0 + gap) for i in range(dim)])
        #     best_action = [i for i in range(dim) if self.mean_losses[i] < 0.5]
        # else:
        #     raise Exception("Invalid action set %s for stochastic environment, abort." % action_set)

        # self.mean_losses = np.array([1,0.9,0.8,0.4,1,0.9,0.5,0.4,0.5,0.4,0.3,0.4,0.3,0.2,0.19,0.15])

        divide=np.linspace(0,1,m_size).sum()/1.8
        other_base = np.linspace(0,1,m_size)[::-1]/divide
        optimal_base =  np.append(np.array([1, 0.9]), np.array([0.85 for i in range(m_size-2)]))

        self.mean_losses = np.append(other_base, optimal_base)
        for i in range(m_size-2):
            self.mean_losses = np.append(self.mean_losses, other_base)
        best_action = [m_size,m_size+1] # 同时选两个base arm为最优
        # print(self.baseline)
        self.mean_losses = np.array([1,    0.9,  0.9,  0.9,
                                     1,    0.9,  0.85,  0.8,
                                     0.72, 0.54, 0.36, 0.18,
                                     0.72, 0.54, 0.36, 0.18])
        best_action = [m_size,m_size+1]
        print(self.mean_losses)
        self.baseline = self.mean_losses[best_action].sum()

    def reset(self) -> object:
        pass


