from typing import Callable, Dict, Any, List, Optional, Tuple
import math
import numpy as np
import math
from utils import * 

class L2FOB_FPA:
    def __init__(self,
                 T: int,
                 B: Optional[float] = None,
                 rho: Optional[float] = None,
                 gamma: float = 1.2,
                 delta: float = 0.01,
                 eta_rho: float = 0.1,
                 eta_gamma: float = 0.1,
                 phi: Optional[Callable[[float], float]] = None,
                 phi_prime: Optional[Callable[[float], float]] = None,
                 grid_size: int = 101,
                 init_Q_rho: float = 0.0,
                 init_Q_gamma: float = 0.0,
                 use_optimism: bool = False):
        self.T = int(T)
        self.B = None if B is None else float(B)
        self.gamma = float(gamma)
        self.delta = float(delta)
        self.eta_rho = float(eta_rho)
        self.eta_gamma = float(eta_gamma)
        self.grid_size = max(2, int(grid_size))
        self.use_optimism = bool(use_optimism)
        self.remaining_budget = float(self.B)
        if rho is not None:
            self.rho = float(rho)
        elif self.B is not None:
            self.rho = float(self.B) / float(max(1, self.T))
        else:
            self.rho = 0.01
        if phi is None:
            self.phi = lambda x: 0.5 * x * x
        else:
            self.phi = phi
        if phi_prime is None:
            self.phi_prime = lambda x: x
        else:
            self.phi_prime = phi_prime
        self.Q_rho = float(init_Q_rho)
        self.Q_gamma = float(init_Q_gamma)
        self.history_d: List[float] = []
        self.history_v: List[float] = []
        self.history_b: List[float] = []
        self.history_reward: List[float] = []
        self.history_cost: List[float] = []
        self.t = 1
        self._last_ctx: Optional[Dict[str, Any]] = None

    def _empirical_win_prob(self, b: float) -> float:
        if len(self.history_d) == 0:
            return 0.0
        cnt = 0
        for d in self.history_d:
            if d <= b:
                cnt += 1
        return cnt / len(self.history_d)

    def hat_r(self, v: float, b: float) -> float:
        p = self._empirical_win_prob(b)
        if not self.use_optimism:
            return p * max(0.0, v - b)
        else:
            return p * max(0.0, v - b) + math.sqrt(math.log(2*self.T/self.delta) / (2*max(1, self.t-1)))

    def check_c(self, v: float, b: float) -> float:
        p = self._empirical_win_prob(b)
        if not self.use_optimism:
            return p * b
        else:
            return p * b - math.sqrt(math.log((2*self.T/self.delta)) / (2*max(1, self.t-1)))

    def _E_over_v(self, func: Callable[[float, float], float], b: float) -> float:
        if self._last_ctx is None:
            return 0.0
        v = float(self._last_ctx.get("v", 0.0))
        return func(v, b)

    def act(self, ctx: Dict[str, Any]) -> float:
        self._last_ctx = dict(ctx)
        v_t = float(ctx.get("v", 0.0))
        vbar = float(ctx.get("vbar", 1.0))
        if self.remaining_budget <= 0:
            b = 0.0
            self.history_b.append(b)
            return b
        if self.t == 1:
            b = 0.0
            self.history_b.append(b)
            return b
        max_bid = min(v_t, vbar, self.remaining_budget)
        bids = [max_bid * i / (self.grid_size - 1) for i in range(self.grid_size)]
        best_b = 0.0
        best_score = -float("inf")
        for b in bids:
            E_checkc = self._E_over_v(lambda v, bb: self.check_c(v, bb), b)
            E_hatr = self._E_over_v(lambda v, bb: self.hat_r(v, bb), b)
            Q_rho_tb = self.Q_rho + self.eta_rho * max(0.0, E_checkc - self.rho)
            Q_gamma_tb = self.Q_gamma + self.eta_gamma * max(0.0, (self.gamma * E_checkc - E_hatr))
            pen_rho = self.phi_prime(Q_rho_tb) * self.eta_rho * max(0.0, E_checkc - self.rho)
            pen_gamma = self.phi_prime(Q_gamma_tb) * self.eta_gamma * max(0.0, (self.gamma * E_checkc - E_hatr))
            r_hat_current = self.hat_r(v_t, b)
            score = r_hat_current - pen_rho - pen_gamma
            if score > best_score + 1e-12 or (abs(score - best_score) <= 1e-12 and b < best_b):
                best_score = score
                best_b = b
        if self.remaining_budget - best_b < 0.2 and self.remaining_budget < vbar:
            best_b = self.remaining_budget
        b_chosen = float(max(0.0, min(best_b, vbar)))
        self.history_b.append(b_chosen)
        return b_chosen

    def update(self, reward: float, cost: float, info: Dict[str, Any]) -> None:
        self.history_reward.append(float(reward))
        self.history_cost.append(float(cost))
        if self._last_ctx is not None:
            try:
                self.history_v.append(float(self._last_ctx.get("v", 0.0)))
            except Exception:
                pass
        d_compete = info.get("d_compete") if isinstance(info, dict) else None
        if d_compete is not None:
            try:
                self.history_d.append(float(d_compete))
            except Exception:
                pass
        try:
            c = float(cost)
        except Exception:
            c = 0.0
        self.remaining_budget = max(0.0, self.remaining_budget - c)
        b_t = self.history_b[-1] if len(self.history_b) > 0 else 0.0
        E_checkc_bt = self._E_over_v(lambda v, bb: self.check_c(v, bb), b_t)
        E_hatr_bt = self._E_over_v(lambda v, bb: self.hat_r(v, bb), b_t)
        self.Q_rho = self.Q_rho + self.eta_rho * max(0.0, E_checkc_bt - self.rho)
        self.Q_gamma = self.Q_gamma + self.eta_gamma * max(0.0, (self.gamma * E_checkc_bt - E_hatr_bt))
        self.t += 1

    def get_state(self) -> Dict[str, Any]:
        return {
            "t": self.t,
            "Q_rho": self.Q_rho,
            "Q_gamma": self.Q_gamma,
            "rho": self.rho,
            "gamma": self.gamma,
            "total_spent": sum(self.history_cost),
            "total_reward": sum(self.history_reward),
            "n_obs": len(self.history_v),
        }


