from libmab.learners import CUCB, CRandom, Fixed
from libmab.attackers import BPOracleCombinatorialAttacker
from libmab.envs.combinatorial import CombinatorialBoundedGaussianEnv
from libmab.visualization import Colors
from libmab.utils import save
from tqdm import tqdm
from typing import List

import matplotlib.pyplot as plt
import tikzplotlib as tkz
import numpy as np
import os

plt.style.use("config.mlpstyle")
# ------------------------------
# Parameters
# ------------------------------
np.random.seed(123)
name = "oa_bounded_positive"
exp_name = "exps"

E = 10
T = 10**5
arms = np.array([0.5, 0.5, 0.25, 0.25])
target = np.array([0, 0, 1, 1])
K = len(arms)
d = int(np.sum(target))
sigma = 0.1
epsilon = 0.01


env = CombinatorialBoundedGaussianEnv(arms, sigma=sigma, d=d)

bandits = [
    CUCB(K, T, sigma=sigma, d=d),
    CUCB(K, T, sigma=sigma, d=d),
]

labels = [
    "CUCB",
    "OA CUCB",
]

colors = [
    Colors.orange,
    Colors.blue,
]

markers = [
    "s",
    "*",
]

attackers = [
    None,
    BPOracleCombinatorialAttacker(K, T, target, arms, d=d, epsilon=epsilon),
]

regrets = np.zeros((len(bandits), E, T))
rewards = np.zeros((len(bandits), E, T))
armpull = np.zeros((len(bandits), E, K))
attacks = np.zeros((len(attackers), E, T))


for e in tqdm(range(E)):
    for t in tqdm(range(T)):
        rewardvec = env.rewardvec(e, t)
        for b_id, (bandit, attacker) in enumerate(zip(bandits, attackers)):
            arm = bandit.pull_arm()
            reward = rewardvec * arm
            attack = (
                attacker.attack(reward, arm) if attacker is not None else np.zeros(K)
            )
            # if b_id == 2:
            #    print(f"{bandit.__class__.__name__} {arm} {attack}")
            bandit.update(reward - attack, arm)

            #  update data for visualization
            rewards[b_id, e, t] = np.sum(env.pseudo_reward(arm))  # np.sum(reward)
            regrets[b_id, e, t] = np.sum(
                env.pseudo_reward(env.opt_arm()) - env.pseudo_reward(arm)
            )
            armpull[b_id, e, :] += arm
            attacks[b_id, e, t] += np.sum(attack)

    for b, a in zip(bandits, attackers):
        print(b)
        b.reset()
        if a is not None:
            print(a)
            a.reset()


os.makedirs(f"{exp_name}/", exist_ok=True)
EXTRA_TIKZ_PARAM = ["baseline", "every node/.append style={font=\\Large}"]
EXTRA_AXIS_PARAM = [
    "width=.6\\linewidth",
    "height=.2\\textheight",
    "max space between ticks=1000pt",
    "try min ticks=5",
    "grid=major",
    "grid style={dashed, gray!30}",
]
EXT = "pdf"
SKIP = 1000
x = [*range(T)][::SKIP]

# ----- Regrets -----
fig, ax = plt.subplots()
for b_id, bandit in enumerate(bandits):
    y = np.mean(np.cumsum(regrets, axis=2), axis=1)[b_id]
    c = 1.96 * np.std(np.cumsum(regrets, axis=2), axis=1)[b_id] / np.sqrt(E)
    ax.plot(
        x,
        y[::SKIP],
        label=labels[b_id],
        color=colors[b_id],
        marker=markers[b_id],
        markevery=len(x) // len(ax.get_xticks()),
    )
    ax.fill_between(x, (y - c)[::SKIP], (y + c)[::SKIP], alpha=0.5, color=colors[b_id])
ax.legend()
#ax.set_title("[BP] Cumulative Regret")
ax.set_xlabel("t")
ax.set_ylabel("Regret")
ax.grid(True, ls="--", lw=0.5)
fig.savefig(f"{exp_name}/{name}_regret")
tkz.save(
    f"{exp_name}/{name}_regret.tex",
    extra_tikzpicture_parameters=EXTRA_TIKZ_PARAM,
    extra_axis_parameters=EXTRA_AXIS_PARAM,
)

# ----- Attack Cost -----
fig, ax = plt.subplots()
for b_id, bandit in enumerate(bandits):
    if attackers[b_id] is None:
        continue
    y = np.mean(np.cumsum(attacks, axis=2), axis=1)[b_id]
    c = 1.96 * np.std(np.cumsum(attacks, axis=2), axis=1)[b_id] / np.sqrt(E)
    ax.plot(
        x,
        y[::SKIP],
        label=labels[b_id],
        color=colors[b_id],
        marker=markers[b_id],
        markevery=len(x) // len(ax.get_xticks()),
    )
    ax.fill_between(x, (y - c)[::SKIP], (y + c)[::SKIP], alpha=0.5, color=colors[b_id])
ax.legend()
#ax.set_title("[BP] Cumulative Attack Cost")
ax.set_xlabel("t")
ax.set_ylabel("Attack Cost")
ax.grid(True, ls="--", lw=0.5)
fig.savefig(f"{exp_name}/{name}_attack_cost")
tkz.save(
    f"{exp_name}/{name}_attack_cost.tex",
    extra_tikzpicture_parameters=EXTRA_TIKZ_PARAM,
    extra_axis_parameters=EXTRA_AXIS_PARAM,
)
