import os
import json
import random
import numpy as np
import matplotlib.pyplot as plt
from multiprocessing import Pool


class BaseAlgorithm:
    def __init__(self, n_arms):
        self.n_arms = n_arms

    def choose_arm(self):
        raise NotImplementedError

    def update(self, feedback):
        raise NotImplementedError


class RegretMatching(BaseAlgorithm):
    def __init__(self, n_arms, use_rm_plus=False, freeze_rounds=1):
        super().__init__(n_arms)
        self.use_rm_plus = bool(use_rm_plus)

        fr = int(freeze_rounds)
        if fr < 1:
            raise ValueError("freeze_rounds must be >= 1")
        self.freeze_rounds = fr
        self.freeze_left = 0

        self.regrets = np.zeros(n_arms, dtype=float)
        self.P = np.ones(n_arms, dtype=float) / n_arms
        self.last_arm = None

    def _recompute_distribution(self):
        pos = np.maximum(self.regrets, 0.0)
        s = float(pos.sum())
        if s > 0.0:
            self.P = pos / s
        else:
            self.P = np.ones(self.n_arms, dtype=float) / self.n_arms

    def choose_arm(self):
        if self.freeze_left <= 0:
            self._recompute_distribution()
            self.freeze_left = self.freeze_rounds

        a = np.random.choice(self.n_arms, p=self.P)
        self.last_arm = a
        self.freeze_left -= 1
        return a

    def update(self, losses):
        losses = np.asarray(losses, dtype=float)
        if losses.shape != (self.n_arms,):
            raise ValueError(f"losses must have shape ({self.n_arms},)")

        if self.last_arm is None:
            raise RuntimeError("choose_arm must be called before update")

        a = self.last_arm
        inst_regret = losses[a] - losses

        if self.use_rm_plus:
            self.regrets = np.maximum(0.0, self.regrets + inst_regret)
        else:
            self.regrets += inst_regret


class Principal:
    def __init__(self, id, discretization, algorithm: BaseAlgorithm, cost=0.0):
        self.id = id
        self.arms = discretization
        self.alg = algorithm
        self.cost = float(cost)
        self.last_arm = None

        self.price_history = []
        self.profit_history = []

    def select_price(self):
        arm = self.alg.choose_arm()
        self.last_arm = arm
        price = self.arms[arm]
        self.price_history.append(price)
        return price

    def update(self, fb):
        self.alg.update(fb)


class PricingGame:
    """
    Buyer picks the lowest posted price (uniform tie-split).
    Principal i's payoff at price x is:
      0 if not selected,
      x - cost_i if selected (split evenly among all tied minimums).
    """
    def __init__(self, principals):
        self.principals = principals
        self.accepted_price_history = []

    def run(self, T):
        n = len(self.principals)
        for _ in range(T):
            prices = [p.select_price() for p in self.principals]
            self.accepted_price_history.append(min(prices))

            for i, p in enumerate(self.principals):
                other_prices = [prices[k] for k in range(n) if k != i]
                min_other = min(other_prices) if other_prices else np.inf
                num_min_others = sum(
                    1 for k in range(n) if k != i and prices[k] == min_other
                )

                profits = np.zeros(p.alg.n_arms, dtype=float)
                losses = np.zeros_like(profits)

                for j, price_j in enumerate(p.arms):
                    if price_j < min_other:
                        prof = price_j - p.cost
                    elif price_j == min_other:
                        prof = (price_j - p.cost) / (num_min_others + 1)
                    else:
                        prof = 0.0

                    profits[j] = prof
                    losses[j] = 0.5 * (1.0 - prof)

                chosen = p.last_arm
                p.profit_history.append(profits[chosen])
                p.update(losses)


def _script_dir():
    try:
        return os.path.dirname(os.path.abspath(__file__))
    except NameError:
        return os.getcwd()


def _ensure_dir(path):
    os.makedirs(path, exist_ok=True)
    return path


def _count_hist(x, bin_edges):
    counts, _ = np.histogram(x, bins=bin_edges)
    return counts.astype(float)


def _tag_num(x: float) -> str:
    """
    Create a filename-safe number tag:
      0.0 -> "0"
      1.25 -> "1p25"
      -0.5 -> "m0p5"
      1e-3 -> "0p001" (via :g formatting)
    """
    s = f"{float(x):g}"
    s = s.replace("-", "m").replace(".", "p")
    return s


def simulate_one_seed_hist(args):
    """
    Worker: run one seed and return count histograms (not raw samples).
    """
    (seed, discretizations, costs, T, use_rm_plus, freeze_rounds_per_player, bin_edges, verbose) = args
    if verbose:
        print(seed)

    np.random.seed(seed)
    random.seed(seed)

    n_players = len(discretizations)
    if freeze_rounds_per_player is None:
        freeze_rounds_per_player = [1] * n_players
    if len(freeze_rounds_per_player) != n_players:
        raise ValueError("freeze_rounds_per_player must have length = number of players")

    principals = []
    for i in range(n_players):
        principals.append(
            Principal(
                i,
                discretizations[i],
                RegretMatching(
                    n_arms=len(discretizations[i]),
                    use_rm_plus=use_rm_plus,
                    freeze_rounds=int(freeze_rounds_per_player[i]),
                ),
                cost=costs[i],
            )
        )

    game = PricingGame(principals)
    game.run(T)

    h_p1 = _count_hist(np.asarray(principals[0].price_history, dtype=float), bin_edges)
    h_p2 = _count_hist(np.asarray(principals[1].price_history, dtype=float), bin_edges)
    h_acc = _count_hist(np.asarray(game.accepted_price_history, dtype=float), bin_edges)

    return h_p1, h_p2, h_acc


