import os
import json
import random
import numpy as np
import matplotlib.pyplot as plt
from multiprocessing import Pool


class BaseAlgorithm:
    def __init__(self, n_arms):
        self.n_arms = n_arms

    def choose_arm(self):
        raise NotImplementedError

    def update(self, feedback):
        raise NotImplementedError


class SwapRegretHedge(BaseAlgorithm):
    def __init__(self, n_arms, eta, station_iters=20,
                 init_weights=None, init_logweights=None, prior_strength: float = 1.0):
        super().__init__(n_arms)
        self.eta = eta
        self.station_iters = station_iters

        if init_logweights is not None:
            lw = np.asarray(init_logweights, dtype=float)
            if lw.shape != (n_arms, n_arms):
                raise ValueError("init_logweights must have shape (n_arms, n_arms)")
            self.logw = lw.copy()
        elif init_weights is not None:
            w0 = np.asarray(init_weights, dtype=float)
            if w0.shape != (n_arms, n_arms):
                raise ValueError("init_weights must have shape (n_arms, n_arms)")
            if np.any(w0 <= 0):
                raise ValueError("init_weights must be strictly positive")
            self.logw = prior_strength * np.log(w0)
        else:
            self.logw = np.zeros((n_arms, n_arms))

        self.P = np.ones(n_arms) / n_arms
        self.last_arm = None

    def choose_arm(self):
        m = np.max(self.logw, axis=1, keepdims=True)
        expw = np.exp(self.logw - m)
        Q = expw / expw.sum(axis=1, keepdims=True)

        p = np.ones(self.n_arms) / self.n_arms
        for _ in range(self.station_iters):
            p = p @ Q
        p /= p.sum()

        self.P = p
        a = np.random.choice(self.n_arms, p=p)
        self.last_arm = a
        return a

    def update(self, losses):
        L = np.outer(self.P, losses)
        self.logw -= self.eta * L


class Principal:
    def __init__(self, id, discretization, algorithm: BaseAlgorithm, cost: float = 0.0):
        self.id = id
        self.arms = discretization
        self.alg = algorithm
        self.cost = float(cost)
        self.last_arm = None

        self.price_history = []
        self.profit_history = []

    def select_price(self):
        arm = self.alg.choose_arm()
        self.last_arm = arm
        price = self.arms[arm]
        self.price_history.append(price)
        return price

    def update(self, fb):
        self.alg.update(fb)


class PricingGame:
    """
    Buyer picks the lowest posted price (uniform tie-split).
    Principal i's payoff at price x is:
      0 if not selected,
      x - cost_i if selected (split evenly among all tied minimums).
    """
    def __init__(self, principals):
        self.principals = principals
        self.accepted_price_history = []

    def run(self, T):
        n = len(self.principals)
        for _ in range(T):
            prices = [p.select_price() for p in self.principals]
            self.accepted_price_history.append(min(prices))

            for i, p in enumerate(self.principals):
                other_prices = [prices[k] for k in range(n) if k != i]
                min_other = min(other_prices) if other_prices else np.inf
                num_min_others = sum(
                    1 for k in range(n) if k != i and prices[k] == min_other
                )

                profits = np.zeros(p.alg.n_arms, dtype=float)
                losses = np.zeros_like(profits)

                for j, price_j in enumerate(p.arms):
                    if price_j < min_other:
                        prof = price_j - p.cost
                    elif price_j == min_other:
                        prof = (price_j - p.cost) / (num_min_others + 1)
                    else:
                        prof = 0.0

                    profits[j] = prof
                    losses[j] = 0.5 * (1.0 - prof)

                chosen = p.last_arm
                p.profit_history.append(profits[chosen])
                p.update(losses)


def _script_dir():
    try:
        return os.path.dirname(os.path.abspath(__file__))
    except NameError:
        return os.getcwd()


def _ensure_dir(path):
    os.makedirs(path, exist_ok=True)
    return path


def _init_W0(init_swap_weights, i, discretizations, costs):
    if init_swap_weights is None:
        return None
    if callable(init_swap_weights):
        return init_swap_weights(i, np.asarray(discretizations[i], dtype=float), float(costs[i]))
    return init_swap_weights[i] if i < len(init_swap_weights) else None


