# This data is extracted from figure 4a of Greenblatt et al. (2024)
import matplotlib.pyplot as plt
import numpy as np

from ai_control_games.trusted_monitor.tm_game import TMModellingAssumptions
from ai_control_games.distributions import ContFromDiscrete, DiscreteDist, BetaDistOver0to1

BETA_APPROX_OF_RW_SF = BetaDistOver0to1(a=7.385681061563469, b=6.85811507578284)
BETA_APPROX_OF_RW_BD = BetaDistOver0to1(a=8.907878167987224, b=2.8202215050915025)

_tm_sf_probs = [
    0.2596219986385344,
    0.11574197825226595,
    0.2700860001211716,
    0.2700860001211716,
    0.31194200605172495,
    0.31194200605172495,
    1.6068621895281545,
    10.082703390464784,
    18.44082457472174,
    23.777465330867027,
    22.364825130710916,
    14.200287973882764,
    7.021982956793223,
    4.039742534241446,
    3.2157024174837185,
    3.019502389684259,
    1.8871479435273804,
    1.135982122809453,
    0.31194200605172495,
    0.07650197269237415,
]


def get_tm_dist_from_bars(bars):
    bar_heights = np.array([x for x in bars])
    total_height = sum(bar_heights)
    probs = bar_heights / total_height
    return ContFromDiscrete(probs=probs), DiscreteDist(probs=probs)


CONT_TM_SF_BARS_READ_OFF_FIG_4A, DISC_TM_SF_BARS_READ_OFF_FIG_4A \
    = get_tm_dist_from_bars(_tm_sf_probs)

_tm_bd_probs = [
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    1.6068621895281545,
    0.0,
    1.6068621895281545,
    4.78530263987939,
    6.354902862275062,
    4.78530263987939,
    7.9637430902306265,
    20.59902488051579,
    15.890224213328771,
    22.166009102540798,
    15.890224213328771,
    9.5333433126263,
    1.6068621895281545,
]

CONT_TM_BD_BARS_READ_OFF_FIG_4A, DISC_TM_BD_BARS_READ_OFF_FIG_4A \
    = get_tm_dist_from_bars(_tm_bd_probs)

CONT_MA = TMModellingAssumptions(
    safe_dist=CONT_TM_SF_BARS_READ_OFF_FIG_4A,
    bd_dist=CONT_TM_BD_BARS_READ_OFF_FIG_4A
)

DISC_MA = TMModellingAssumptions(
    safe_dist=DISC_TM_SF_BARS_READ_OFF_FIG_4A,
    bd_dist=DISC_TM_BD_BARS_READ_OFF_FIG_4A
)

_num_bd_problems = 19
_num_sf_problems = 1500 - 19

modifier = 1.4
BETA_WEAKER_SIGNAL_SF = BETA_APPROX_OF_RW_SF
BETA_WEAKER_SIGNAL_BD = BetaDistOver0to1(a=BETA_APPROX_OF_RW_BD.a - modifier, b=BETA_APPROX_OF_RW_BD.b + modifier)


# Outputs the following:
# Safe: a=7.385681061563469, b=6.85811507578284, loc=0, scale=1
# Backdoored: a=8.907878167987224, b=2.8202215050915025, loc=0.0, scale=1.0

def fit_beta_to_disc():
    x_values = DISC_TM_SF_BARS_READ_OFF_FIG_4A.vals
    disc_probabilities_sf = DISC_TM_SF_BARS_READ_OFF_FIG_4A.probs
    disc_probabilities_bd = DISC_TM_BD_BARS_READ_OFF_FIG_4A.probs

    fig, ax1 = plt.subplots()
    ax1.set_ylim(0, 0.22)
    ax2 = ax1.twinx()  # instantiate a second Axes that shares the same x-axis
    ax2.set_ylim(0, 3.5)

    xs = np.linspace(0, 1, 1001)
    bar_width = (x_values[1] - x_values[0]) / 2
    ax1.bar(x_values, disc_probabilities_sf, bar_width, alpha=0.3, label='Safe dist')
    ax1.bar(x_values, disc_probabilities_bd, bar_width, alpha=0.3, label='Backdoored dist')

    from scipy.stats import beta

    # Data should be a list of values between 0 and 1 - we need to generate these from the discrete data pdf
    def get_data_from_pdf(xs, ps):
        data = []
        for x, p in zip(xs, ps):
            data.extend([x] * int(p * 100000))
        return data

    data_sf = get_data_from_pdf(x_values, disc_probabilities_sf)
    a_sf, b_sf, loc_sf, scale_sf = beta.fit(data=data_sf, floc=0, fscale=1)
    label = f'Safe \n $\\beta (a={a_sf:.2f}, b={b_sf:.2f})$'
    ax2.plot(xs, beta.pdf(xs, a_sf, b_sf, loc=loc_sf, scale=scale_sf), label=label)

    data_bd = get_data_from_pdf(x_values, disc_probabilities_bd)
    a_bd, b_bd, loc_bd, scale_bd = beta.fit(data=data_bd, floc=0.0, fscale=1.0)
    label = f'Backdoored \n $\\beta (a={a_bd:.2f}, b={b_bd:.2f})$'
    ax2.plot(xs, beta.pdf(xs, a_bd, b_bd, loc=loc_bd, scale=scale_bd), label=label)
    # Turn on legend
    ax1.legend(loc=(1.1, 0.1), title='Data from Greenblatt et al. (2024) fig 4a')
    ax2.legend(loc=(1.1, 0.6), title='Fitted beta distributions')

    print(f'Safe: a={a_sf}, b={b_sf}, loc={loc_sf}, scale={scale_sf}')
    print(f'Backdoored: a={a_bd}, b={b_bd}, loc={loc_bd}, scale={scale_bd}')

    return fig, ax1, ax2