def plot_hist_mean_counts(ax, bin_edges, runs_hist, title):
    runs_hist = np.asarray(runs_hist, dtype=float)  # (n_runs, n_bins)
    mean = runs_hist.mean(axis=0)

    centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    widths = np.diff(bin_edges)

    ax.bar(centers, mean, width=widths, align="center", alpha=0.8)
    ax.set_title(title)
    ax.set_xlabel("Price")
    ax.set_ylabel("Count (mean over seeds)")


def run_parallel_histograms_counts_save_separate(
    discretizations,
    costs,
    T,
    n_seeds=50,
    use_rm_plus=False,
    freeze_rounds_per_player=None,
    bins=50,
    processes=None,
    title_prefix="Regret Matching - hist mean counts over seeds",
    save_basename="rm_hist",
    dpi=300,
    show=False,
    out_subdir="outputs",
    verbose_seeds=False,
):
    """
    Runs one (freeze_p1, freeze_p2) combo:
      - saves 3 separate plot files (P1, P2, Accepted)
      - saves data (.npz) with bin_edges and per-seed hist counts for each series
      - includes per-player costs in the plot filenames
    """

    # Common binning for all plots and all seeds
    all_prices = np.asarray(sorted({float(x) for disc in discretizations for x in disc}), dtype=float)
    lo = float(all_prices.min())
    hi = float(all_prices.max())
    bin_edges = np.linspace(lo, hi, int(bins) + 1)

    if processes is None:
        processes = min(n_seeds, os.cpu_count() or 1)

    worker_args = [
        (seed, discretizations, costs, T, use_rm_plus, freeze_rounds_per_player, bin_edges, verbose_seeds)
        for seed in range(n_seeds)
    ]

    with Pool(processes=processes) as pool:
        results = pool.map(simulate_one_seed_hist, worker_args)

    p1_runs = np.stack([r[0] for r in results], axis=0)   # (n_seeds, n_bins)
    p2_runs = np.stack([r[1] for r in results], axis=0)
    acc_runs = np.stack([r[2] for r in results], axis=0)

    out_dir = _ensure_dir(os.path.join(_script_dir(), out_subdir))

    freeze_p1, freeze_p2 = freeze_rounds_per_player
    cost_p1, cost_p2 = costs

    cost_tag = f"costP1_{_tag_num(cost_p1)}_costP2_{_tag_num(cost_p2)}"
    combo_tag = f"{cost_tag}_freezeP1_{int(freeze_p1)}_freezeP2_{int(freeze_p2)}"

    # Save data for reuse later
    meta = {
        "T": int(T),
        "n_seeds": int(n_seeds),
        "bins": int(bins),
        "costs": [float(c) for c in costs],
        "use_rm_plus": bool(use_rm_plus),
        "freeze_rounds_per_player": [int(freeze_p1), int(freeze_p2)],
        "save_basename": save_basename,
    }

    data_path = os.path.join(out_dir, f"{save_basename}_{combo_tag}_data.npz")
    np.savez_compressed(
        data_path,
        bin_edges=bin_edges,
        p1_hists=p1_runs,
        p2_hists=p2_runs,
        acc_hists=acc_runs,
        meta_json=np.array(json.dumps(meta), dtype=object),
    )

    # Save three separate plots per combo
    plots = [
        ("player1", p1_runs, "Chosen Price - Player 1"),
        ("player2", p2_runs, "Chosen Price - Player 2"),
        ("accepted", acc_runs, "Accepted Price (min)"),
    ]

    saved = {"data": data_path, "plots": []}

    for suffix, runs_hist, plot_title in plots:
        fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.0))

        full_title = (
            f"{title_prefix}\n"
            f"{cost_tag} | freezeP1={int(freeze_p1)}, freezeP2={int(freeze_p2)}\n"
            f"{plot_title}"
        )

        plot_hist_mean_counts(ax, bin_edges, runs_hist, full_title)
        plt.tight_layout()

        png_path = os.path.join(out_dir, f"{save_basename}_{combo_tag}_{suffix}.png")
        pdf_path = os.path.join(out_dir, f"{save_basename}_{combo_tag}_{suffix}.pdf")
        fig.savefig(png_path, dpi=dpi, bbox_inches="tight")
        fig.savefig(pdf_path, dpi=dpi, bbox_inches="tight")
        saved["plots"].append((png_path, pdf_path))

        if show:
            plt.show()
        else:
            plt.close(fig)

    print(f"\nCombo: {combo_tag}")
    print(f"Saved data:\n  {data_path}")
    print("Saved plots:")
    for png_path, pdf_path in saved["plots"]:
        print(f"  {png_path}\n  {pdf_path}")

    return saved


if __name__ == "__main__":
    k, T = 100, 10000000
    grid = [i / k for i in range(k + 1)]
    discretizations = [grid, grid]

    costs = [0.0, 0.5]

    combos = [
        (1, 1),
        (10, 1),
        (1, 10),
    ]

    for freeze_p1, freeze_p2 in combos:
        run_parallel_histograms_counts_save_separate(
            discretizations=discretizations,
            costs=costs,
            T=T,
            n_seeds=100,
            use_rm_plus=False,
            freeze_rounds_per_player=[freeze_p1, freeze_p2],
            bins=50,
            processes=None,
            title_prefix="Regret Matching (mean counts over seeds)",
            save_basename="rm_frozen",
            dpi=300,
            show=False,
            out_subdir="outputs",
            verbose_seeds=False,
        )

    # Example reload:
    # import numpy as np, json
    # d = np.load("outputs/rm_frozen_costP1_0_costP2_1_freezeP1_10_freezeP2_1_data.npz", allow_pickle=True)
    # meta = json.loads(d["meta_json"].item())
    # print(meta)
