import os

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import beta as beta_dist
from tap import Tap

from meta_alignment.reward_transformation import (
    calc_BoN_linearized_reward,
)


def calc_approximated_reward_diff(
    linearized_reward_func, rewards: np.ndarray, n: int, baseline=False
) -> float:
    ratio = beta_dist.pdf(rewards, a=alpha, b=beta)
    linearized_reward = linearized_reward_func(rewards, n)
    if baseline:
        linearized_reward -= np.mean(linearized_reward)
    linearized_diff = np.mean(linearized_reward * (ratio - 1))
    return linearized_diff


class CompareConfig(Tap):
    num_samples: int = 100000
    max_M: int = 32
    eps: float = 0.0001


args = CompareConfig().parse_args()

for n in [1, 4, 8, 16]:
    result_path = f"results/errors/bias_variance_n={n}.npz"
    ms = [M for M in range(2, args.max_M + 1)]

    if os.path.exists(result_path):
        data = np.load(result_path)
        biases = data["biases"]
        variances = data["variances"]
    else:
        eps = args.eps
        alpha = 1 + eps
        beta = 1
        num_samples = args.num_samples
        biases = []
        variances = []
        for M in ms:
            bon_reward_unif = n / (n + 1)
            bon_reward_beta = (n * alpha) / (n * alpha + 1)
            bon_reward_diff = bon_reward_beta - bon_reward_unif

            estimates = []
            for _ in range(num_samples):
                rewards = np.random.rand(M)

                linearized_diff = calc_approximated_reward_diff(
                    calc_BoN_linearized_reward, rewards, n, baseline=True
                )
                estimates.append(linearized_diff)
            normalization = bon_reward_diff**2
            biases.append((np.mean(estimates) - bon_reward_diff) ** 2 / normalization)
            variances.append(np.var(estimates) / normalization)
            os.makedirs(os.path.dirname(result_path), exist_ok=True)
            print(f"M={M}: bias={biases[-1]:.6g}, var={variances[-1]:.6g}, ")

        biases = np.array(biases)  # type: ignore
        variances = np.array(variances)  # type: ignore
        np.savez(result_path, biases=biases, variances=variances)

    errors = biases + variances
    with plt.style.context("./config/paper.mplstyle"):
        plt.rcParams["font.size"] = 12
        fig = plt.figure(figsize=(4, 3))
        ax = fig.add_subplot(1, 1, 1)
        ax.stackplot(
            ms,
            biases,
            variances,
            labels=["Bias", "Variance"],
            alpha=0.5,
        )
        ax.plot(ms, errors, label="MSE", color="black", linestyle="--")
        ax.set_xlabel(r"Number of samples $M$")
        ax.set_ylabel("Relative error")
        ax.legend()
        ax.set_title(rf"$N={n}$")
        ax.set_ylim(0, 1.0)
    fig.savefig(f"figs/errors/compare_linearized_reward_n={n}.png", bbox_inches="tight")