def _count_hist(x, bin_edges):
    counts, _ = np.histogram(x, bins=bin_edges)
    return counts.astype(float)


def _tag_num(x: float) -> str:
    """
    Filename-safe number tag:
      0.0 -> "0"
      1.25 -> "1p25"
      -0.5 -> "m0p5"
    """
    s = f"{float(x):g}"
    s = s.replace("-", "m").replace(".", "p")
    return s


def _eta_value(eta_base: float, i: int, eta_mode: str) -> float:
    """
    eta_mode options:
      - "eta_11_minus_10i": eta = eta_base * (11 - 10*i)
      - "eta_1_plus_10i":   eta = eta_base * (1 + 10*i)
      - "eta_base":         eta = eta_base
    """
    if eta_mode == "eta_11_minus_10i":
        return eta_base * (11 - 10 * i)
    if eta_mode == "eta_1_plus_10i":
        return eta_base * (1 + 10 * i)
    if eta_mode == "eta_base":
        return eta_base
    raise ValueError(f"Unknown eta_mode: {eta_mode}")


def simulate_one_seed_hist(args):
    """
    Worker: run one seed and return count histograms (not raw samples).
    """
    (seed, discretizations, costs, T, station_iters, init_swap_weights,
     prior_strength, bin_edges, eta_mode, verbose) = args

    if verbose:
        print(seed)

    np.random.seed(seed)
    random.seed(seed)

    etas = [np.sqrt(np.log(len(d)) / T) for d in discretizations]

    principals = []
    for i in range(len(discretizations)):
        eta_i = _eta_value(float(etas[i]), i, eta_mode)
        principals.append(
            Principal(
                i,
                discretizations[i],
                SwapRegretHedge(
                    n_arms=len(discretizations[i]),
                    eta=eta_i,
                    station_iters=station_iters,
                    init_weights=_init_W0(init_swap_weights, i, discretizations, costs),
                    prior_strength=prior_strength,
                ),
                cost=costs[i],
            )
        )

    game = PricingGame(principals)
    game.run(T)

    h_p1 = _count_hist(np.asarray(principals[0].price_history, dtype=float), bin_edges)
    h_p2 = _count_hist(np.asarray(principals[1].price_history, dtype=float), bin_edges)
    h_acc = _count_hist(np.asarray(game.accepted_price_history, dtype=float), bin_edges)

    return h_p1, h_p2, h_acc


def plot_hist_mean_counts(ax, bin_edges, runs_hist, title):
    runs_hist = np.asarray(runs_hist, dtype=float)  # (n_runs, n_bins)
    mean = runs_hist.mean(axis=0)

    centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    widths = np.diff(bin_edges)

    ax.bar(centers, mean, width=widths, align="center", alpha=0.8)
    ax.set_title(title)
    ax.set_xlabel("Price")
    ax.set_ylabel("Count (mean per seed)")


