import numpy as np
from sympy import *
import numpy as np
import time

def _or_func(v):
    return 1 - np.prod(1 - v)

def UCB_value(experienced_mu,count,t,CB_coefficient):
    tmp = experienced_mu+CB_coefficient*(np.sqrt((np.log(t+1)/count)))
    if tmp>=1:
        UCB_value = 1
    if tmp<1:
        UCB_value = tmp
    return UCB_value

def LCB_value(experienced_cost,count,t,CB_coefficient):
    tmp = experienced_cost-CB_coefficient*(np.sqrt((np.log(t+1)/count)))
    if tmp<=0:
        LCB_value = 0
    if tmp>0:
        LCB_value = tmp
    return LCB_value

def combination_list(l,n):
    if len(l)==n:
        return [l]
    if n==1:
        z=[]
        for i in l:
            z.append([i])
        return z
    else:
        z=[]
        for i in range(len(l)):
            for j in combination_list(l[i+1:],n-1):
                z.append([l[i]]+j)
        return z


def all_combination_list(l,n):
    total_list = []
    for i in range(1,n+1):
        tmp_list = combination_list(l,i)
        total_list +=tmp_list
    return total_list

def combination(L,n):
    l = list(np.arange(L))
    return all_combination_list(l,n)

class C2MAB_V_direct(object):

    def __init__(self, K, env, T,CB_coefficient,log_ind,LCB_coefficient):
        super(C2MAB_V_direct, self).__init__()
        self.K = K
        self.env = env
        self.T = T
        self.L = self.env.L
        self.C = self.env.C
        self.rewards= np.zeros(self.T)
        self.violation = np.zeros(self.T)
        self.regret = np.zeros(self.T)
        self.regret_cumulative = np.zeros(self.T)
        self.rewards_cumulative = np.zeros(self.T)
        self.choosing_count = np.ones(self.L)
        self.cost = self.env.cost
        self.CB_coefficient =CB_coefficient
        self.log_ind = log_ind
        self.LCB_coefficient =LCB_coefficient

    def run(self):
        starttime = time.time()
        combination_all = combination(self.L,self.K)
        experienced_mu = np.random.uniform(self.env.mu_lower,self.env.mu_upper,self.L)
        experienced_cost = np.random.uniform(self.env.cost_lower,self.env.cost_upper,self.L)

        for t in range(self.T):
            #print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++",t)
            UCB_mu = np.zeros(self.L)
            LCB_cost = np.zeros(self.L)

            At = np.zeros(self.L)
            best_reward = 0
            for i in range(self.L):
                UCB_mu[i] = UCB_value(experienced_mu[i],self.choosing_count[i],t,self.CB_coefficient)
                LCB_cost[i] = LCB_value(experienced_cost[i],self.choosing_count[i],t,self.LCB_coefficient)
            for i in range(len(combination_all)):
                At_tmp_index = combination_all[i]
                At_tmp = np.zeros(self.L)
                for j in range(len(At_tmp_index)):
                    At_tmp[At_tmp_index[j]] = 1
                if np.dot(LCB_cost,At_tmp.T)-self.C <=0:
                    tmp_UCB_mu = np.zeros(len(At_tmp_index))
                    for q in range(len(At_tmp_index)):
                        tmp_UCB_mu[q] = UCB_mu[At_tmp_index[q]]
                    tmp_reward = _or_func(tmp_UCB_mu)
                    if tmp_reward> best_reward:
                        best_reward = tmp_reward
                        At = At_tmp

            tmp_violation = np.dot(self.cost,At.T)-self.C
            if tmp_violation<=0:
                self.violation[t] = 0
            if tmp_violation>0:
                self.violation[t] = np.dot(self.cost,At.T)-self.C

            index_At_choosing = np.flatnonzero(At)
            feedback_cost,reward_t,users_choosing_K = self.env.feedback(At)
            if users_choosing_K.sum()>=1:
                click_At = np.flatnonzero(users_choosing_K)[0]
                # print(len(np.flatnonzero(users_choosing_K)))
            if users_choosing_K.sum()==0:
                click_At = len(index_At_choosing)-1
            for i in range(len(index_At_choosing)):
                self.choosing_count[index_At_choosing[i]]+=1
                experienced_cost[index_At_choosing[i]]=(experienced_cost[index_At_choosing[i]]*(self.choosing_count[index_At_choosing[i]]-1)+feedback_cost[index_At_choosing[i]])/(self.choosing_count[index_At_choosing[i]])
            for i in  range(click_At+1):
                if i<click_At:
                    experienced_mu[index_At_choosing[i]] = (experienced_mu[index_At_choosing[i]]*(self.choosing_count[index_At_choosing[i]]-1)+0)/(self.choosing_count[index_At_choosing[i]])
                if i==click_At:
                    if users_choosing_K[click_At]==1:
                        experienced_mu[index_At_choosing[i]] = (experienced_mu[index_At_choosing[i]]*(self.choosing_count[index_At_choosing[i]]-1)+1)/(self.choosing_count[index_At_choosing[i]])
                        #print("first_click_update")
                    if users_choosing_K[click_At]==0:
                        experienced_mu[index_At_choosing[i]] = (experienced_mu[index_At_choosing[i]]*(self.choosing_count[index_At_choosing[i]]-1)+0)/(self.choosing_count[index_At_choosing[i]])
            self.rewards[t] = reward_t
        return self.rewards, self.violation, starttime