import numpy as np

from Algorithms.utils import OFTRL_CVXPY


class play_GenLBINFV:
    def __init__(
        self,
        action_list,
        cost_matrix,
        regime,
        m_estimation,
        n_i,
        loss_list,
        time_horizon=2000000,
        change_dist_ratio=1.0,
        eta=1 / 4,
    ):
        self.action_list = action_list
        print("There are {} actions.".format(len(self.action_list)))
        self.cost_matrix = cost_matrix
        self.supplier_num = len(self.cost_matrix)
        self.demander_num = len(self.cost_matrix[0])
        self.regime = regime
        self.m_estimation = m_estimation
        self.n_i = n_i
        self.loss_list = loss_list
        self.time_horizon = time_horizon
        self.gamma = np.log(self.time_horizon)
        self.eta = eta
        self.epsilon = 1 / 4 * self.n_i
        self.change_dist_ratio = change_dist_ratio

        self.lhat_sum = np.zeros_like(self.cost_matrix)
        self.alpha_sum = np.zeros_like(cost_matrix)
        self.m = np.ones_like(self.cost_matrix) / 2
        self.beta = np.sqrt((1 + self.epsilon / self.n_i) ** 2)
        self.L_sum = np.zeros_like(self.cost_matrix)

        self.pull_num = np.zeros_like(self.cost_matrix)

        # Maintain the chosen action and loss L
        self.chosen_action_list = []

    def update_beta(
        self,
        a_t,
        sample_mean,
        x_t,
    ):
        alpha = np.zeros((self.supplier_num, self.demander_num))
        for i in range(self.supplier_num):
            for j in range(self.demander_num):
                alpha[i][j] = (
                    (a_t[i][j] / self.n_i[i][j]) ** 2
                    * (sample_mean[i][j] - self.m[i][j]) ** 2
                    * np.minimum(
                        1,
                        2
                        * (1 - x_t[i][j] / self.n_i[i][j])
                        / ((x_t[i][j] / self.n_i[i][j]) ** 2 * self.gamma),
                    )
                )
        # update alpha_sum
        self.alpha_sum += alpha

        # update beta
        self.beta = np.sqrt(
            (1 + self.epsilon / self.n_i) ** 2 + (1 / self.gamma) * self.alpha_sum
        )

    def update_m(
        self,
        L,
        a_t,
        sample_mean,
    ):
        if self.m_estimation == "LeastSquare":
            # Update the number of arm pull and L_sum\
            for i in range(self.supplier_num):
                for j in range(self.demander_num):
                    self.pull_num[i][j] += a_t[i][j]
                    self.L_sum[i][j] += np.sum(L[i][j][: int(a_t[i][j])])

            self.m = 1 / (1 + self.pull_num) * (1 / 2 + self.L_sum)
        elif self.m_estimation == "GradientDescent":
            for i in range(self.supplier_num):
                for j in range(self.demander_num):
                    if a_t[i][j] > 0:
                        self.m[i][j] = (1 - self.eta) * self.m[i][
                            j
                        ] + self.eta * sample_mean[i][j]

    def return_cumulative_regret(
        self,
    ):
        ret = [0]

        if self.regime == "Stochastic" or self.regime == "StochasticWithCorruption":
            # Find the best action
            best_action, current_value = None, float("inf")
            mean_cost_matrix = self.cost_matrix * self.change_dist_ratio + (
                1 - self.cost_matrix
            ) * (1 - self.change_dist_ratio)
            for a_i, action in enumerate(self.action_list):
                dummy = np.sum(action * (mean_cost_matrix))
                if dummy < current_value:
                    best_action = action
                    current_value = dummy
                    print(best_action)

            for t, loss in enumerate(self.loss_list):
                loss = np.sum(
                    mean_cost_matrix * (self.chosen_action_list[t][0] - best_action)
                )
                ret.append(ret[-1] + loss)

        return ret

    def run(self):

        for t in range(self.time_horizon):
            if t == 0:
                p_intial = np.array(
                    [1 / len(self.action_list) for _ in range(len(self.action_list))]
                )

            # Loss
            L = self.loss_list[t]
            if (
                self.regime == "StochasticWithCorruption"
                and t >= self.time_horizon * self.change_dist_ratio
            ):
                L_new = [
                    [np.zeros(int(self.n_i[i][j])) for j in range(self.demander_num)]
                    for i in range(self.supplier_num)
                ]
                for i in range(self.supplier_num):
                    for j in range(self.demander_num):
                        for k in range(int(self.n_i[i][j])):
                            L_new[i][j][k] = np.random.uniform(
                                (1 - 2 * self.cost_matrix[i][j]),
                                1,
                                1,
                            )
                L = L_new

            x_t, a_t, p_t, a_t_index = OFTRL_CVXPY(
                m=self.m,
                lhat_sum=self.lhat_sum,
                action_list=self.action_list,
                beta=self.beta,
                n_i=self.n_i,
                gamma=self.gamma,
                p_initial=p_intial,
            )
            p_intial = p_t

            # Store the chosen action
            self.chosen_action_list.append((a_t, a_t_index))

            # Sample mean
            sample_mean = np.zeros_like(self.cost_matrix)
            for i in range(self.supplier_num):
                for j in range(self.demander_num):
                    if int(a_t[i][j]) > 0:
                        sample_mean[i][j] = np.mean(L[i][j][: int(a_t[i][j])])

            # Update beta (and alpha)
            self.update_beta(a_t=a_t, sample_mean=sample_mean, x_t=x_t)

            # Update m
            self.update_m(
                L=L,
                a_t=a_t,
                sample_mean=sample_mean,
            )

            # Update l_hat
            self.lhat_sum += self.m + a_t / x_t * (sample_mean - self.m)

            if t % 100 == 0:
                print(t, 100 * t / self.time_horizon)
                # print(self.alpha_sum)

                print(np.sort(p_t)[-10:])
        # Compute the cumultive regret
        cumulative_regret_list = self.return_cumulative_regret()

        return cumulative_regret_list
