import math
import bisect
import numpy as np
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt


class SkiRentalOptimizer:
    def __init__(self, b, C, predicted_days, predicted_probs):
        self.b = float(b)
        self.C = float(C)

        self.b_int = int(round(self.b))

        self.q_base = 1.0 + 1.0 / (self.b_int - 1)

        self.days = list(predicted_days)
        self.probs = list(predicted_probs)
        if len(self.days) != len(self.probs) or len(self.days) == 0:
            raise ValueError("predicted_days/probs must be non-empty lists of the same length.")

        s = sum(self.probs)
        if s <= 0:
            raise ValueError("Sum of probabilities must be positive.")
        self.probs = [p / s for p in self.probs]

        pairs = sorted(zip(self.days, self.probs), key=lambda x: x[0])
        self.days = [d for d, _ in pairs]
        self.probs = [p for _, p in pairs]

        self.segments = self._calculate_g_function_parameters()
        self.max_day = int(max(self.days))

    def _calculate_g_function_parameters(self):
        n = len(self.days)
        segments = []

        expected_val = sum(p * d for p, d in zip(self.probs, self.days))
        current_prob_sum_d = 0.0

        segments.append(
            {"start": 0.0, "end": self.days[0], "a_k": 1.0, "b_k": self.b - 1.0}
        )

        cum_prob = 0.0
        for k in range(n - 1):
            d_k = self.days[k]
            d_next = self.days[k + 1]
            p_k = self.probs[k]

            current_prob_sum_d += p_k * d_k
            cum_prob += p_k

            a_k = 1.0 - cum_prob
            b_k = current_prob_sum_d + (self.b - 1.0) * a_k
            segments.append({"start": d_k, "end": d_next, "a_k": a_k, "b_k": b_k})

        segments.append(
            {"start": self.days[-1], "end": float("inf"), "a_k": 0.0, "b_k": expected_val}
        )
        return segments

    def _get_g_value(self, t):
        if t == 0:
            return self.b - 1.0
        for seg in self.segments:
            if seg["start"] < t <= seg["end"]:
                return seg["a_k"] * t + seg["b_k"]
        return self.segments[-1]["b_k"]

    def _segment_integer_range(self, seg_start, seg_end):
        lo = 0 if seg_start <= 0.0 else int(seg_start) + 1
        hi = int(seg_end)
        return lo, hi

    def is_feasible(self, h):
        F = 0.0
        mu = 0.0
        b = self.b_int

        for seg in self.segments:
            start_interval = seg["start"]
            end_interval = seg["end"]
            if start_interval >= b:
                break

            effective_end = min(end_interval, b)
            a_k = seg["a_k"]
            b_k = seg["b_k"]

            lo, hi = self._segment_integer_range(start_interval, effective_end)
            if lo > hi:
                continue

            if a_k > 1e-12:
                intersection_x = (h - b_k) / a_k
                e_i = min(hi, int(math.floor(intersection_x + 1e-12)))
            else:
                e_i = hi if (b_k <= h + 1e-12) else (lo - 1)

            if e_i >= lo:
                s = lo

                current_load = mu + (b - s) * F
                target_budget = (self.C - 1.0) * s
                if target_budget > current_load + 1e-9:
                    mass = (target_budget - current_load) / (b - 1)
                    F += mass
                    mu += mass * s
                    if F >= 1.0 - 1e-12:
                        return True

                step = e_i - s
                if step > 0 and F < 1.0 - 1e-12:
                    F_end = (F + (self.C - 1.0)) * (self.q_base ** step) - (self.C - 1.0)
                    if F_end >= 1.0 - 1e-12:
                        return True
                    F = F_end
                    mu = (self.C - 1.0) * e_i - (b - e_i) * F

            if F >= 1.0 - 1e-12:
                return True

        remaining_prob = 1.0 - F
        if remaining_prob <= 1e-9:
            return True

        max_allowed_d = ((self.C - 1.0) * b - mu) / remaining_prob
        start_idx = bisect.bisect_right(self.days, b)

        for i in range(start_idx, len(self.days)):
            d_i = self.days[i]
            if self._get_g_value(d_i) <= h + 1e-12 and d_i <= max_allowed_d + 1e-12:
                return True

        return False

    def solve(self, tolerance=1e-5):
        low = 0.0
        high = self.b + self.days[-1]
        optimal_h = high

        while high - low > tolerance:
            mid = (low + high) / 2.0
            if self.is_feasible(mid):
                optimal_h = mid
                high = mid
            else:
                low = mid

        return optimal_h

    def construct_policy_pmf(self, h):
        b = self.b_int
        F = 0.0
        mu = 0.0
        pmf = {}

        for seg in self.segments:
            start_interval = seg["start"]
            end_interval = seg["end"]
            if start_interval >= b:
                break

            effective_end = min(end_interval, b)
            a_k = seg["a_k"]
            b_k = seg["b_k"]

            lo, hi = self._segment_integer_range(start_interval, effective_end)
            if lo > hi:
                continue

            if a_k > 1e-12:
                intersection_x = (h - b_k) / a_k
                e_i = min(hi, int(math.floor(intersection_x + 1e-12)))
            else:
                e_i = hi if (b_k <= h + 1e-12) else (lo - 1)

            if e_i < lo:
                continue

            s = lo

            current_load = mu + (b - s) * F
            target_budget = (self.C - 1.0) * s
            if target_budget > current_load + 1e-9:
                mass = (target_budget - current_load) / (b - 1)
                if mass > 0:
                    pmf[s] = pmf.get(s, 0.0) + mass
                    F += mass
                    mu += mass * s
                    if F >= 1.0 - 1e-12:
                        total = sum(pmf.values())
                        if abs(total - 1.0) > 1e-8:
                            pmf[s] += 1.0 - total
                        return pmf

            x = s
            while x < e_i and F < 1.0 - 1e-12:
                F_next = (F + (self.C - 1.0)) * self.q_base - (self.C - 1.0)
                F_next = min(1.0, max(0.0, F_next))
                inc = max(0.0, F_next - F)
                if inc > 0:
                    pmf[x + 1] = pmf.get(x + 1, 0.0) + inc
                F = F_next
                x += 1

            if F < 1.0 - 1e-12:
                mu = (self.C - 1.0) * e_i - (b - e_i) * F

        remaining = 1.0 - F
        if remaining <= 1e-9:
            total = sum(pmf.values())
            if abs(total - 1.0) > 1e-8:
                k = min(pmf.keys())
                pmf[k] += 1.0 - total
            return pmf

        max_allowed_d = ((self.C - 1.0) * b - mu) / remaining
        start_idx = bisect.bisect_right(self.days, b)

        candidates = []
        for i in range(start_idx, len(self.days)):
            d = self.days[i]
            if self._get_g_value(d) <= h + 1e-12 and d <= max_allowed_d + 1e-12:
                candidates.append(d)

        if not candidates:
            raise RuntimeError("Tail became infeasible; check feasibility/construction consistency.")

        best_d = min(candidates, key=lambda d: (self._get_g_value(d), d))
        pmf[best_d] = pmf.get(best_d, 0.0) + remaining

        total = sum(pmf.values())
        if abs(total - 1.0) > 1e-6:
            pmf[best_d] += 1.0 - total

        return pmf

    def expected_cost(self, pmf):
        return sum(prob * self._get_g_value(day) for day, prob in pmf.items())

    def min_g_over_integer_thresholds(self):
        t_max = self.max_day + 1
        best = float("inf")
        for t in range(1, t_max + 1):
            best = min(best, self._get_g_value(t))
        return best


