import numpy as np
import pandas as pd
from copy import deepcopy
from src.utils import LIST_FEATURE_DUMMY, combine_dt_dummy_belief, LIST_FEATURE_DUMMY_CB


class LinearBanditsBelief(object):
    __LIST_ACTIONS__ = ['no_action', 'email', 'call']
    def __init__(self, dict_params, dt_env_rewards, dt_env_dummy, belief_estimator, prefix_sep="_abc_"):
        self.dict_params = dict_params
        self.prefix_sep = prefix_sep
        self.dt_env_rewards = dt_env_rewards
        self.dt_env_dummy = dt_env_dummy
        self.belief_estimator = belief_estimator
        self._init_constant()
        np.random.seed(self.dict_params["seed"])

    def _init_constant(self):
        self.current_iter = 0
        self.current_batch = 0
        self.actions = []
        self.rewards = []
        self.believed_rewards = []
        self.hist_belief = []
        self.current_belief = []
        self.lmd = self.dict_params["lmd"]
        self.union_bound = self.dict_params["union_bound"]
        self.UCB_multiply = self.dict_params["UCB_multiply"]

    def _update_belief(self):
        hist_belief = self.belief_estimator.update_belief_one_iteration(vec_init = self.dict_params['vec_init'],
                                                                        number_round = self.current_iter + 1)
        self.current_belief = hist_belief[-1][np.newaxis, :]
        self.hist_belief.append(hist_belief[-1])

    def _get_context(self):
        self.current_env_rewards_grp = \
            self.dt_env_rewards.loc[self.dt_env_rewards['index_env']==(self.current_iter + 1)].reset_index(drop=True)
        self.current_env_dummy_grp = \
            self.dt_env_dummy.loc[self.dt_env_dummy['index_env']==(self.current_iter + 1)].reset_index(drop=True)
        # self.current_env_rewards_log_grp = \
        #     self.dt_env_rewards.loc[self.dt_env_rewards['index_env']<=(self.current_iter + 1)].reset_index(drop=True)
        # self.current_env_dummy_log_grp = \
        #     self.dt_env_dummy.loc[self.dt_env_dummy['index_env']<=(self.current_iter + 1)].reset_index(drop=True)

    def _cal_constant_ucb(self):
        if self.union_bound:
            self.current_constant_UCB = ((self.current_batch + 1) * np.sqrt(self.current_batch * self.dict_params['batch_lengh'])) * self.UCB_multiply
        else:
            self.current_constant_UCB = ((self.current_batch + 1) * np.sqrt(self.dict_params['batch_lengh'])) * self.UCB_multiply

    def _update_weights(self):
        if self.current_iter == 0:
            self.vt_raw = np.outer(self.current_context_dummy.values, self.current_context_dummy.values)
            self.x_r  = self.current_context_dummy.values * self.current_reward
        else:
            self.vt_raw += np.outer(self.current_context_dummy.values, self.current_context_dummy.values)
            self.x_r += self.current_context_dummy.values * self.current_reward

        i = self.current_context_dummy.values.shape[1]
        self.vt = self.vt_raw + np.identity(i) * self.lmd

        self.reward_weights = np.inner(np.linalg.inv(self.vt), self.x_r).T

    def _update_weights_batch(self):
        self.vt_batch = deepcopy(self.vt)
        self.reward_weights_batch = deepcopy(self.reward_weights)

    def _cal_bonus_ucb(self, action, context):
        current_bonus_ucb =  self.current_constant_UCB *  np.linalg.norm(context @ np.linalg.inv(self.vt_batch), axis=1)

        if self.dict_params["verbose"] is True:
            if self.current_iter % 1000 == 0:
                print(f"For action {action}: Average bonus {np.round(np.mean(current_bonus_ucb), 4)}")

        return current_bonus_ucb

    def _take_action(self, random_action = False):
        if random_action:
            action = np.random.choice(self.__LIST_ACTIONS__)
        else:
            action = self._action_max_reward()
        self.actions.append(action)
        self.current_action = action

        return action

    def _update_action_context(self):
        self.current_context_dummy = \
            self.current_env_dummy_grp.loc[self.current_env_dummy_grp['ACTION'] == self.current_action, LIST_FEATURE_DUMMY].reset_index(drop=True)
        self.current_context_dummy = combine_dt_dummy_belief(self.current_context_dummy, self.current_belief)
        # df_selection = pd.DataFrame.from_dict(data = {'index_env': list(range(1, self.current_iter + 2)), 'ACTION': self.actions})
        # self.current_env_rewards_log = \
        #     self.current_env_rewards_log_grp.merge(df_selection, how = 'inner', on = ['index_env', 'ACTION']).reset_index(drop=True)
        # self.current_env_dummy_log = \
        #     self.current_env_dummy_log_grp.merge(df_selection, how = 'inner', on = ['index_env', 'ACTION']).reset_index(drop=True)

    def _get_reward(self):
        current_env_rewards = self.current_env_rewards_grp[self.current_env_rewards_grp['ACTION']==self.current_action].reset_index(drop=True)
        self.current_reward = current_env_rewards["realized_reward"][0]
        self.current_believed_reward = current_env_rewards["believed_reward"][0]
        self.rewards.append(self.current_reward)
        self.believed_rewards.append(self.current_believed_reward)

    def _action_max_reward(self):
        best_action = None
        best_reward = -999999,

        for action_ in self.__LIST_ACTIONS__:
            context_dummy_ = \
                self.current_env_dummy_grp.loc[
                    self.current_env_dummy_grp['ACTION'] == action_, LIST_FEATURE_DUMMY].reset_index(drop=True)
            context_dummy_ = combine_dt_dummy_belief(context_dummy_, self.current_belief)
            current_bonus_ucb_ = self._cal_bonus_ucb(action=action_, context=context_dummy_.values)

            reward_ucb = \
                np.squeeze(np.inner(context_dummy_.values, self.reward_weights_batch)) + current_bonus_ucb_

            if (reward_ucb > best_reward) | ((reward_ucb == best_reward) & (np.random.rand() > 0.5)):
                best_action = action_
                best_reward = reward_ucb

        return best_action

    def _run_one_iteration(self):
        self._update_belief()
        self._get_context()

        if self.dict_params["verbose"] is True:
            if self.current_iter % 1000 == 0:
                print(f"Iteration {self.current_iter}")

        rounds_random = max(self.dict_params["hot_start"], self.dict_params["batch_lengh"] + 1)
        if self.current_iter <= rounds_random - 1:
            self._take_action(random_action=True)
        else:
            self._take_action(random_action=False)

        if self.dict_params["verbose"] is True:
            if self.current_iter % 1000 == 0:
                print(f"current action: {self.current_action}")

        self._update_action_context()
        self._get_reward()
        self._update_weights()

        if (self.current_iter + 1) % self.dict_params["batch_lengh"] == 0:
            self._update_weights_batch()
            self.current_batch += 1

        self._cal_constant_ucb()

    def run_simulation(self):
        for i_ in range(self.dict_params["number_rounds"]):
            self._run_one_iteration()
            if self.dict_params["verbose"] is True:
                if self.current_iter % 1000 == 0:
                    print(f"Cumulative Rewards: {np.round(np.sum(self.rewards), 4)}")
                    print(f"Cumulative Believed Rewards: {np.round(np.sum(self.believed_rewards), 4)}")
                    print("------------------------------------")
            self.current_iter += 1
