import math

import torch
from torch import Tensor
from torch._prims_common import DeviceLikeType

from bandit2.bandit_ctrl import (
    BanditCrUCBController,
    BanditOptimalController,
    BanditRobustThompsonSamplingController,
    BanditThompsonSamplingController,
    BanditTransformerController,
    BanditUCBController,
)
from bandit2.bandit_env import BanditController


def get_bandit_algs(
    n_envs: int,
    n_steps: int,
    n_actions: int,
    optimal_actions: Tensor | None,
    dpt_policy: BanditTransformerController | None,
    dpt_frozen_policy: BanditTransformerController | None,
    crucb_alpha: float,
    crucb_variance: float,
    rts_eps_steps: float,
    rts_max_poison_diff: float,
    *,
    device: DeviceLikeType | None = None
) -> dict[str, BanditController]:
    rts_corruption_level_known = n_steps * rts_eps_steps * rts_max_poison_diff * (1 / n_actions)
    rts_corruption_level_unknown = torch.sqrt(n_steps * torch.log(torch.tensor(n_actions)) / n_actions).item()
    rts_corruption_level_tuned = 0.5
    crucb_sigma_scaled = math.sqrt(1 - 2 * crucb_alpha) * crucb_variance
    return {
        **({"opt": BanditOptimalController(n_envs, n_steps, n_actions, optimal_actions, device=device)} if optimal_actions is not None else {}),
        **({"dpt": dpt_policy} if dpt_policy is not None else {}),
        **({"dpt_frozen": dpt_frozen_policy} if dpt_frozen_policy is not None else {}),
        "ts": BanditThompsonSamplingController(n_envs, n_steps, n_actions, sample=True, device=device),
        "rts": BanditRobustThompsonSamplingController(n_envs, n_steps, n_actions, corruption_level=rts_corruption_level_tuned, sample=True, device=device),
        "rts_u": BanditRobustThompsonSamplingController(n_envs, n_steps, n_actions, corruption_level=rts_corruption_level_unknown, sample=True, device=device),
        "rts_k": BanditRobustThompsonSamplingController(n_envs, n_steps, n_actions, corruption_level=rts_corruption_level_known, sample=True, device=device),
        "ucb": BanditUCBController(n_envs, n_steps, n_actions, device=device),
        "crucb_p": BanditCrUCBController(n_envs, n_steps, n_actions, crucb_alpha, const=crucb_variance, flag_p=True, device=device),
        "crucb_v": BanditCrUCBController(n_envs, n_steps, n_actions, crucb_alpha, const=crucb_sigma_scaled, device=device),
        "crucb": BanditCrUCBController(n_envs, n_steps, n_actions, crucb_alpha, device=device),
    }