def purohit_lambda_from_R(R, b):
    inside = 1.0 - (1.0 + 1.0 / b) / R
    if inside <= 0:
        raise ValueError("Invalid (R,b)")
    return 1.0 / b - math.log(inside)


def prob_Y_ge_b(days, probs, b):
    return sum(p for d, p in zip(days, probs) if d >= b)


def purohit_branch_pmf(b, lam, branch):
    b = int(round(b))
    alpha = (b - 1.0) / b

    if branch == "ge":
        k = int(math.floor(lam * b + 1e-12))
        k = max(1, k)
        denom = b * (1.0 - alpha ** k)
        return {i: (alpha ** (k - i)) / denom for i in range(1, k + 1)}

    if branch == "lt":
        ell = int(math.ceil(b / lam - 1e-12))
        ell = max(1, ell)
        denom = b * (1.0 - alpha ** ell)
        return {i: (alpha ** (ell - i)) / denom for i in range(1, ell + 1)}

    raise ValueError("branch must be 'ge' or 'lt'")


def purohit_baseline_majority(days, probs, b, R):
    lam = purohit_lambda_from_R(R, b)
    P = prob_Y_ge_b(days, probs, b)
    branch = "ge" if P > 0.5 else "lt"
    pmf = purohit_branch_pmf(b, lam, branch)
    return pmf, P


