from dataclasses import dataclass
from typing import Optional, Literal, Dict, Any
import numpy as np
import math


def truncated_normal(mean: float, std: float, low: float, high: float, size=None,
                     rng: Optional[np.random.Generator] = None):
    if rng is None:
        rng = np.random.default_rng()
    if size is None:
        size = 1
    out = np.empty(size, dtype=float)
    i = 0
    while i < size:
        draws = rng.normal(mean, std, size - i)
        draws = draws[(draws >= low) & (draws <= high)]
        k = min(len(draws), size - i)
        if k > 0:
            out[i:i + k] = draws[:k]
            i += k
    return out if size != 1 else float(out[0])


def compute_optimal_fixed_bid(env, T, B, gamma, grid_size=81, seed=0):
    rng = np.random.default_rng(seed)
    value_dist = ValueDist(kind="normal", vbar=env.vbar, rng=rng)
    comp_dist = CompetingBidDist(vbar=env.vbar, rng=rng)
    vs = value_dist.sample(T)
    ds = comp_dist.sample(T)

    cand_bids = np.linspace(0, env.vbar, grid_size)
    best_reward, best_cost, best_b = -1, None, None

    for b in cand_bids:
        won = (b >= ds)
        rewards = (vs - b) * won
        costs = b * won
        total_r = rewards.sum()
        total_c = costs.sum()

        if total_c <= B and total_r >= gamma * total_c:
            if total_r > best_reward:
                best_reward, best_cost, best_b = total_r, total_c, b

    return best_b, best_reward, best_cost


class ValueDist:
    def __init__(self, kind: Literal["normal", "lognormal", "uniform"] = "normal",
                 vbar: float = 1.0, rng: Optional[np.random.Generator] = None):
        self.kind = kind
        self.vbar = float(vbar)
        self.rng = rng or np.random.default_rng()

    def sample(self, size: Optional[int] = None):
        rng = self.rng
        if self.kind == "normal":
            return truncated_normal(0.6, 0.1, 0.0, self.vbar, size=size, rng=rng)
        elif self.kind == "lognormal":
            if size is None:
                x = float(rng.normal(-0.4, 0.1, 1)[0])
                v = float(math.exp(x))
                return min(max(0.0, v), self.vbar)
            xs = rng.normal(-0.4, 0.1, size)
            vs = np.exp(xs)
            return np.clip(vs, 0.0, self.vbar)
        elif self.kind == "uniform":
            if size is None:
                return float(rng.uniform(0.25, self.vbar))
            return rng.uniform(0.25, self.vbar, size)
        else:
            raise ValueError("unknown value kind")


class CompetingBidDist:
    def __init__(self, vbar: float = 1.0, rng: Optional[np.random.Generator] = None):
        self.vbar = float(vbar)
        self.rng = rng or np.random.default_rng()

    def sample(self, size: Optional[int] = None):
        return truncated_normal(0.4, 0.1, 0.0, self.vbar, size=size, rng=self.rng)


@dataclass
class StepResult:
    reward: float
    cost: float
    info: Dict[str, Any]


class FPAEnvSimple:
    def __init__(self, vbar: float = 1.0,
                 feedback: Literal["full", "one-sided"] = "one-sided",
                 value_kind: Literal["normal", "lognormal", "uniform"] = "normal",
                 seed: Optional[int] = 2025):
        self.vbar = float(vbar)
        self.feedback = feedback
        self.rng = np.random.default_rng(seed)
        self.value_dist = ValueDist(kind=value_kind, vbar=vbar, rng=self.rng)
        self.comp_dist = CompetingBidDist(vbar=vbar, rng=self.rng)

        self._cur_v: Optional[float] = None
        self._cur_d: Optional[float] = None
        self._has_context = False

    def reset(self, seed: Optional[int] = None):
        if seed is not None:
            self.rng = np.random.default_rng(seed)
            self.value_dist.rng = self.rng
            self.comp_dist.rng = self.rng
        self._cur_v = None
        self._cur_d = None
        self._has_context = False

    def get_context(self) -> Dict[str, Any]:
        self._cur_v = float(self.value_dist.sample())
        self._cur_d = float(self.comp_dist.sample())
        self._has_context = True
        return {
            "v": self._cur_v,
            "vbar": self.vbar,
            "d_compete": self._cur_d,
            "feedback_mode": self.feedback,
            "max_safe_bid": self.vbar,
        }

    def step(self, bid: float, keep_context: bool = False) -> StepResult:
        if not self._has_context:
            raise RuntimeError("Call get_context() before step().")

        v = float(self._cur_v)
        d = float(self._cur_d)

        b = 0.0 if (bid is None or not math.isfinite(float(bid))) else float(bid)
        b = max(0.0, min(b, self.vbar))

        won = (b >= d)
        reward = (v - b) if won else 0.0
        cost = b if won else 0.0

        if self.feedback == "full":
            d_obs = d
            d_compete = d
        else:
            d_obs = (None if won else d)
            d_compete = (None if won else d)

        info = {
            "bid": b,
            "won": bool(won),
            "v": v,
            "d_observed": d_obs,
            "d_compete": d_compete,
            "roi": (reward / cost) if cost > 0 else (float("inf") if reward > 0 else 0.0),
        }

        if not keep_context:
            self._has_context = False
        return StepResult(reward=reward, cost=cost, info=info)


if __name__ == "__main__":
    env = FPAEnvSimple(vbar=1.0, feedback="full", value_kind="normal", seed=2025)
    total_r = total_c = 0.0
    for _ in range(10):
        ctx = env.get_context()
        bid = 0.7 * ctx["v"]
        res = env.step(bid)
        print(f"v={ctx['v']:.3f}  bid={bid:.3f}  "
              f"won={res.info['won']}  reward={res.reward:.3f}  cost={res.cost:.3f}  "
              f"d_obs={res.info['d_observed']}")
        total_r += res.reward
        total_c += res.cost
    print(f"TOTAL: reward={total_r:.3f}, cost={total_c:.3f}, ROI={total_r/total_c if total_c>0 else 0.0:.3f}")
