import math
import bisect


class SkiRentalOptimizer:
    """
    Discrete ski-rental randomized policy optimizer (water-filling + bisection).

    """

    def __init__(self, b, C, predicted_days, predicted_probs):
        self.b = float(b)
        self.C = float(C)

        self.b_int = int(round(self.b))

        self.q_base = 1.0 + 1.0 / (self.b_int - 1)

        self.days = list(predicted_days)
        self.probs = list(predicted_probs)

        if len(self.days) != len(self.probs) or len(self.days) == 0:
            raise ValueError("predicted_days/probs must be non-empty lists of the same length.")

        s = sum(self.probs)
        if s <= 0:
            raise ValueError("Sum of probabilities must be positive.")
        self.probs = [p / s for p in self.probs]

        pairs = sorted(zip(self.days, self.probs), key=lambda x: x[0])
        self.days = [d for d, _ in pairs]
        self.probs = [p for _, p in pairs]

        self.segments = self._calculate_g_function_parameters()
        self.max_day = int(max(self.days))

    def _calculate_g_function_parameters(self):
        n = len(self.days)
        segments = []

        expected_val = sum(p * d for p, d in zip(self.probs, self.days))
        current_prob_sum_d = 0.0

        segments.append(
            {"start": 0.0, "end": self.days[0], "a_k": 1.0, "b_k": self.b - 1.0}
        )

        cum_prob = 0.0
        for k in range(n - 1):
            d_k = self.days[k]
            d_next = self.days[k + 1]
            p_k = self.probs[k]

            current_prob_sum_d += p_k * d_k
            cum_prob += p_k

            a_k = 1.0 - cum_prob
            b_k = current_prob_sum_d + (self.b - 1.0) * a_k

            segments.append({"start": d_k, "end": d_next, "a_k": a_k, "b_k": b_k})

        segments.append(
            {"start": self.days[-1], "end": float("inf"), "a_k": 0.0, "b_k": expected_val}
        )
        return segments

    def _get_g_value(self, t):
        if t == 0:
            return self.b - 1.0

        for seg in self.segments:
            if seg["start"] < t <= seg["end"]:
                return seg["a_k"] * t + seg["b_k"]
        return self.segments[-1]["b_k"]

    def _segment_integer_range(self, seg_start, seg_end):
        lo = 0 if seg_start <= 0.0 else int(seg_start) + 1
        hi = int(seg_end)
        return lo, hi

    def is_feasible(self, h):
        F = 0.0
        mu = 0.0
        b = self.b_int

        for seg in self.segments:
            start_interval = seg["start"]
            end_interval = seg["end"]
            if start_interval >= b:
                break

            effective_end = min(end_interval, b)
            a_k = seg["a_k"]
            b_k = seg["b_k"]

            lo, hi = self._segment_integer_range(start_interval, effective_end)
            if lo > hi:
                continue

            if a_k > 1e-12:
                intersection_x = (h - b_k) / a_k
                e_i = min(hi, int(math.floor(intersection_x + 1e-12)))
            else:
                e_i = hi if (b_k <= h + 1e-12) else (lo - 1)

            if e_i >= lo:
                s = lo

                current_load = mu + (b - s) * F
                target_budget = (self.C - 1.0) * s
                if target_budget > current_load + 1e-9:
                    mass = (target_budget - current_load) / (b - 1)
                    F += mass
                    mu += mass * s
                    if F >= 1.0 - 1e-12:
                        return True

                step = e_i - s
                if step > 0 and F < 1.0 - 1e-12:
                    F_end = (F + (self.C - 1.0)) * (self.q_base ** step) - (self.C - 1.0)
                    if F_end >= 1.0 - 1e-12:
                        return True
                    F = F_end
                    mu = (self.C - 1.0) * e_i - (b - e_i) * F

            if F >= 1.0 - 1e-12:
                return True

        remaining_prob = 1.0 - F
        if remaining_prob <= 1e-9:
            return True

        max_allowed_d = ((self.C - 1.0) * b - mu) / remaining_prob
        start_idx = bisect.bisect_right(self.days, b)

        for i in range(start_idx, len(self.days)):
            d_i = self.days[i]
            if self._get_g_value(d_i) <= h + 1e-12 and d_i <= max_allowed_d + 1e-12:
                return True

        return False

    def solve(self, tolerance=1e-5):
        low = 0.0
        high = self.b + self.days[-1]
        optimal_h = high

        while high - low > tolerance:
            mid = (low + high) / 2.0
            if self.is_feasible(mid):
                optimal_h = mid
                high = mid
            else:
                low = mid

        return optimal_h

    def construct_policy_pmf(self, h):
        b = self.b_int
        F = 0.0
        mu = 0.0
        pmf = {}

        for seg in self.segments:
            start_interval = seg["start"]
            end_interval = seg["end"]
            if start_interval >= b:
                break

            effective_end = min(end_interval, b)
            a_k = seg["a_k"]
            b_k = seg["b_k"]

            lo, hi = self._segment_integer_range(start_interval, effective_end)
            if lo > hi:
                continue

            if a_k > 1e-12:
                intersection_x = (h - b_k) / a_k
                e_i = min(hi, int(math.floor(intersection_x + 1e-12)))
            else:
                e_i = hi if (b_k <= h + 1e-12) else (lo - 1)

            if e_i < lo:
                continue

            s = lo

            current_load = mu + (b - s) * F
            target_budget = (self.C - 1.0) * s
            if target_budget > current_load + 1e-9:
                mass = (target_budget - current_load) / (b - 1)
                if mass > 0:
                    pmf[s] = pmf.get(s, 0.0) + mass
                    F += mass
                    mu += mass * s
                    if F >= 1.0 - 1e-12:
                        total = sum(pmf.values())
                        if abs(total - 1.0) > 1e-8:
                            pmf[s] += 1.0 - total
                        return pmf

            x = s
            while x < e_i and F < 1.0 - 1e-12:
                F_next = (F + (self.C - 1.0)) * self.q_base - (self.C - 1.0)
                F_next = min(1.0, max(0.0, F_next))
                inc = max(0.0, F_next - F)
                if inc > 0:
                    pmf[x + 1] = pmf.get(x + 1, 0.0) + inc
                F = F_next
                x += 1

            if F < 1.0 - 1e-12:
                mu = (self.C - 1.0) * e_i - (b - e_i) * F

        remaining = 1.0 - F
        if remaining <= 1e-9:
            total = sum(pmf.values())
            if abs(total - 1.0) > 1e-8 and pmf:
                k = min(pmf.keys())
                pmf[k] += 1.0 - total
            return pmf

        max_allowed_d = ((self.C - 1.0) * b - mu) / remaining
        start_idx = bisect.bisect_right(self.days, b)

        candidates = []
        for i in range(start_idx, len(self.days)):
            d = self.days[i]
            if self._get_g_value(d) <= h + 1e-12 and d <= max_allowed_d + 1e-12:
                candidates.append(d)

        if not candidates:
            raise RuntimeError("Tail became infeasible; check consistency between feasibility and construction.")

        best_d = min(candidates, key=lambda d: (self._get_g_value(d), d))
        pmf[best_d] = pmf.get(best_d, 0.0) + remaining

        total = sum(pmf.values())
        if abs(total - 1.0) > 1e-6:
            pmf[best_d] += 1.0 - total

        return pmf

    def expected_cost(self, pmf):
        return sum(prob * self._get_g_value(day) for day, prob in pmf.items())

    def min_g_over_integer_thresholds(self):
        t_max = self.max_day + 1
        best = float("inf")
        for t in range(1, t_max + 1):
            best = min(best, self._get_g_value(t))
        return best

    def consistency_actual(self, pmf):
        Eg = self.expected_cost(pmf)
        gmin = self.min_g_over_integer_thresholds()
        return Eg / gmin, Eg, gmin


