import os
import numpy as np
import random
import matplotlib.pyplot as plt


class BaseAlgorithm:
    def __init__(self, n_arms):
        self.n_arms = n_arms

    def choose_arm(self):
        raise NotImplementedError

    def update(self, feedback):
        raise NotImplementedError


class RegretMatching(BaseAlgorithm):
    """
    Full-information Regret Matching over losses with optional "frozen strategy" play.

    freeze_rounds = 1 means recompute the mixed strategy every round (standard).
    freeze_rounds = c means recompute once, then reuse that same distribution for c rounds.
    """
    def __init__(self, n_arms, use_rm_plus: bool = False, freeze_rounds: int = 1):
        super().__init__(n_arms)
        self.use_rm_plus = bool(use_rm_plus)

        fr = int(freeze_rounds)
        if fr < 1:
            raise ValueError("freeze_rounds must be >= 1")
        self.freeze_rounds = fr
        self.freeze_left = 0

        self.regrets = np.zeros(n_arms, dtype=float)
        self.P = np.ones(n_arms, dtype=float) / n_arms
        self.last_arm = None

    def _recompute_distribution(self):
        pos = np.maximum(self.regrets, 0.0)
        s = float(pos.sum())
        if s > 0.0:
            self.P = pos / s
        else:
            self.P = np.ones(self.n_arms, dtype=float) / self.n_arms

    def choose_arm(self):
        if self.freeze_left <= 0:
            self._recompute_distribution()
            self.freeze_left = self.freeze_rounds

        a = np.random.choice(self.n_arms, p=self.P)
        self.last_arm = a
        self.freeze_left -= 1
        return a

    def update(self, losses):
        losses = np.asarray(losses, dtype=float)
        if losses.shape != (self.n_arms,):
            raise ValueError(f"losses must have shape ({self.n_arms},)")

        if self.last_arm is None:
            raise RuntimeError("choose_arm must be called before update")

        a = self.last_arm
        inst_regret = losses[a] - losses

        if self.use_rm_plus:
            self.regrets = np.maximum(0.0, self.regrets + inst_regret)
        else:
            self.regrets += inst_regret


class Principal:
    def __init__(self, id, discretization, algorithm: BaseAlgorithm, cost: float = 0.0):
        self.id = id
        self.arms = discretization
        self.alg = algorithm
        self.cost = float(cost)
        self.price_history = []
        self.profit_history = []
        self.last_arm = None

    def select_price(self):
        arm = self.alg.choose_arm()
        self.last_arm = arm
        price = self.arms[arm]
        self.price_history.append(price)
        return price

    def update(self, fb):
        self.alg.update(fb)


class PricingGame:
    def __init__(self, principals):
        self.principals = principals
        self.accepted_price_history = []

    def run(self, T):
        n = len(self.principals)
        for _ in range(T):
            prices = [p.select_price() for p in self.principals]
            self.accepted_price_history.append(min(prices))

            for i, p in enumerate(self.principals):
                other_prices = [prices[k] for k in range(n) if k != i]
                min_other = min(other_prices) if other_prices else np.inf
                num_min_others = sum(
                    1 for k in range(n) if k != i and prices[k] == min_other
                )

                profits = np.zeros(p.alg.n_arms, dtype=float)
                losses = np.zeros_like(profits)

                for j, price_j in enumerate(p.arms):
                    if price_j < min_other:
                        prof = price_j - p.cost
                    elif price_j == min_other:
                        prof = (price_j - p.cost) / (num_min_others + 1)
                    else:
                        prof = 0.0

                    profits[j] = prof
                    losses[j] = 0.5 * (1.0 - prof)

                chosen = p.last_arm
                p.profit_history.append(profits[chosen])
                p.update(losses)


