import os
import numpy as np
import random
import matplotlib.pyplot as plt


class BaseAlgorithm:
    def __init__(self, n_arms):
        self.n_arms = n_arms

    def choose_arm(self):
        raise NotImplementedError

    def update(self, feedback):
        raise NotImplementedError


class SwapRegretHedge(BaseAlgorithm):
    def __init__(self, n_arms, eta, station_iters=20,
                 init_weights=None, init_logweights=None, prior_strength: float = 1.0):
        super().__init__(n_arms)
        self.eta = eta
        self.station_iters = station_iters

        if init_logweights is not None:
            lw = np.asarray(init_logweights, dtype=float)
            if lw.shape != (n_arms, n_arms):
                raise ValueError("init_logweights must have shape (n_arms, n_arms)")
            self.logw = lw.copy()
        elif init_weights is not None:
            w0 = np.asarray(init_weights, dtype=float)
            if w0.shape != (n_arms, n_arms):
                raise ValueError("init_weights must have shape (n_arms, n_arms)")
            if np.any(w0 <= 0):
                raise ValueError("init_weights must be strictly positive")
            self.logw = prior_strength * np.log(w0)
        else:
            self.logw = np.zeros((n_arms, n_arms))

        self.P = np.ones(n_arms) / n_arms
        self.last_arm = None

    def choose_arm(self):
        m = np.max(self.logw, axis=1, keepdims=True)
        expw = np.exp(self.logw - m)
        Q = expw / expw.sum(axis=1, keepdims=True)

        p = np.ones(self.n_arms) / self.n_arms
        for _ in range(self.station_iters):
            p = p @ Q
        p /= p.sum()

        self.P = p
        a = np.random.choice(self.n_arms, p=p)
        self.last_arm = a
        return a

    def update(self, losses):
        L = np.outer(self.P, losses)
        self.logw -= self.eta * L

    def reset_weights(self, init_weights=None, init_logweights=None, prior_strength: float = 1.0):
        if init_logweights is not None:
            lw = np.asarray(init_logweights, dtype=float)
            if lw.shape != (self.n_arms, self.n_arms):
                raise ValueError("init_logweights must have shape (n_arms, n_arms)")
            self.logw = lw.copy()
        elif init_weights is not None:
            w0 = np.asarray(init_weights, dtype=float)
            if w0.shape != (self.n_arms, self.n_arms):
                raise ValueError("init_weights must have shape (n_arms, n_arms)")
            if np.any(w0 <= 0):
                raise ValueError("init_weights must be strictly positive")
            self.logw = prior_strength * np.log(w0)
        else:
            self.logw = np.zeros((self.n_arms, self.n_arms))


class Principal:
    def __init__(self, id, discretization, algorithm: BaseAlgorithm, cost: float = 0.0):
        self.id = id
        self.arms = discretization
        self.alg = algorithm
        self.cost = float(cost)
        self.price_history = []
        self.profit_history = []
        self.last_arm = None

    def select_price(self):
        arm = self.alg.choose_arm()
        self.last_arm = arm
        price = self.arms[arm]
        self.price_history.append(price)
        return price

    def update(self, fb):
        self.alg.update(fb)


class PricingGame:
    def __init__(self, principals):
        self.principals = principals
        self.accepted_price_history = []

    def run(self, T):
        n = len(self.principals)
        for _ in range(T):
            prices = [p.select_price() for p in self.principals]
            self.accepted_price_history.append(min(prices))

            for i, p in enumerate(self.principals):
                other_prices = [prices[k] for k in range(n) if k != i]
                min_other = min(other_prices) if other_prices else np.inf
                num_min_others = sum(1 for k in range(n) if k != i and prices[k] == min_other)

                profits = np.zeros(p.alg.n_arms, dtype=float)
                losses = np.zeros_like(profits)

                for j, price_j in enumerate(p.arms):
                    if price_j < min_other:
                        prof = price_j - p.cost
                    elif price_j == min_other:
                        prof = (price_j - p.cost) / (num_min_others + 1)
                    else:
                        prof = 0.0

                    profits[j] = prof
                    losses[j] = 0.5 * (1.0 - prof)

                chosen = p.last_arm
                p.profit_history.append(profits[chosen])
                p.update(losses)