def baseline_lambda_from_R(R, b):
    inside = 1.0 - (1.0 + 1.0 / b) / R
    if inside <= 0:
        raise ValueError("Invalid (R,b)")
    return 1.0 / b - math.log(inside)


def prob_Y_ge_b(days, probs, b):
    return sum(p for d, p in zip(days, probs) if d >= b)


def baseline_branch_pmf(b, lam, branch):
    b = int(round(b))
    alpha = (b - 1.0) / b

    if branch == "ge":
        k = int(math.floor(lam * b + 1e-12))
        k = max(1, k)
        denom = b * (1.0 - alpha ** k)
        return {i: (alpha ** (k - i)) / denom for i in range(1, k + 1)}

    if branch == "lt":
        ell = int(math.ceil(b / lam - 1e-12))
        ell = max(1, ell)
        denom = b * (1.0 - alpha ** ell)
        return {i: (alpha ** (ell - i)) / denom for i in range(1, ell + 1)}

    raise ValueError("branch must be 'ge' or 'lt'")


def baseline_majority(days, probs, b, R):
    lam = baseline_lambda_from_R(R, b)
    P = prob_Y_ge_b(days, probs, b)
    pmf = baseline_branch_pmf(b, lam, "ge" if P > 0.5 else "lt")
    return pmf, P


