from dataclasses import dataclass
from typing import Dict, Any, Optional
import numpy as np

@dataclass
class StepResult:
    reward: float
    cost: float
    info: Dict[str, Any]

def compute_optimal_fixed_arm(env, T: int, B: float, gamma: float):
    N, K = env.relev.shape
    assert T <= N

    relev = env.relev[:T]
    costs = np.array(env.cost)

    best_arm, best_reward, best_cost = None, -1, 0.0

    for k in range(K):
        r_total = relev[:, k].sum()
        c_total = costs[k] * T

        if c_total <= B and r_total >= gamma * c_total:
            if r_total > best_reward:
                best_arm, best_reward, best_cost = k, r_total, c_total

    return best_arm, best_reward, best_cost


class BanditEnv:
    def __init__(self, context: np.ndarray, relev: np.ndarray,
                 T: int = 5000, tot_budget: float = 500,
                 cost: Optional[np.ndarray] = None,
                 reward_noise: float = 0.05, cost_noise: float = 0.05,
                 seed: int = 0):
        self.context = context
        self.relev = relev
        _, self.K = relev.shape

        self.T = T
        self.budget = tot_budget
        self.rng = np.random.default_rng(seed)

        self.cost = (self.rng.normal(0.2, 0.02, size=self.K)
                        if cost is None else np.array(cost, dtype=float))
        print("Arm costs:", self.cost)
        print("Min cost:", np.min(self.cost), "Max cost:", np.max(self.cost))

        self.reward_noise = reward_noise
        self.cost_noise = cost_noise

        self._cur_idx: Optional[int] = None
        self._has_context = False
        self.round = -1

    def reset(self, seed: Optional[int] = None):
        if seed is not None:
            self.rng = np.random.default_rng(seed)
        self.round = -1
        self._cur_idx = None
        self._has_context = False

    def get_context(self) -> Dict[str, Any]:
        self.round = int(self.rng.integers(0, self.T))
        self._cur_idx = self.round
        self._has_context = True
        return {
            "context": self.context[self.round],
            "arms": self.K,
            "round": self.round,
        }

    def step(self, arm: int, keep_context: bool = False) -> StepResult:
        if not self._has_context:
            raise RuntimeError("Call get_context() before step().")

        if arm < 0 or arm >= self.K:
            raise ValueError(f"arm must be in [0,{self.K-1}]")

        arm = int(arm)
        r_true = float(self.relev[self._cur_idx, arm])
        c_true = float(self.cost[arm])

        r_obs = r_true + self.rng.normal(0, self.reward_noise)
        c_obs = c_true + self.rng.normal(0, self.cost_noise)

        info = {
            "arm": arm,
            "context": self.context[self._cur_idx],
            "reward_true": r_true,
            "cost_true": c_true,
            "reward_observed": r_obs,
            "cost_observed": c_obs,
        }

        if not keep_context:
            self._has_context = False
            self._cur_idx = None

        return StepResult(reward=r_obs, cost=c_obs, info=info)
