import numpy as np


class play_GenCTS:
    def __init__(
        self,
        action_list,
        cost_matrix,
        regime,
        n_i,
        loss_list,
        time_horizon,
        change_dist_ratio=1.0,
    ):
        self.action_list = action_list
        print("There are {} actions.".format(len(self.action_list)))
        self.cost_matrix = cost_matrix
        self.supplier_num = len(cost_matrix)
        self.demander_num = len(cost_matrix[0])
        self.regime = regime
        self.n_i = n_i
        self.loss_list = loss_list
        self.time_horizon = time_horizon
        self.change_dist_ratio = change_dist_ratio

        self.p = np.ones_like(self.cost_matrix)
        self.q = np.ones_like(self.cost_matrix)

        self.pull_num = np.zeros_like(self.cost_matrix)
        # Maintain the chosen action and loss L
        self.chosen_action_list = []

    def sample_theta(
        self,
    ):
        theta = np.zeros_like(self.cost_matrix)
        for i in range(self.supplier_num):
            for j in range(self.demander_num):
                theta[i][j] = np.random.beta(self.p[i][j], self.q[i][j], 1)
        return theta

    def oracle(self, theta):
        current = float("inf")
        a_t = None
        a_t_index = None
        for a_i, action in enumerate(self.action_list):
            summation = np.sum(theta * action)
            if current > summation:
                current = summation
                a_t = action
                a_t_index = a_i
        return a_t, a_t_index

    def update_beta(self, collected_obs):
        for i in range(self.supplier_num):
            for j in range(self.demander_num):
                sample_num = len(collected_obs[i][j])
                for x in range(sample_num):
                    Y_ij_x = np.random.binomial(1, collected_obs[i][j][x], 1)
                    self.p[i][j] += Y_ij_x
                    self.q[i][j] += 1 - Y_ij_x

    def return_cumulative_regret(
        self,
    ):
        ret = [0]

        if self.regime == "Stochastic" or self.regime == "StochasticWithCorruption":
            # Find the best action
            best_action, current_value = None, float("inf")
            mean_cost_matrix = self.cost_matrix * self.change_dist_ratio + (
                1 - self.cost_matrix
            ) * (1 - self.change_dist_ratio)
            for a_i, action in enumerate(self.action_list):
                dummy = np.sum(action * (mean_cost_matrix))
                if dummy < current_value:
                    best_action = action
                    current_value = dummy
                    print(best_action)

            for t, loss in enumerate(self.loss_list):
                loss = np.sum(
                    mean_cost_matrix * (self.chosen_action_list[t][0] - best_action)
                )
                ret.append(ret[-1] + loss)

        return ret

    def run(
        self,
    ):
        for t in range(self.time_horizon):

            # Loss
            L = self.loss_list[t]
            if (
                self.regime == "StochasticWithCorruption"
                and t >= self.time_horizon * self.change_dist_ratio
            ):
                L_new = [
                    [np.zeros(int(self.n_i[i][j])) for j in range(self.demander_num)]
                    for i in range(self.supplier_num)
                ]
                for i in range(self.supplier_num):
                    for j in range(self.demander_num):
                        for k in range(int(self.n_i[i][j])):
                            L_new[i][j][k] = np.random.uniform(
                                (1 - 2 * self.cost_matrix[i][j]),
                                1,
                                1,
                            )

                L = L_new

            # Sample theta
            theta_t = self.sample_theta()

            # Choose an action
            a_t, a_t_index = self.oracle(theta_t)

            # Store the chosen action and loss L
            self.chosen_action_list.append((a_t, a_t_index))

            # Collect Observation
            collected_obs = [
                [[] for j in range(self.demander_num)] for i in range(self.supplier_num)
            ]
            for i in range(self.supplier_num):
                for j in range(self.demander_num):
                    for x in range(int(a_t[i][j])):
                        collected_obs[i][j].append(L[i][j][x])

            # Update beta distribution
            self.update_beta(collected_obs)

        # Compute the cumultive regret
        cumulative_regret_list = self.return_cumulative_regret()

        return cumulative_regret_list
