import numpy as np
from scipy.optimize import minimize


class Environment:
    def __init__(self, env_config: dict):
        self.env_config = env_config

        self.K = env_config["K"]
        self.d = env_config["d"]
        self.arm_set = np.array(env_config["arm_set"])  # dim: K * d
        self.theta_star = np.array(env_config["theta_star"])  # dim: d
        self.c_mean = np.array(env_config["c_mean"])  # dim: K
        self.c_sigma = env_config["c_sigma"]  # dim: 1, cost noise std
        self.sigma = env_config["sigma"]  # dim: 1, reward noise std

        self.pi_star = None
        self.opt_reward = None
        self.calculate_pi_star()

        # statistics
        self.arm_pull_history = []
        self.cumu_regret = []
        self.cumu_safety_cost = []

    def calculate_pi_star(self):
        mu = self.arm_set @ self.theta_star
        c_mean = self.c_mean

        # 1. objective function
        objective = lambda pi: -np.dot(pi, mu)

        # 2. constraints
        constraints = [
            {"type": "ineq", "fun": lambda pi: -np.sum(pi * c_mean)},
            {"type": "eq", "fun": lambda pi: np.sum(pi) - 1},
        ]

        # 3. bounds
        bounds = [(0, 1) for _ in range(self.K)]

        # 4. initial guess
        init_guess = np.ones(self.K) / self.K

        result = minimize(objective, init_guess, bounds=bounds, constraints=constraints)
        self.pi_star = result.x
        self.opt_reward = -result.fun

        #print(f"mu: {mu}, c_mean: {self.c_mean}")
        #print(f"pi_star: {self.pi_star}, opt_reward: {self.opt_reward}")

    def get_safety_costs_t(self) -> np.ndarray:
        return self.c_mean + np.random.randn(self.K) * self.c_sigma

    def get_rewards_t(self) -> np.ndarray:
        return self.arm_set @ self.theta_star + np.random.randn(self.K) * self.sigma

    def update_statistics(
        self, t: int, arms_t: list, rewards_t: np.ndarray, safety_costs_t: np.ndarray
    ):
        self.arm_pull_history.append(arms_t)

        if t > 1:
            self.cumu_regret.append(
                self.cumu_regret[-1] + self.opt_reward - rewards_t[arms_t]
            )

        else:
            self.cumu_regret.append(self.opt_reward - rewards_t[arms_t])

        if t > 1:
            self.cumu_safety_cost.append(
                self.cumu_safety_cost[-1] + safety_costs_t[arms_t]
            )
        else:
            self.cumu_safety_cost.append(safety_costs_t[arms_t])

    def get_results(self):

        return {
            "arm_pull_history": self.arm_pull_history,
            "cumu_regret": self.cumu_regret,
            "cumu_safety_cost": self.cumu_safety_cost,
        }