def run_parallel_histograms_save_separate(
    discretizations,
    costs,
    T,
    n_seeds=50,
    station_iters=20,
    init_swap_weights=None,
    prior_strength=1.0,
    eta_mode="eta_base",
    bins=50,
    processes=None,
    title_prefix="SwapRegretHedge (mean counts per seed)",
    save_basename="swap_regret",
    dpi=300,
    show=False,
    out_subdir="outputs",
    verbose_seeds=False,
):
    """
    Runs one eta_mode:
      - saves 3 separate plot files (P1, P2, Accepted)
      - saves data (.npz) with bin_edges and per-seed hist counts for each series
      - includes per-player costs and eta_mode in filenames
    """

    all_prices = np.asarray(sorted({float(x) for disc in discretizations for x in disc}), dtype=float)
    lo = float(all_prices.min())
    hi = float(all_prices.max())
    bin_edges = np.linspace(lo, hi, int(bins) + 1)

    if processes is None:
        processes = min(n_seeds, os.cpu_count() or 1)

    worker_args = [
        (seed, discretizations, costs, T, station_iters, init_swap_weights,
         prior_strength, bin_edges, eta_mode, verbose_seeds)
        for seed in range(n_seeds)
    ]

    with Pool(processes=processes) as pool:
        results = pool.map(simulate_one_seed_hist, worker_args)

    p1_runs = np.stack([r[0] for r in results], axis=0)
    p2_runs = np.stack([r[1] for r in results], axis=0)
    acc_runs = np.stack([r[2] for r in results], axis=0)

    out_dir = _ensure_dir(os.path.join(_script_dir(), out_subdir))

    cost_p1, cost_p2 = costs
    cost_tag = f"costP1_{_tag_num(cost_p1)}_costP2_{_tag_num(cost_p2)}"
    eta_tag = f"etaMode_{eta_mode}"
    run_tag = f"{cost_tag}_{eta_tag}_stationIters_{int(station_iters)}"

    meta = {
        "T": int(T),
        "n_seeds": int(n_seeds),
        "bins": int(bins),
        "costs": [float(c) for c in costs],
        "eta_mode": str(eta_mode),
        "station_iters": int(station_iters),
        "prior_strength": float(prior_strength),
        "save_basename": str(save_basename),
    }

    data_path = os.path.join(out_dir, f"{save_basename}_{run_tag}_data.npz")
    np.savez_compressed(
        data_path,
        bin_edges=bin_edges,
        p1_hists=p1_runs,
        p2_hists=p2_runs,
        acc_hists=acc_runs,
        meta_json=np.array(json.dumps(meta), dtype=object),
    )

    plots = [
        ("player1", p1_runs, "Chosen Price: Player 1"),
        ("player2", p2_runs, "Chosen Price: Player 2"),
        ("accepted", acc_runs, "Accepted Price (min)"),
    ]

    saved = {"data": data_path, "plots": []}

    for suffix, runs_hist, plot_title in plots:
        fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.0))

        full_title = (
            f"{title_prefix}\n"
            f"{cost_tag} | {eta_tag} | station_iters={int(station_iters)}\n"
            f"{plot_title}"
        )

        plot_hist_mean_counts(ax, bin_edges, runs_hist, full_title)
        plt.tight_layout()

        png_path = os.path.join(out_dir, f"{save_basename}_{run_tag}_{suffix}.png")
        pdf_path = os.path.join(out_dir, f"{save_basename}_{run_tag}_{suffix}.pdf")
        fig.savefig(png_path, dpi=dpi, bbox_inches="tight")
        fig.savefig(pdf_path, dpi=dpi, bbox_inches="tight")
        saved["plots"].append((png_path, pdf_path))

        if show:
            plt.show()
        else:
            plt.close(fig)

    print(f"\nRun tag: {run_tag}")
    print(f"Saved data:\n  {data_path}")
    print("Saved plots:")
    for png_path, pdf_path in saved["plots"]:
        print(f"  {png_path}\n  {pdf_path}")

    return saved


if __name__ == "__main__":

    k, T = 100, 10000000
    grid = [i / k for i in range(k + 1)]
    discretizations = [grid, grid]

    costs = [0.0, 1]

    eta_modes = [
        "eta_11_minus_10i",  # eta=etas[i] * (11 - 10*i)
        "eta_1_plus_10i",    # eta=etas[i] * (1 + 10*i)
        "eta_base",          # eta=etas[i]
    ]

    for mode in eta_modes:
        run_parallel_histograms_save_separate(
            discretizations=discretizations,
            costs=costs,
            T=T,
            n_seeds=100,
            station_iters=20,
            init_swap_weights=None,
            prior_strength=1.0,
            eta_mode=mode,
            bins=50,
            processes=None,
            title_prefix="SwapRegretHedge (mean counts per seed)",
            save_basename="swap_regret_costs_0_1",
            dpi=300,
            show=False,
            out_subdir="outputs",
            verbose_seeds=False,
        )

    # Example reload:
    # import numpy as np, json
    # d = np.load("outputs/swap_regret_costs_0_1_costP1_0_costP2_1_etaMode_eta_base_stationIters_20_data.npz", allow_pickle=True)
    # meta = json.loads(d["meta_json"].item())
    # print(meta)