def run_and_plot(
    discretizations, costs, T, sample_iv, title,
    use_rm_plus: bool = False,
    freeze_rounds_per_player=None,
    save=True, out_dir="outputs", basename="rm_run", dpi=300, show=True
):
    assert len(discretizations) == len(costs)
    n_players = len(discretizations)

    if freeze_rounds_per_player is None:
        freeze_rounds_per_player = [1] * n_players
    if len(freeze_rounds_per_player) != n_players:
        raise ValueError("freeze_rounds_per_player must have length = number of players")

    def _script_dir():
        try:
            return os.path.dirname(os.path.abspath(__file__))
        except NameError:
            return os.getcwd()

    def _ensure_dir(path):
        os.makedirs(path, exist_ok=True)
        return path

    def _tag_num(x: float) -> str:
        s = f"{float(x):g}"
        return s.replace("-", "m").replace(".", "p")

    def _maybe_save(fig, base_tag, suffix):
        if not save:
            return
        png_path = os.path.join(out_path, f"{base_tag}_{suffix}.png")
        pdf_path = os.path.join(out_path, f"{base_tag}_{suffix}.pdf")
        fig.savefig(png_path, dpi=dpi, bbox_inches="tight")
        fig.savefig(pdf_path, dpi=dpi, bbox_inches="tight")
        print(f"Saved plot:\n  {png_path}\n  {pdf_path}")

    principals = []
    for i in range(n_players):
        alg = RegretMatching(
            n_arms=len(discretizations[i]),
            use_rm_plus=use_rm_plus,
            freeze_rounds=int(freeze_rounds_per_player[i]),
        )
        principals.append(Principal(i, discretizations[i], alg, cost=costs[i]))

    game = PricingGame(principals)
    game.run(T)

    x = np.arange(0, T, sample_iv)

    k_tag = f"k_{len(discretizations[0]) - 1}"
    T_tag = f"T_{int(T)}"
    iv_tag = f"iv_{int(sample_iv)}"
    costs_tag = "costs_" + "_".join(f"P{i}_{_tag_num(c)}" for i, c in enumerate(costs))
    rm_tag = "rmPlus_1" if use_rm_plus else "rmPlus_0"
    freeze_tag = "freeze_" + "_".join(f"P{i}_{int(fr)}" for i, fr in enumerate(freeze_rounds_per_player))
    base_tag = f"{basename}_{k_tag}_{T_tag}_{iv_tag}_{costs_tag}_{rm_tag}_{freeze_tag}"

    out_path = None
    if save:
        out_path = _ensure_dir(os.path.join(_script_dir(), out_dir))

    # Time-series plots (unchanged)
    fig1, ax1 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    fig1.suptitle(title, fontsize=14)
    for p in principals:
        ys = [p.price_history[t] for t in x]
        ax1.scatter(x, ys, s=10, label=f"P{p.id}")
    ax1.set_title("Chosen Prices")
    ax1.set_xlabel("Round")
    ax1.set_ylabel("Price")
    ax1.legend()
    fig1.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(fig1, base_tag, "timeseries_chosen_prices")
    if show:
        plt.show()
    else:
        plt.close(fig1)

    fig2, ax2 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    fig2.suptitle(title, fontsize=14)
    ys_acc = [game.accepted_price_history[t] for t in x]
    ax2.plot(x, ys_acc, marker="o", linestyle="None", markersize=3)
    ax2.set_title("Accepted Price (min over principals)")
    ax2.set_xlabel("Round")
    ax2.set_ylabel("Accepted Price")
    fig2.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(fig2, base_tag, "timeseries_accepted_price")
    if show:
        plt.show()
    else:
        plt.close(fig2)

    fig3, ax3 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    fig3.suptitle(title, fontsize=14)
    for p in principals:
        profits = np.array(p.profit_history)
        cum = np.cumsum(profits)
        avg = cum / np.arange(1, T + 1)
        ys = [avg[t] for t in x]
        ax3.plot(x, ys, label=f"P{p.id}")
    ax3.set_title("Running Average Profit")
    ax3.set_xlabel("Round")
    ax3.set_ylabel("Average Profit")
    ax3.axhline(0.0, linestyle="--", linewidth=1)
    ax3.legend()
    fig3.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(fig3, base_tag, "timeseries_running_avg_profit")
    if show:
        plt.show()
    else:
        plt.close(fig3)

    fig4, ax4 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    fig4.suptitle(title, fontsize=14)
    for p in principals:
        cum = np.cumsum(p.profit_history)
        ys = [cum[t] for t in x]
        ax4.plot(x, ys, label=f"P{p.id}")
    ax4.set_title("Cumulative Profit")
    ax4.set_xlabel("Round")
    ax4.set_ylabel("Cumulative Profit")
    ax4.legend()
    fig4.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(fig4, base_tag, "timeseries_cumulative_profit")
    if show:
        plt.show()
    else:
        plt.close(fig4)

    # Histogram plots: density -> counts (density=False) + y-label "Count"
    figh1, axh1 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    figh1.suptitle(f"{title} | Price Histogram", fontsize=14)
    axh1.hist(principals[0].price_history, bins=50, density=False, alpha=0.8)
    axh1.set_title("Chosen Price: Player 1")
    axh1.set_xlabel("Price")
    axh1.set_ylabel("Count")
    figh1.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(figh1, base_tag, "hist_player1_chosen_price_counts")
    if show:
        plt.show()
    else:
        plt.close(figh1)

    figh2, axh2 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    figh2.suptitle(f"{title} | Price Histogram", fontsize=14)
    axh2.hist(principals[1].price_history, bins=50, density=False, alpha=0.8)
    axh2.set_title("Chosen Price: Player 2")
    axh2.set_xlabel("Price")
    axh2.set_ylabel("Count")
    figh2.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(figh2, base_tag, "hist_player2_chosen_price_counts")
    if show:
        plt.show()
    else:
        plt.close(figh2)

    figh3, axh3 = plt.subplots(1, 1, figsize=(6.5, 4.5))
    figh3.suptitle(f"{title} | Price Histogram", fontsize=14)
    axh3.hist(game.accepted_price_history, bins=50, density=False, alpha=0.8)
    axh3.set_title("Accepted Price (min)")
    axh3.set_xlabel("Price")
    axh3.set_ylabel("Count")
    figh3.tight_layout(rect=[0, 0, 1, 0.92])
    _maybe_save(figh3, base_tag, "hist_accepted_price_counts")
    if show:
        plt.show()
    else:
        plt.close(figh3)

    return {
        "principals": principals,
        "accepted_price_history": game.accepted_price_history,
        "x": x,
        "base_tag": base_tag,
    }


if __name__ == "__main__":
    np.random.seed(10)

    k, T, sample_iv = 100, 10000000, 20
    grid = [i / k for i in range(k + 1)]
    discretizations = [grid, grid]

    freeze_p1 = 1
    freeze_p2 = 20

    run_and_plot(
        discretizations=discretizations,
        costs=[0.0, 1],
        T=T,
        sample_iv=sample_iv,
        title=f"Regret Matching with Frozen Strategies (P1={freeze_p1}, P2={freeze_p2})",
        use_rm_plus=False,
        freeze_rounds_per_player=[freeze_p1, freeze_p2],
    )