def run_and_plot(
    discretizations, costs, T, sample_iv, title,
    station_iters=20, init_swap_weights=None, prior_strength: float = 1.0,
    save=True, out_dir="outputs", basename="run_and_plot", dpi=300, show=True
):
    assert len(discretizations) == len(costs)
    etas = [np.sqrt(np.log(len(d)) / T) for d in discretizations]

    def _script_dir():
        try:
            return os.path.dirname(os.path.abspath(__file__))
        except NameError:
            return os.getcwd()

    def _ensure_dir(path):
        os.makedirs(path, exist_ok=True)
        return path

    def _tag_num(x: float) -> str:
        s = f"{float(x):g}"
        return s.replace("-", "m").replace(".", "p")

    def W0_for(i):
        if init_swap_weights is None:
            return None
        if callable(init_swap_weights):
            return init_swap_weights(i, np.asarray(discretizations[i], dtype=float), float(costs[i]))
        return init_swap_weights[i] if i < len(init_swap_weights) else None

    principals = [
        Principal(
            i,
            discretizations[i],
            SwapRegretHedge(
                n_arms=len(discretizations[i]),
                eta=etas[i] * (1 + 10 * i),
                station_iters=station_iters,
                init_weights=W0_for(i),
                prior_strength=prior_strength,
            ),
            cost=costs[i],
        )
        for i in range(len(discretizations))
    ]

    game = PricingGame(principals)
    game.run(T)

    x = np.arange(0, T, sample_iv)

    k_tag = f"k_{len(discretizations[0]) - 1}"
    T_tag = f"T_{int(T)}"
    iv_tag = f"iv_{int(sample_iv)}"
    st_tag = f"station_{int(station_iters)}"
    cost_tag = "costs_" + "_".join(f"P{i}_{_tag_num(c)}" for i, c in enumerate(costs))
    base_tag = f"{basename}_{k_tag}_{T_tag}_{iv_tag}_{st_tag}_{cost_tag}"

    out_path = None
    if save:
        out_path = _ensure_dir(os.path.join(_script_dir(), out_dir))

    def _maybe_save(fig, suffix):
        if not save:
            return
        png_path = os.path.join(out_path, f"{base_tag}_{suffix}.png")
        pdf_path = os.path.join(out_path, f"{base_tag}_{suffix}.pdf")
        fig.savefig(png_path, dpi=dpi, bbox_inches="tight")
        fig.savefig(pdf_path, dpi=dpi, bbox_inches="tight")
        print(f"Saved plot:\n  {png_path}\n  {pdf_path}")

    # Time-series plots (unchanged)
    fig1, ax1 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    fig1.suptitle(title, fontsize=14)
    for p in principals:
        ys = [p.price_history[t] for t in x]
        ax1.scatter(x, ys, s=10, label=f"P{p.id}")
    ax1.set_title("Chosen Prices")
    ax1.set_xlabel("Round")
    ax1.set_ylabel("Price")
    ax1.legend()
    fig1.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(fig1, "timeseries_chosen_prices")
    if show:
        plt.show()
    else:
        plt.close(fig1)

    fig2, ax2 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    fig2.suptitle(title, fontsize=14)
    ys_acc = [game.accepted_price_history[t] for t in x]
    ax2.plot(x, ys_acc, marker="o", linestyle="None", markersize=3)
    ax2.set_title("Accepted Price (min over principals)")
    ax2.set_xlabel("Round")
    ax2.set_ylabel("Accepted Price")
    fig2.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(fig2, "timeseries_accepted_price")
    if show:
        plt.show()
    else:
        plt.close(fig2)

    fig3, ax3 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    fig3.suptitle(title, fontsize=14)
    for p in principals:
        profits = np.array(p.profit_history)
        cum = np.cumsum(profits)
        avg = cum / np.arange(1, T + 1)
        ys = [avg[t] for t in x]
        ax3.plot(x, ys, label=f"P{p.id}")
    ax3.set_title("Running Average Profit")
    ax3.set_xlabel("Round")
    ax3.set_ylabel("Average Profit")
    ax3.axhline(0.0, linestyle="--", linewidth=1)
    ax3.legend()
    fig3.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(fig3, "timeseries_running_avg_profit")
    if show:
        plt.show()
    else:
        plt.close(fig3)

    fig4, ax4 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    fig4.suptitle(title, fontsize=14)
    for p in principals:
        cum = np.cumsum(p.profit_history)
        ys = [cum[t] for t in x]
        ax4.plot(x, ys, label=f"P{p.id}")
    ax4.set_title("Cumulative Profit")
    ax4.set_xlabel("Round")
    ax4.set_ylabel("Cumulative Profit")
    ax4.legend()
    fig4.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(fig4, "timeseries_cumulative_profit")
    if show:
        plt.show()
    else:
        plt.close(fig4)

    # Histogram plots: UPDATED to show total counts (density=False) and label y-axis "Count"
    figh1, axh1 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    figh1.suptitle(f"{title} | Price Histogram", fontsize=14)
    axh1.hist(principals[0].price_history, bins=50, density=False, alpha=0.8)
    axh1.set_title("Chosen Price: Player 1")
    axh1.set_xlabel("Price")
    axh1.set_ylabel("Count")
    figh1.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(figh1, "hist_player1_chosen_price_counts")
    if show:
        plt.show()
    else:
        plt.close(figh1)

    figh2, axh2 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    figh2.suptitle(f"{title} | Price Histogram", fontsize=14)
    axh2.hist(principals[1].price_history, bins=50, density=False, alpha=0.8)
    axh2.set_title("Chosen Price: Player 2")
    axh2.set_xlabel("Price")
    axh2.set_ylabel("Count")
    figh2.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(figh2, "hist_player2_chosen_price_counts")
    if show:
        plt.show()
    else:
        plt.close(figh2)

    figh3, axh3 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    figh3.suptitle(f"{title} | Price Histogram", fontsize=14)
    axh3.hist(game.accepted_price_history, bins=50, density=False, alpha=0.8)
    axh3.set_title("Accepted Price (min)")
    axh3.set_xlabel("Price")
    axh3.set_ylabel("Count")
    figh3.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(figh3, "hist_accepted_price_counts")
    if show:
        plt.show()
    else:
        plt.close(figh3)

    return {
        "principals": principals,
        "accepted_price_history": game.accepted_price_history,
        "x": x,
    }


if __name__ == "__main__":
    np.random.seed(0)


    k, T, sample_iv = 100, 10000000, 20
    grid = [i / k for i in range(k + 1)]
    discretizations = [grid, grid]

    run_and_plot(
        discretizations=discretizations,
        costs=[0.0, 1],
        T=T,
        sample_iv=sample_iv,
        title="No-Swap Regret, Different Costs",
        station_iters=20,
        init_swap_weights=None
    )
