import numpy as np
from cloud import cloud
import time


def UCB_value(experienced_mu, count, t, CB_coefficient):
    tmp = experienced_mu + CB_coefficient * (np.sqrt((np.log(t + 1) / count)))
    if tmp >= 1:
        UCB_value = 1
    if tmp < 1:
        UCB_value = tmp
    return UCB_value


def LCB_value(experienced_cost, count, t, CB_coefficient):
    tmp = experienced_cost - CB_coefficient * (np.sqrt((np.log(t + 1) / count)))
    if tmp <= 0:
        LCB_value = 0
    if tmp > 0:
        LCB_value = tmp
    return LCB_value


class C2MAB_V(object):

    def __init__(self, K, env, T, CB_coefficient, log_ind, LCB_coefficient):
        super(C2MAB_V, self).__init__()
        self.K = K
        self.env = env
        self.T = T
        self.L = self.env.L
        self.C = self.env.C
        self.rewards = np.zeros(self.T)
        self.violation = np.zeros(self.T)
        self.regret = np.zeros(self.T)
        self.regret_cumulative = np.zeros(self.T)
        self.rewards_cumulative = np.zeros(self.T)
        self.choosing_count = np.ones(self.L)
        self.cost = self.env.cost
        self.CB_coefficient = CB_coefficient
        self.log_ind = log_ind
        self.LCB_coefficient = LCB_coefficient

    def run(self):
        global optimal_result
        starttime = time.time()
        experienced_mu = np.random.uniform(self.env.mu_lower, self.env.mu_upper, self.L)
        experienced_cost = np.random.uniform(self.env.cost_lower, self.env.cost_upper, self.L)
        for t in range(self.T):
            # print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++",t)
            UCB_mu = np.zeros(self.L)
            log_UCB_mu = np.zeros(self.L)
            LCB_cost = np.zeros(self.L)
            for i in range(self.L):
                UCB_mu[i] = UCB_value(experienced_mu[i], self.choosing_count[i], t, self.CB_coefficient)
                log_UCB_mu[i] = np.log(UCB_mu[i])
                LCB_cost[i] = LCB_value(experienced_cost[i], self.choosing_count[i], t, self.LCB_coefficient)
            At = cloud(self.L, self.K, self.C, log_UCB_mu, LCB_cost)
            tmp_violation = np.dot(self.cost, At.T) - self.C
            if tmp_violation <= 0:
                self.violation[t] = 0
            if tmp_violation > 0:
                self.violation[t] = np.dot(self.cost, At.T) - self.C

            index_At_choosing = np.flatnonzero(At)

            feedback_cost, reward_t, users_choosing_K = self.env.feedback(At)

            if users_choosing_K.sum() == self.K:
                click_At = len(index_At_choosing) - 1
            if users_choosing_K.sum() < self.K:
                click_At = np.argwhere(users_choosing_K == 0)[0][0]
            for i in range(len(index_At_choosing)):
                self.choosing_count[index_At_choosing[i]] += 1
                experienced_cost[index_At_choosing[i]] = (experienced_cost[index_At_choosing[i]] * (
                            self.choosing_count[index_At_choosing[i]] - 1) + feedback_cost[index_At_choosing[i]]) / (
                                                         self.choosing_count[index_At_choosing[i]])
            for i in range(click_At + 1):
                if i < click_At:
                    experienced_mu[index_At_choosing[i]] = (experienced_mu[index_At_choosing[i]] * (
                                self.choosing_count[index_At_choosing[i]] - 1) + 1) / (
                                                           self.choosing_count[index_At_choosing[i]])
                if i == click_At:
                    if users_choosing_K[click_At] == 1:
                        experienced_mu[index_At_choosing[i]] = (experienced_mu[index_At_choosing[i]] * (
                                    self.choosing_count[index_At_choosing[i]] - 1) + 1) / (
                                                               self.choosing_count[index_At_choosing[i]])
                        # print("first_click_update")
                    if users_choosing_K[click_At] == 0:
                        experienced_mu[index_At_choosing[i]] = (experienced_mu[index_At_choosing[i]] * (
                                    self.choosing_count[index_At_choosing[i]] - 1) + 0) / (
                                                               self.choosing_count[index_At_choosing[i]])
            self.rewards[t] = reward_t
        return self.rewards, self.violation, starttime