import numpy as np
import scipy.stats as stats
from find_optimal_min import find_optimal_min


def get_mu_bound(mu, L):
    mu_lower = np.zeros(L)
    mu_upper = np.zeros(L)
    for i in range(L):
        if 1 - mu[i] <= mu[i]:
            mu_lower[i] = 2 * mu[i] - 1
            mu_upper[i] = 1
        if mu[i] < 1 - mu[i]:
            mu_lower[i] = 0
            mu_upper[i] = 2 * mu[i]
    return mu_lower, mu_upper


def get_cost_bound(cost, L):
    cost_lower = np.zeros(L)
    cost_upper = np.zeros(L)
    for i in range(L):
        if 1 - cost[i] <= cost[i]:
            cost_lower[i] = 2 * cost[i] - 1
            cost_upper[i] = 1
        if cost[i] < 1 - cost[i]:
            cost_lower[i] = 0
            cost_upper[i] = 2 * cost[i]
    return cost_lower, cost_upper

def get_feedback_cost(cost_lower, cost_upper, cost, sigma=0.08):
    x = stats.truncnorm(
        (cost_lower - cost) / sigma, (cost_upper - cost) / sigma, loc=cost, scale=sigma)
    feedback_cost = x.rvs(1)[0]
    return feedback_cost


class Environment(object):
    def __init__(self, L, C, mu, cost):
        super(Environment, self).__init__()
        self.L = L
        self.C = C
        self.mu = mu
        self.mu_lower, self.mu_upper = get_mu_bound(self.mu, L)
        self.cost = cost
        self.cost_lower, self.cost_upper = get_cost_bound(self.cost, self.L)

    def _or_func(self, v):
        return 1 - np.prod(1 - v)

    def feedback(self, A):
        commend_list = np.flatnonzero(A)
        tmp_k_mu = np.zeros(len(commend_list))
        users_choosing_K = np.zeros(len(commend_list))
        for i in range(len(commend_list)):
            tmp_k_mu[i] = self.mu[commend_list[i]]
            users_choosing_K[i] = np.random.binomial(1, tmp_k_mu[i])
        if users_choosing_K.sum() > 1:
            first_click = np.flatnonzero(users_choosing_K)[0]
            users_choosing_K[first_click + 1:] = 0

        feedback_cost = np.zeros(self.L)
        for i in range(len(commend_list)):
            feedback_cost[commend_list[i]] = get_feedback_cost(self.cost_lower[commend_list[i]],
                                                               self.cost_upper[commend_list[i]],
                                                               self.cost[commend_list[i]], 1)

        return feedback_cost, self._or_func(tmp_k_mu), users_choosing_K

    def get_best_reward(self, K, C):
        log_mu = np.zeros(len(self.mu))
        for i in range(len(self.mu)):
            log_mu[i] = np.log(1 - self.mu[i])
        optimal_result = find_optimal_min(self.L, K, C, log_mu, self.cost)
        choose_list = np.flatnonzero(optimal_result)
        tmp = np.zeros(len(np.flatnonzero(optimal_result)))
        for i in range(len(tmp)):
            tmp[i] = self.mu[choose_list[i]]
        best_reward = self._or_func(tmp)
        return best_reward, optimal_result