def purohit_baseline_mixture(days, probs, b, R):
    lam = purohit_lambda_from_R(R, b)
    P = prob_Y_ge_b(days, probs, b)

    q = purohit_branch_pmf(b, lam, "ge")
    r = purohit_branch_pmf(b, lam, "lt")

    pmf = {}
    for j, pr in q.items():
        pmf[j] = pmf.get(j, 0.0) + P * pr
    for j, pr in r.items():
        pmf[j] = pmf.get(j, 0.0) + (1.0 - P) * pr

    total = sum(pmf.values())
    if abs(total - 1.0) > 1e-10:
        j0 = min(pmf.keys())
        pmf[j0] += 1.0 - total

    return pmf, P


def expand_to_full_grid(days, probs, M):
    full_probs = [0.0] * (M + 1)
    s = 0.0
    for d, p in zip(days, probs):
        full_probs[int(d)] += float(p)
        s += float(p)
    full_probs = [x / s for x in full_probs]
    full_days = list(range(1, M + 1))
    return full_days, [full_probs[d] for d in full_days]


def wasserstein1_distance_on_line(p, q):
    cp = 0.0
    cq = 0.0
    dist = 0.0
    for i in range(len(p)):
        cp += p[i]
        cq += q[i]
        dist += abs(cp - cq)
    return dist


def eta_full_bidirectional(days, probs, M):
    _, full_probs = expand_to_full_grid(days, probs, M)
    left = sum(full_probs[i - 1] * (i - 1) for i in range(1, M + 1))
    right = sum(full_probs[i - 1] * (M - i) for i in range(1, M + 1))
    return max(left, right)


def make_p_hat_random_transport(days, probs, eta, M, rng=None, max_iters=200000):
    full_days, full_probs = expand_to_full_grid(days, probs, M)
    p = full_probs[:]

    if rng is None:
        rng = np.random.default_rng()

    if eta <= 0:
        return full_days, p[:], 0.0

    q = p[:]
    remaining = p[:]
    budget = float(eta)

    idxs = np.arange(len(p))

    for _ in range(max_iters):
        if budget <= 1e-15:
            break
        total_rem = float(sum(remaining))
        if total_rem <= 1e-15:
            break

        src = int(rng.choice(idxs, p=np.array(remaining) / total_rem))
        if remaining[src] <= 1e-18:
            continue

        dst = int(rng.integers(0, len(p) - 1))
        if dst >= src:
            dst += 1

        dist = abs(dst - src)
        if dist <= 0:
            continue

        max_move = min(remaining[src], budget / dist)
        if max_move <= 1e-18:
            continue

        m = max_move * float(rng.random())
        if m <= 1e-18:
            continue

        q[src] -= m
        q[dst] += m
        remaining[src] -= m
        budget -= m * dist

    q = [max(0.0, x) for x in q]
    s = sum(q)
    q = [x / s for x in q]

    achieved = wasserstein1_distance_on_line(p, q)
    return full_days, q, achieved