class L2FOB_Bandit:
    def __init__(self, 
                rmodel,
                cmodel,
                T: int,
                B: Optional[float] = None,
                ARMS: int = 4,
                rho: Optional[float] = None,
                gamma: float = 1.2,
                eta_rho: float = 0.1,
                eta_gamma: float = 0.1,
                init_Q_rho: float = 0.0,
                init_Q_gamma: float = 0.0,):
        self.T = int(T)
        self.B = None if B is None else float(B)
        self.gamma = float(gamma)
        self.eta_rho = float(eta_rho)
        self.eta_gamma = float(eta_gamma)
        print(f"eta gamma: {self.eta_gamma}")
        self.ARMS = int(ARMS)
        self.b = int(0)
        self.rmodel = rmodel
        self.cmodel = cmodel
        if rho is not None:
            self.rho = float(rho)
        elif self.B is not None:
            self.rho = float(self.B) / float(max(1, self.T))
        else:
            self.rho = 0.01
        self.phi = lambda x: x * x
        self.phi_prime = lambda x: 2 * x
        self.Q_rho = float(init_Q_rho)
        self.Q_gamma = float(init_Q_gamma)
        self.history_b: List[float] = []
        self.history_v: List[float] = []
        self.history_reward: List[float] = []
        self.history_cost: List[float] = []
        self.t = 1
        self._last_ctx: Optional[Dict[str, Any]] = None

    def sample(self, dist):
        return np.argmax(np.random.multinomial(1, dist.flatten()))

    def act(self, ctx: Dict[str, Any]) -> float:
        self._last_ctx = dict(ctx)
        v_t = ctx.get("context")
        if self.t == 1:
            self.b = int(0)
            self.history_b.append(self.b)
            return self.b
        E_checkc = self.cmodel.predict(v_t)
        E_checkcr = self.gamma * E_checkc - self.rmodel.predict(v_t)
        Q_rho_tb = self.Q_rho + self.eta_rho * np.maximum(0.0, E_checkc - self.rho)
        Q_gamma_tb = self.Q_gamma + self.eta_gamma * np.maximum(0.0, E_checkcr)
        pen_rho = self.phi_prime(Q_rho_tb) * self.eta_rho * np.maximum(0.0, E_checkc - self.rho)
        pen_gamma = self.phi_prime(Q_gamma_tb) * self.eta_gamma * np.maximum(0.0, (E_checkcr))
        r_hat_current = self.rmodel.predict(v_t)
        score = r_hat_current - pen_rho - pen_gamma
        b_chosen = np.argmax(score)
        self.history_b.append(b_chosen)
        return int(b_chosen)

    def update(self, reward: float, cost: float, info: Dict[str, Any]) -> None:
        context = info.get("context")
        action = int(info.get("bid"))
        self.rmodel.update(context[action], reward)
        self.cmodel.update(action, cost)
        self.history_reward.append(float(reward))
        self.history_cost.append(float(cost))
        if self._last_ctx is not None:
            try:
                self.history_v.append(self._last_ctx.get("context"))
            except Exception:
                pass
        b_t = self.history_b[-1]
        E_checkc = self.cmodel.predict(self._last_ctx.get("context"))[b_t]
        E_checkcr = self.gamma * E_checkc - self.rmodel.predict(self._last_ctx.get("context"))[b_t]
        self.Q_rho = self.Q_rho + self.eta_rho * max(0.0, E_checkc - self.rho)
        self.Q_gamma = self.Q_gamma + self.eta_gamma * max(0.0, E_checkcr)
        self.t += 1

    def get_state(self) -> Dict[str, Any]:
        return {
            "t": self.t,
            "Q_rho": self.Q_rho,
            "Q_gamma": self.Q_gamma,
            "rho": self.rho,
            "gamma": self.gamma,
            "total_spent": sum(self.history_cost),
            "total_reward": sum(self.history_reward),
            "n_obs": len(self.history_v),
        }
