import numpy as np


class Algorithm:
    def __init__(self, params: dict):
        self.params = params

    def select_arm(self, t: int, safety_costs_t: np.ndarray) -> int:
        pass

    def update(self, t: int, arm: int, reward: float, safety_costs_t: np.ndarray):
        pass

    def get_estimation():
        pass


class OFUL(Algorithm):
    def __init__(self, params: dict):
        super().__init__(params)

        # configuration parameters
        self.d = params["d"]  # dimension of the linear parameter
        self.T = params["T"]  # total time horizon
        self.K = params["K"]  # number of arms
        self.arm_set = np.array(params["arm_set"])  # arm set dim: K * d

        self.lam = params["lam"]
        self.sigma = params["sigma"]
        self.S = params["S"]
        self.L = params["L"]
        self.delta = params["delta"]

        self.Sigma = np.eye(self.d) * self.lam
        self.Obs = np.zeros(self.d)
        self.theta_hat = np.zeros(self.d)
        self.rad = self.S * np.sqrt(self.lam) + self.sigma * np.sqrt(
            self.d * np.log(1 / self.delta)
        )

    def select_arm(self, t: int, safety_costs_t: np.ndarray) -> int:
        r_hat = np.zeros(self.K)
        for a in range(self.K):
            x_a = self.arm_set[a, :]
            r_hat[a] = self.compute_ucb(x_a, self.Sigma, self.theta_hat, self.rad)

        indices_t = r_hat
        a_t = np.argmax(indices_t)
        return a_t

    @staticmethod
    def compute_ucb(
        x: np.ndarray, Sigma: np.ndarray, theta_hat: np.ndarray, radius: float
    ) -> float:
        mean = x.T @ theta_hat
        sigma_inv_x = np.linalg.solve(Sigma, x)
        ucb = mean + radius * np.sqrt(x.T @ sigma_inv_x)
        return ucb

    def update(self, t: int, arm: int, reward: float, safety_costs_t: np.ndarray):

        x_t = self.arm_set[arm, :]
        self.Sigma += np.outer(x_t, x_t)
        self.Obs += reward * x_t
        self.theta_hat = np.linalg.solve(self.Sigma, self.Obs)
        self.rad = self.S * np.sqrt(self.lam) + self.sigma * np.sqrt(
            self.d * np.log((1.0 + (t) * self.L * self.L / self.lam) / self.delta)
        )

    def get_estimation(self) -> np.ndarray:
        return self.theta_hat


class PO(Algorithm):
    def __init__(self, params: dict):
        super().__init__(params)

        # configuration parameters
        self.d = params["d"]  # dimension of the linear parameter
        self.T = params["T"]  # total time horizon
        self.K = params["K"]  # number of arms
        self.arm_set = np.array(params["arm_set"])  # arm set dim: K * d

        self.lam = params["lam"]
        self.sigma = params["sigma"]
        self.S = params["S"]
        self.L = params["L"]
        self.delta = params["delta"]

        self.V_list = [self.delta * np.sqrt(t) for t in range(self.T + 1)]
        self.epsilon_list = [0.4 / (np.sqrt(t) + (t < 0.5)) for t in range(self.T + 1)]

        # statistics

        self.Q = 0.0
        self.Sigma = np.eye(self.d) * self.lam
        self.Obs = np.zeros(self.d)
        self.theta_hat = np.zeros(self.d)
        self.rad = self.S * np.sqrt(self.lam) + self.sigma * np.sqrt(
            self.d * np.log(1 / self.delta)
        )

    def select_arm(self, t: int, safety_costs_t: np.ndarray) -> int:
        r_hat = np.zeros(self.K)
        for a in range(self.K):
            x_a = self.arm_set[a, :]
            r_hat[a] = self.compute_ucb(x_a, self.Sigma, self.theta_hat, self.rad)

        indices_t = r_hat - safety_costs_t * self.Q / self.V_list[t]
        a_t = np.argmax(indices_t)
        return a_t

    @staticmethod
    def compute_ucb(
        x: np.ndarray, Sigma: np.ndarray, theta_hat: np.ndarray, radius: float
    ) -> float:
        mean = x.T @ theta_hat
        sigma_inv_x = np.linalg.solve(Sigma, x)
        ucb = mean + radius * np.sqrt(x.T @ sigma_inv_x)
        return ucb

    def update(self, t: int, arm: int, reward: float, safety_costs_t: np.ndarray):

        self.Q = max(0.0, self.Q + safety_costs_t[arm] + self.epsilon_list[t])
        x_t = self.arm_set[arm, :]
        self.Sigma += np.outer(x_t, x_t)
        self.Obs += reward * x_t
        self.theta_hat = np.linalg.solve(self.Sigma, self.Obs)
        self.rad = self.S * np.sqrt(self.lam) + self.sigma * np.sqrt(
            self.d * np.log((1.0 + (t) * self.L * self.L / self.lam) / self.delta)
        )

    def get_estimation(self) -> np.ndarray:
        return self.theta_hat


