from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from tqdm import tqdm

Number = Union[int, float]
OptionalNumber = Optional[Number]

@dataclass
class RoundLog:
    t: int
    bid: float
    reward: float
    cost: float
    gamma_cost_minus_reward: float
    cum_gamma_loss: float
    budget_left: Optional[float]
    info: Dict[str, Any]

@dataclass
class RunResult:
    rounds: List[RoundLog]
    total_reward: float
    total_cost: float
    total_gamma_loss: float
    steps_called: int
    steps_total: int
    stopped_by_budget: bool
    stop_round: Optional[int]

class OnlineInteractionRunner:
    def __init__(self,
                 env: Any,
                 algo: Any,
                 T: int,
                 budget: OptionalNumber = None,
                 gamma: float = 1.2):
        self.env = env
        self.algo = algo
        self.T = int(T)
        self.budget = None if budget is None else float(budget)
        self.gamma = float(gamma)

        if not hasattr(algo, "act") or not callable(getattr(algo, "act")):
            raise TypeError("algo must implement act(ctx) -> bid")
        if not hasattr(algo, "update") or not callable(getattr(algo, "update")):
            raise TypeError("algo must implement update(reward, cost, info)")

    @staticmethod
    def _to_float(x, default=0.0) -> float:
        try:
            v = float(x)
        except Exception:
            v = float(default)
        if v != v or v == float("inf") or v == float("-inf"):
            return float(default)
        return v

    def run(self) -> RunResult:
        logs: List[RoundLog] = []
        total_reward = 0.0
        total_cost = 0.0
        cum_gamma_loss = 0.0
        steps_called = 0
        stopped_by_budget = False
        stop_round: Optional[int] = None

        remaining = None if self.budget is None else float(self.budget)

        with tqdm(range(1, self.T + 1), desc="Running", unit="step") as pbar:
            for t in pbar:
                if (remaining is not None and remaining <= 0.0) or stopped_by_budget:
                    logs.append(RoundLog(
                        t=t,
                        bid=0.0,
                        reward=0.0,
                        cost=0.0,
                        gamma_cost_minus_reward=0.0,
                        cum_gamma_loss=cum_gamma_loss,
                        budget_left=(None if remaining is None else max(0.0, remaining)),
                        info={"stopped": True, "reason": "budget_exhausted", "round": t}
                    ))
                    continue

                try:
                    ctx = self.env.get_context()
                except StopIteration:
                    stopped_by_budget = True
                    stop_round = stop_round or t
                    logs.append(RoundLog(
                        t=t,
                        bid=0.0,
                        reward=0.0,
                        cost=0.0,
                        gamma_cost_minus_reward=0.0,
                        cum_gamma_loss=cum_gamma_loss,
                        budget_left=(None if remaining is None else max(0.0, remaining)),
                        info={"stopped": True, "reason": "env_done", "round": t}
                    ))
                    continue

                bid = self._to_float(self.algo.act(ctx), 0.0)
                ret = self.env.step(bid)
                if not (hasattr(ret, "reward") and hasattr(ret, "cost") and hasattr(ret, "info")):
                    raise TypeError("env.step(bid) must return StepResult(reward, cost, info)")

                reward = self._to_float(ret.reward, 0.0)
                cost = self._to_float(ret.cost, 0.0)
                info: Dict[str, Any] = dict(ret.info) if isinstance(ret.info, dict) else {}
                steps_called += 1

                L_t = self.gamma * cost - reward
                cum_gamma_loss += L_t

                total_reward += reward
                total_cost += cost

                if remaining is not None:
                    remaining -= cost
                    if remaining <= 0.0:
                        stopped_by_budget = True
                        stop_round = stop_round or t

                info_all = {**info, "round": t, "bid": bid, "gamma": self.gamma,
                            "gamma_cost_minus_reward": L_t, "cum_gamma_loss": cum_gamma_loss,
                            "stopped": False}
                logs.append(RoundLog(
                    t=t,
                    bid=bid,
                    reward=reward,
                    cost=cost,
                    gamma_cost_minus_reward=L_t,
                    cum_gamma_loss=cum_gamma_loss,
                    budget_left=(None if remaining is None else max(0.0, remaining)),
                    info=info_all
                ))

                self.algo.update(reward, cost, info_all)
                
                pbar.set_postfix({
                    "Reward": f"{total_reward:.2f}",
                    "Cost": f"{total_cost:.2f}",
                    "ROI": f"{total_reward/total_cost if total_cost > 0 else 0:.2f}",
                    "GammaLoss": f"{cum_gamma_loss:.2f}"
                })

        return RunResult(
            rounds=logs,
            total_reward=total_reward,
            total_cost=total_cost,
            total_gamma_loss=cum_gamma_loss,
            steps_called=steps_called,
            steps_total=self.T,
            stopped_by_budget=stopped_by_budget,
            stop_round=stop_round,
        )
