import numpy as np

from find_optimal import find_optimal
import scipy.stats as stats
from get_api_mu import answer_multiple_choice_questions


def get_value_bound(price, L):
    value_lower = np.zeros(L)
    value_upper = np.zeros(L)
    for i in range(L):
        if 1 - price[i] <= price[i]:
            value_lower[i] = 2 * price[i] - 1
            value_upper[i] = 1
        if price[i] < 1 - price[i]:
            value_lower[i] = 0
            value_upper[i] = 2 * price[i]
    return value_lower, value_upper


def get_cost_bound(cost, L):
    cost_lower = np.zeros(L)
    cost_upper = np.zeros(L)
    for i in range(L):
        if 1 - cost[i] <= cost[i]:
            cost_lower[i] = 2 * cost[i] - 1
            cost_upper[i] = 1
        if cost[i] < 1 - cost[i]:
            cost_lower[i] = 0
            cost_upper[i] = 2 * cost[i]
    return cost_lower, cost_upper


def get_mu(price, L):
    mu = np.zeros(L)
    for i in range(L):
        mu[i] = price[i] / 2
    return mu


def get_feedback_cost(cost_lower, cost_upper, cost, sigma=0.08):
    x = stats.truncnorm(
        (cost_lower - cost) / sigma, (cost_upper - cost) / sigma, loc=cost, scale=sigma)
    feedback_cost = x.rvs(1)[0]
    return feedback_cost


def get_value_choose(value_lower, value_upper, price, sigma=0.08):
    x = stats.truncnorm(
        (value_lower - price) / sigma, (value_upper - price) / sigma, loc=price, scale=sigma)
    value_choose = x.rvs(1)[0]
    return value_choose


def exp_reward_cost(mu, cost, commend_list):
    exp_reward_t = 0
    for i in range(len(commend_list)):
        exp_reward_t = exp_reward_t + mu[commend_list[i]]
    return exp_reward_t


class Environment(object):
    def __init__(self, L, C, synthetic=True):
        super(Environment, self).__init__()
        if synthetic:
            self.L = L
            self.C = C
            questions = []  # from SciQ
            self.price = answer_multiple_choice_questions(questions)
            self.cost = []
            self.value_lower, self.value_upper = get_value_bound(self.price, self.L)
            self.cost_lower, self.cost_upper = get_cost_bound(self.cost, self.L)
            self.mu = get_mu(self.price, self.L)

    def total_price(self, v):
        total_price = 0
        users_choosing = np.flatnonzero(v)
        for i in range(len(users_choosing)):
            total_price = total_price + self.price[users_choosing[i]]
        return total_price

    def feedback(self, A):
        commend_list = np.flatnonzero(A)
        exp_reward_t = exp_reward_cost(self.mu, self.cost, commend_list)
        users_choosing_L = np.zeros(self.L)
        feedback_cost = np.zeros(self.L)
        value_choose = np.zeros(self.L)
        for i in range(len(commend_list)):
            feedback_cost[commend_list[i]] = get_feedback_cost(self.cost_lower[commend_list[i]],
                                                               self.cost_upper[commend_list[i]],
                                                               self.cost[commend_list[i]], 1)
            value_choose[commend_list[i]] = get_value_choose(self.value_lower[commend_list[i]],
                                                             self.value_upper[commend_list[i]],
                                                             self.price[commend_list[i]], 1)
            if value_choose[commend_list[i]] >= self.price[commend_list[i]]:
                users_choosing_L[commend_list[i]] = 1
        total_price = self.total_price(users_choosing_L)


        return users_choosing_L, feedback_cost, total_price, exp_reward_t

    def get_best_reward(self, K, C):
        optimal_result = find_optimal(self.L, K, C, self.mu, self.cost)
        best_reward = np.dot(optimal_result, (self.mu).T)
        return best_reward, optimal_result