import numpy as np
from find_optimal import find_optimal
import scipy.stats as stats
from solve_constrain import solve_constrain
from find_optimal import find_optimal

def get_mu_bound(mu,L):
    mu_lower = np.zeros(L)
    mu_upper = np.zeros(L)
    for i in range(L):
        if 1-mu[i]<=mu[i]:  
            mu_lower[i] = 2*mu[i]-1
            mu_upper[i] = 1
        if mu[i]<1-mu[i]:
            mu_lower[i] = 0
            mu_upper[i] = 2*mu[i]
    return mu_lower,mu_upper
    
def get_cost_bound(cost,L):
    cost_lower = np.zeros(L)
    cost_upper = np.zeros(L)
    for i in range(L):
        if 1-cost[i]<=cost[i]:  
            cost_lower[i] = 2*cost[i]-1
            cost_upper[i] = 1
        if cost[i]<1-cost[i]:
            cost_lower[i] = 0
            cost_upper[i] = 2*cost[i]
    return cost_lower,cost_upper


def get_feedback_cost(cost_lower,cost_upper,cost,sigma):
    x =stats.truncnorm(
    (cost_lower - cost) / sigma, (cost_upper - cost) / sigma, loc=cost, scale=sigma)
    feedback_cost = x.rvs(1)[0]
    return feedback_cost



class Environment(object):
    def __init__(self, L,C,mu,cost): 
        super(Environment, self).__init__()
        self.L = L
        self.C = C
        self.mu = mu
        self.mu_lower,self.mu_upper = get_mu_bound(self.mu,L)
        self.cost = cost
        self.cost_lower,self.cost_upper = get_cost_bound(self.cost,self.L)

    def conjunctive_reward_func(self, v):
        return np.prod(v)

    def feedback(self,A):
        commend_list = np.flatnonzero(A)
        tmp_k_mu = np.zeros(len(commend_list))
        users_choosing_K = np.zeros(len(commend_list)) # len == K
        for i in range(len(commend_list)):
            tmp_k_mu[i]=self.mu[commend_list[i]]
            users_choosing_K[i]=np.random.binomial(1,tmp_k_mu[i])
        
        if users_choosing_K.sum() == len(commend_list)-1:
            zero = np.argwhere(users_choosing_K==0)[0][0]
            if zero!=(len(commend_list)-1):
                users_choosing_K[zero + 1 : ] = 0
        if users_choosing_K.sum() < len(commend_list)-1:
            first_zero = np.argwhere(users_choosing_K==0)[0][0]
            users_choosing_K[first_zero + 1 : ] = 0
        
        feedback_cost = np.zeros(self.L)
        for i in range(len(commend_list)):
            feedback_cost[commend_list[i]] = get_feedback_cost(self.cost_lower[commend_list[i]],self.cost_upper[commend_list[i]],self.cost[commend_list[i]],1)

        return feedback_cost,self.conjunctive_reward_func(tmp_k_mu),users_choosing_K

    def get_best_reward(self,K,C):
        log_mu = np.zeros(len(self.mu))
        for i in range(len(self.mu)):
            log_mu[i] = np.log(self.mu[i])
        optimal_result = find_optimal(self.L,K,C,log_mu,self.cost)
        choose_list = np.flatnonzero(optimal_result)   
        tmp = np.zeros(len(np.flatnonzero(optimal_result)))
        for i in range(len(tmp)):
            tmp[i]=self.mu[choose_list[i]]
        best_reward = self.conjunctive_reward_func(tmp) 
        # print("mu",self.mu)
        # print("cost",self.cost)
        return  best_reward,optimal_result