class SERMiSC(Algorithm):
    def __init__(self, params: dict):
        super().__init__(params)

        # configuration parameters
        self.d = params["d"]  # dimension of the linear parameter
        self.T = params["T"]  # total time horizon
        self.T_1 = params["T_1"]  # Phase I time horizon
        self.K = params["K"]  # number of arms
        self.arm_set = np.array(params["arm_set"])  # arm set dim: K * d
        self.pi_prime = np.array(params["pi_prime"])  # fixed exploration policy

        self.lam = params["lam"]
        self.sigma = params["sigma"]
        self.S = params["S"]
        self.L = params["L"]
        self.delta = params["delta"]

        q = params["q_list"]  # hyper-parameter for the exploration policy
        self.q_list = [q for _ in range(self.T + 1)]

        self.zeta_list = [
            1.0 / (60.0 * np.sqrt(2 * (t < 0.5) + (t < 1.5) + 2 * t - 2))
            for t in range(self.T + 1)
        ] / np.log(self.T_1)

        self.V_list = [self.delta * np.sqrt(t) for t in range(self.T + 1)]
        self.epsilon_list = [0.4 / (np.sqrt(t) + (t < 0.5)) for t in range(self.T + 1)]

        # statistics
        self.phase_I_flag = True
        self.C_hat = np.zeros(self.K)
        self.N_ = np.zeros(self.K)
        self.Y_ = np.zeros(self.K)
        self.pi_t_a_t = None

        self.Q = 0.0
        self.Sigma = np.eye(self.d) * self.lam
        self.Obs = np.zeros(self.d)
        self.theta_hat = np.zeros(self.d)
        self.rad = self.S * np.sqrt(self.lam) + self.sigma * np.sqrt(
            self.d * np.log(1 / self.delta)
        )

    def select_arm(self, t: int, safety_costs_t: np.ndarray) -> int:
        is_phase_I = t <= self.T_1
        if is_phase_I:  # Phase I
            q_t = self.q_list[t]
            zeta_t = self.zeta_list[t]
            exp_t = np.exp(-self.C_hat * zeta_t)
            exp_t /= exp_t.sum()
            pi_t = self.pi_prime * q_t + exp_t * (1 - q_t)
            a_t = np.random.choice(self.K, p=pi_t)
            self.pi_t_a_t = pi_t[a_t]
        else:  # Phase II
            self.phase_I_flag = False
            r_hat = np.zeros(self.K)
            for a in range(self.K):
                x_a = self.arm_set[a, :]
                r_hat[a] = self.compute_ucb(x_a, self.Sigma, self.theta_hat, self.rad)
            # r_hat = np.clip(r_hat, -1.0, 1.0)
            indices_t = r_hat - safety_costs_t * self.Q / self.V_list[t]
            a_t = np.argmax(indices_t)
        return a_t

    @staticmethod
    def compute_ucb(
        x: np.ndarray, Sigma: np.ndarray, theta_hat: np.ndarray, radius: float
    ) -> float:
        mean = x.T @ theta_hat
        sigma_inv_x = np.linalg.solve(Sigma, x)
        ucb = mean + radius * np.sqrt(x.T @ sigma_inv_x)
        return ucb

    def update(self, t: int, arm: int, reward: float, safety_costs_t: np.ndarray):
        is_phase_I = t <= self.T_1
        if is_phase_I:
            self.N_[arm] += 1
            self.Y_[arm] += reward

            assert self.pi_t_a_t is not None
            self.C_hat[arm] += safety_costs_t[arm] / self.pi_t_a_t
            self.pi_t_a_t = None
        else:
            self.Q = max(0.0, self.Q + safety_costs_t[arm] + self.epsilon_list[t])
            x_t = self.arm_set[arm, :]
            self.Sigma += np.outer(x_t, x_t)
            self.Obs += reward * x_t
            self.theta_hat = np.linalg.solve(self.Sigma, self.Obs)
            self.rad = self.S * np.sqrt(self.lam) + self.sigma * np.sqrt(
                self.d
                * np.log(
                    (1.0 + (t - self.T_1) * self.L * self.L / self.lam) / self.delta
                )
            )

    def get_estimation(self) -> np.ndarray:
        X_stat = self.arm_set * (np.sqrt(self.N_)[:, np.newaxis])

        N_ = self.N_.copy()
        N_[N_ < 0.5] = 1.0
        Y_tilde = self.Y_ / np.sqrt(N_)

        # theta_stat = np.linalg.inv(X_stat.T @ X_stat) @ X_stat.T @ Y_tilde
        theta_stat = (
            np.linalg.inv(X_stat.T @ X_stat + np.eye(self.d) * self.lam)
            @ X_stat.T
            @ Y_tilde
        )
        return theta_stat


ALG_REGISTRY = {
    "SERMiSC": SERMiSC,
    "OFUL": OFUL,
    "PO": PO,
}