def baseline_mixture(days, probs, b, R):
    lam = baseline_lambda_from_R(R, b)
    P = prob_Y_ge_b(days, probs, b)

    q = baseline_branch_pmf(b, lam, "ge")
    r = baseline_branch_pmf(b, lam, "lt")

    pmf = {}
    for j, pr in q.items():
        pmf[j] = pmf.get(j, 0.0) + P * pr
    for j, pr in r.items():
        pmf[j] = pmf.get(j, 0.0) + (1.0 - P) * pr

    total = sum(pmf.values())
    if abs(total - 1.0) > 1e-10:
        j0 = min(pmf.keys())
        pmf[j0] += 1.0 - total

    return pmf, P


if __name__ == "__main__":
    B_COST = 50
    R_RATIO = 1.7

    cases = []

    M = 100
    days = list(range(1, M + 1))
    probs = [1.0 / M] * M
    cases.append(("unif100", days, probs))

    M = 200
    days = list(range(1, M + 1))
    probs = [1.0 / M] * M
    cases.append(("unif200", days, probs))

    mu, sigma, M = 50, 12, 150
    days = list(range(1, M + 1))
    w = [math.exp(-0.5 * ((d - mu) / sigma) ** 2) for d in days]
    s = sum(w)
    probs = [x / s for x in w]
    cases.append(("gauss", days, probs))

    p, M = 0.05, 600
    q = 1.0 - p
    days = list(range(1, M + 1))
    w = [(q ** (d - 1)) * p for d in days]
    s = sum(w)
    probs = [x / s for x in w]
    cases.append(("geom", days, probs))

    days = [30, 120]
    probs = [0.7, 0.3]
    cases.append(("twopoint", days, probs))

    for name, days, probs in cases:
        opt = SkiRentalOptimizer(B_COST, R_RATIO, days, probs)
        h = opt.solve(tolerance=1e-5)
        pmf_method = opt.construct_policy_pmf(h)
        cons_method, Eg_method, gmin = opt.consistency_actual(pmf_method)

        pmf_maj, P = baseline_majority(days, probs, B_COST, R_RATIO)
        cons_maj = opt.expected_cost(pmf_maj) / gmin

        pmf_mix, _ = baseline_mixture(days, probs, B_COST, R_RATIO)
        cons_mix = opt.expected_cost(pmf_mix) / gmin

        print(
            f"{name}: P(Y>=b)={P:.4f}, h={h:.4f}, "
            f"E[g(Z)]={Eg_method:.6f}, gmin={gmin:.6f}, "
            f"method={cons_method:.6f}, baseline_majority={cons_maj:.6f}, baseline_mixture={cons_mix:.6f}"
        )
