import numpy as np
data = np.load("mslr.npz")
relev = data['relevances'][:, :20]
context = data['features'][:, :20, :]
random_state = np.random.RandomState(0)

class mlsr():
    def __init__(self, budget = 10000):
        self.arm = 50
        self.dim = 136
        self.relev = relev
        self.context = context
        self.rewardlist = [0]
        self.costlist = [0]
        self.cost = random_state.uniform(0, 5, size=20)
        self.budget = budget
        # print(self.cost)
        # self.cost = np.zeros(20)
        self.round = -1
        self.stopping = 0

    def sendcontext(self):
        # self.round += 1
        self.round = np.random.randint(0,10000)
        return self.context[self.round]

    def feedback(self, arm):
        # index = np.where(self.cost - 0.5 < 0)
        # opt = np.max(self.relev[self.round][index])
        self.rewardlist.append(self.rewardlist[-1] + self.relev[self.round, arm])
        # print("max: {}, choice: {}".format(opt, self.relev[self.round, arm]))
        self.costlist.append(self.costlist[-1] + self.cost[arm])
        if self.costlist[-1] > self.budget:
            self.stopping = 1
        return self.relev[self.round, arm] + np.random.normal(loc=0, scale=0.05), self.cost[arm] + np.random.normal(loc=0, scale=0.05)

    def remain(self, round):
        for _ in range(0, round):
            self.rewardlist.append(self.rewardlist[-1])

    def returnresult(self):
        return self.rewardlist/np.arange(1, len(self.rewardlist) + 1)[:5001]

    def reset(self):
        self.rewardlist = [0]
        self.costlist = [0]
        self.round = -1
        self.stopping = 0

    def restart(self):
        self.rewardlist = [0]
        self.costlist = [0]
        self.round = -1
        self.stopping = 0
        self.cost = random_state.uniform(0, 1, size=20)
