#!/usr/bin/env python3

from math import exp, sqrt, erf, pi
import random
import matplotlib.pyplot as plt

# ---- helpers

def discretization(a: float, b: float, steps: int = 200):
    for i in range(steps + 1):
        yield a + i * (b - a) / steps

def ratio(p_star: float, T: float) -> float:
    return p_star / T if p_star >= T else p_star

# ---- weights on z = |p* - p̂| / h in [0,1]

def linear_weight(z: float) -> float:
    return 1.0 - z

K = 2

def polynomial_weight(z: float) -> float:
    z = abs(z)
    if z < 0.5:
        return 1.0 - 2 ** (K - 1) * (z ** K)
    return 2 ** (K - 1) * ((1.0 - z) ** K)

SIGMA = 0.25

def _unnorm_gauss(z: float) -> float:
    return exp(-0.5 * (z / SIGMA) ** 2)

def _unnorm_gauss_int(z: float) -> float:
    return sqrt(pi / 2.0) * SIGMA * erf(z / (sqrt(2.0) * SIGMA))

# area on [-1,1]
_GNORM = _unnorm_gauss_int(+1.0) - _unnorm_gauss_int(-1.0)

def gaussian_weight(z: float) -> float:
    return _unnorm_gauss(z) / _GNORM

def gaussian_integration(z: float) -> float:
    return _unnorm_gauss_int(z) / _GNORM

# ---- objectives on [p̂-h, p̂+h]

def _weighted_diff(p_hat: float, p_star: float, h: float, T: float, w) -> float:
    return (ratio(p_star, T) - 1.0) * w(abs(p_star - p_hat) / h)

def _max_pstar(p_hat: float, h: float, T: float, w):
    return max((_weighted_diff(p_hat, ps, h, T, w), ps)
               for ps in discretization(p_hat - h, p_hat + h))

def _sum_pstar(p_hat: float, h: float, T: float, w) -> float:
    return sum(_weighted_diff(p_hat, ps, h, T, w)
               for ps in discretization(p_hat - h, p_hat + h)) / (2.0 * h)

def minmax_T(p_hat: float, h: float, w):
    return min((_max_pstar(p_hat, h, T, w), T)
               for T in discretization(p_hat - h, p_hat + h))

def minsum_T(p_hat: float, h: float, w):
    return min((_sum_pstar(p_hat, h, T, w), T)
               for T in discretization(p_hat - h, p_hat + h))

def maxcvar_T(p_hat: float, h: float, alpha: float):
    def cvar(T: float) -> float:
        x = (T - p_hat) / h
        qT = gaussian_integration(x) - gaussian_integration(-1.0)
        return (T * (1.0 - alpha - qT) + qT) / (1.0 - alpha)
    best = max((cvar(T), T) for T in discretization(p_hat - h, p_hat + h))
    alt = p_hat - h
    return max(best, (alt, alt))

def average_improvement(p_hat: float, h: float, T_star: float):
    xs = list(discretization(p_hat - h, p_hat + h))
    T_list = [p_hat - h, p_hat, T_star]
    ys = [tuple(ratio(ps, T) for T in T_list) for ps in xs]
    imp_ph_h = sum(1 for r0, _, r2 in ys if r2 < r0) / len(ys)
    imp_ph   = sum(1 for _, r1, r2 in ys if r2 < r1) / len(ys)
    avg_r0   = sum(r0 for r0, _, _ in ys) / len(ys)
    avg_r1   = sum(r1 for _, r1, _ in ys) / len(ys)
    avg_r2   = sum(r2 for _, _, r2 in ys) / len(ys)
    return imp_ph_h, imp_ph, avg_r0, avg_r1, avg_r2

# ---- demo

if __name__ == "__main__":
    random.seed(42)
    p_hat = 500.0
    h = 480.0

    # choose objective here:
    T_star = minmax_T(p_hat, h, linear_weight)[1]
    # T_star = minsum_T(p_hat, h, linear_weight)[1]
    # T_star = maxcvar_T(p_hat, h, 0.5)[1]

    xs = list(discretization(p_hat - h, p_hat + h))
    T_list = [p_hat - h, p_hat, T_star]
    curves = list(zip(*[tuple(ratio(ps, T) for T in T_list) for ps in xs]))

    plt.figure(figsize=(8, 6))
    plt.plot(xs, curves[2], linewidth=3, label='T*')
    plt.plot(xs, curves[1], linewidth=3, label='T = p̂')
    plt.plot(xs, curves[0], linewidth=3, label='T = p̂ - h')
    plt.xlabel('Max price')
    plt.ylabel('Performance ratio')
    plt.grid(True)
    plt.legend()
    plt.xticks([p_hat - h, T_star, p_hat, p_hat + h], [r'$\hat{p}-h$', r'$T^*$', r'$\hat{p}$', r'$\hat{p}+h$'])
    plt.tight_layout()
    plt.savefig('one_max_plot.pdf', format='pdf')

    imp_ph_h, imp_ph, avg_r0, avg_r1, avg_r2 = average_improvement(p_hat, h, T_star)
    print(f"Improvement vs T=p̂-h: {imp_ph_h:.2%}")
    print(f"Improvement vs T=p̂  : {imp_ph:.2%}")
    print(f"Avg ratio (p̂-h)     : {avg_r0:.6f}")
    print(f"Avg ratio (p̂)       : {avg_r1:.6f}")
    print(f"Avg ratio (T*)       : {avg_r2:.6f}")