def make_discrete_gaussian(mu, sigma, M):
    days = list(range(1, M + 1))
    w = [math.exp(-0.5 * ((d - mu) / sigma) ** 2) for d in days]
    s = sum(w)
    probs = [x / s for x in w]
    return days, probs


def plot_paper_style(eta_vals, ours_mean, maj_mean, mix_mean, filename):
    plt.figure(figsize=(7.5, 4.8))
    plt.plot(eta_vals, ours_mean, marker="o", linewidth=2, label="Ours")
    plt.plot(eta_vals, maj_mean, marker="s", linewidth=2, label="Purohit (majority)")
    plt.plot(eta_vals, mix_mean, marker="^", linewidth=2, label="Purohit (mixture)")
    plt.xlabel(r"Wasserstein distance budget $\eta$")
    plt.ylabel("Consistency")
    plt.grid(True, alpha=0.3)
    plt.legend(loc="best")
    plt.tight_layout()
    plt.savefig(filename, dpi=300)
    plt.close()


def run_experiment(
    b_cost=50,
    r_ratio=1.7,
    mu=90,
    sigma=12,
    base_m=150,
    z=6.0,
    n_eta=41,
    n_samples_per_eta=25,
    base_seed=20260103,
    max_iters=200000,
    tolerance=1e-5,
):
    M = max(base_m, int(math.ceil(mu + z * sigma)))

    days_p, probs_p = make_discrete_gaussian(mu, sigma, M)
    eval_opt = SkiRentalOptimizer(b_cost, r_ratio, days_p, probs_p)
    gmin_p = eval_opt.min_g_over_integer_thresholds()

    eta_full = eta_full_bidirectional(days_p, probs_p, M)
    eta_vals = np.linspace(0.0, eta_full, n_eta)

    ours_mean = []
    maj_mean = []
    mix_mean = []

    case_hash = int(mu * 1000 + sigma * 100)

    for i_eta, eta in enumerate(eta_vals):
        cons_ours_samples = []
        cons_maj_samples = []
        cons_mix_samples = []

        for s_idx in range(n_samples_per_eta):
            rng = np.random.default_rng(base_seed + 100000 * case_hash + 1000 * i_eta + s_idx)

            try:
                days_hat, probs_hat, _ = make_p_hat_random_transport(
                    days_p, probs_p, eta, M, rng=rng, max_iters=max_iters
                )

                opt_hat = SkiRentalOptimizer(b_cost, r_ratio, days_hat, probs_hat)
                h_hat = opt_hat.solve(tolerance=tolerance)
                pmf_ours = opt_hat.construct_policy_pmf(h_hat)

                pmf_maj, _ = purohit_baseline_majority(days_hat, probs_hat, b_cost, r_ratio)
                pmf_mix, _ = purohit_baseline_mixture(days_hat, probs_hat, b_cost, r_ratio)

                cons_ours_samples.append(eval_opt.expected_cost(pmf_ours) / gmin_p)
                cons_maj_samples.append(eval_opt.expected_cost(pmf_maj) / gmin_p)
                cons_mix_samples.append(eval_opt.expected_cost(pmf_mix) / gmin_p)
            except Exception:
                cons_ours_samples.append(np.nan)
                cons_maj_samples.append(np.nan)
                cons_mix_samples.append(np.nan)

        ours_mean.append(float(np.nanmean(np.array(cons_ours_samples, dtype=float))))
        maj_mean.append(float(np.nanmean(np.array(cons_maj_samples, dtype=float))))
        mix_mean.append(float(np.nanmean(np.array(cons_mix_samples, dtype=float))))

    out_png = f"paper_mu{mu}_sigma{sigma}_b{b_cost}_R{r_ratio}_avg{n_samples_per_eta}.png"
    plot_paper_style(eta_vals, ours_mean, maj_mean, mix_mean, filename=out_png)
    print(f"[saved] {out_png}")
    return eta_vals, ours_mean, maj_mean, mix_mean, out_png


if __name__ == "__main__":
    run_experiment()
