import math
import time
from typing import Any

import torch
from torch import Tensor
from torch.nn.functional import one_hot
from torch.utils.data import DataLoader

from args import AdversarialTrainingConfig
from bandit2.bandit_dataset import BanditDataset, BanditDatasetTorch
from bandit2.bandit_env import BanditController
from net import Transformer


class BanditTransformerController(BanditController):
    model: Transformer
    sample: bool
    frozen: bool
    dataset: BanditDataset

    _optimizer: torch.optim.Optimizer | None

    def __init__(
        self,
        model: Transformer,
        n_envs: int,
        n_steps: int,
        n_actions: int,
        sample: bool = False,
        *,
        frozen: bool = False,
        lr: float | None = None,
        device=None,
    ):
        super().__init__(n_envs, n_steps, n_actions, device=device)

        self.model = model
        self.sample = sample
        self.frozen = frozen

        self.dataset = BanditDataset(n_envs, n_steps, n_actions, device=device)

        self._optimizer = None
        if lr is not None:
            self._optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=1e-4)

    def clear_dataset(self) -> None:
        self.dataset.clear()

    def append(self, actions: Tensor, rewards: Tensor, rewards_original: Tensor) -> None:
        self.dataset.append(actions, rewards, rewards_original)

    def sample_actions(self) -> Tensor:
        context = self.dataset.get_context_for_transformer()
        query_line = self.model.make_query_line(torch.ones((self.n_envs, 1)), self.n_envs)

        outputs = self.model.predict_actions(context, query_line)
        outputs = outputs.detach()

        if self.sample:
            probs = torch.softmax(outputs, dim=-1)
            action_indices = torch.multinomial(probs, 1)[..., 0]
        else:
            action_indices = torch.argmax(outputs, dim=-1)

        return one_hot(action_indices, self.n_actions).float()

    def update(self, dataset: BanditDatasetTorch, adv_train_config: AdversarialTrainingConfig) -> tuple[list[dict[str, Any]], dict[str, Any]]:
        if self.frozen:
            return [], {}

        assert self._optimizer is not None

        metrics = []
        # self.model.test = False

        loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
        train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

        train_loss = []
        for _ in range(adv_train_config.victim_iters):
            start_time = time.time()
            epoch_train_loss = 0

            for batch, true_actions in train_loader:
                pred_actions = self.model(batch)
                del batch

                true_actions = true_actions.reshape(-1, self.n_actions)
                pred_actions = pred_actions.reshape(-1, self.n_actions)

                self._optimizer.zero_grad()
                loss: Tensor = loss_fn(pred_actions, true_actions)
                loss.backward()
                self._optimizer.step()
                epoch_train_loss += loss.item() / self.n_steps

            train_loss.append(epoch_train_loss)
            end_time = time.time()
            metrics.append(
                {
                    "train/victim_loss": train_loss[-1],
                    "train/victim_time": end_time - start_time,
                }
            )

        # self.model.test = True

        return metrics, {}


class BanditOptimalController(BanditController):
    optimal_actions: Tensor

    def __init__(self, n_envs: int, n_steps: int, n_actions: int, optimal_actions: Tensor, device=None):
        super().__init__(n_envs, n_steps, n_actions, device=device)
        self.n_envs = n_envs
        self.n_steps = n_steps
        self.n_actions = n_actions

        self.optimal_actions = optimal_actions

    def sample_actions(self) -> Tensor:
        return self.optimal_actions


class BanditThompsonSamplingController(BanditController):
    means: Tensor
    variances: Tensor
    counts: Tensor
    arm_totals: Tensor

    def __init__(self, n_envs: int, n_steps: int, n_actions: int, device=None, std=0.1, sample=False, prior_mean=0.5, prior_var=1 / 12.0):
        super().__init__(n_envs, n_steps, n_actions, device=device)

        self.variance = std**2
        self.prior_mean = prior_mean
        self.prior_variance = prior_var
        self.sample = sample

        self.clear_dataset()

    def clear_dataset(self) -> None:
        self.arm_totals = torch.ones((self.n_envs, self.n_actions), device=self.device)

        self.means = self.prior_mean * torch.ones((self.n_envs, self.n_actions), device=self.device)
        self.variances = self.prior_variance * torch.ones((self.n_envs, self.n_actions), device=self.device)
        self.counts = torch.zeros((self.n_envs, self.n_actions), device=self.device)

    def append(self, actions: Tensor, rewards: Tensor, rewards_original: Tensor) -> None:
        self.counts += actions
        self.arm_totals[torch.arange(self.n_envs), actions.argmax(dim=-1)] += rewards

        # Update posterior
        prior_weight = self.variance / (self.variance + (self.counts * self.prior_variance))
        new_means = prior_weight * self.prior_mean + (1 - prior_weight) * (self.arm_totals / self.counts)
        new_variances = 1 / (1 / self.prior_variance + self.counts / self.variance)

        mask = self.counts > 0
        self.means[mask] = new_means[mask]
        self.variances[mask] = new_variances[mask]

    def sample_actions(self) -> Tensor:
        if self.sample:
            values = self.means + torch.randn(self.means.shape, device=self.means.device) * torch.sqrt(self.variances)
        else:
            values = self.means
            raise NotImplementedError()
        chosen_action = torch.argmax(values, dim=-1)

        return one_hot(chosen_action, self.n_actions).float()


class BanditRobustThompsonSamplingController(BanditThompsonSamplingController):
    """Robust Thompson Sampling algorithm for multi-armed bandits (Xu et al., 2024)"""

    corruption_level: float

    def __init__(self, n_envs: int, n_steps: int, n_actions: int, corruption_level: float, device=None, std=0.1, sample=False, prior_mean=0.5, prior_var=1 / 12):
        self.corruption_level = corruption_level
        super().__init__(n_envs, n_steps, n_actions, device, std, sample, prior_mean, prior_var)

    def sample_actions(self) -> Tensor:
        if self.sample:
            values = self.corruption_level / (self.counts + 1) + self.means + torch.randn(self.means.shape, device=self.means.device) * torch.sqrt(self.variances)
        else:
            values = self.corruption_level / (self.counts + 1) + self.means
            raise NotImplementedError()

        chosen_action = torch.argmax(values, dim=-1)

        return one_hot(chosen_action, self.n_actions).float()


class BanditUCBController(BanditController):
    counts: Tensor
    arm_totals: Tensor

    def __init__(self, n_envs: int, n_steps: int, n_actions: int, const: float = 1.0, *, device=None):
        super().__init__(n_envs, n_steps, n_actions, device)

        self.const = const
        self.clear_dataset()

    def clear_dataset(self) -> None:
        self.arm_totals = torch.ones((self.n_envs, self.n_actions), device=self.device)
        self.counts = torch.zeros((self.n_envs, self.n_actions), device=self.device)

    def append(self, actions: Tensor, rewards: Tensor, rewards_original: Tensor) -> None:
        self.counts += actions
        self.arm_totals[torch.arange(self.n_envs), actions.argmax(dim=-1)] += rewards

    def sample_actions(self):
        empirical_mean = self.arm_totals / torch.maximum(torch.ones(1, device=self.device), self.counts)

        bonus = self.const / torch.maximum(torch.ones(1, device=self.device), torch.sqrt(self.counts))
        augmented_mean = empirical_mean + bonus

        chosen_action = torch.argmax(augmented_mean, dim=-1)
        unchosen_actions = torch.argmin(self.counts, dim=-1)
        envs_with_unvisited_actions = self.counts[torch.arange(self.n_envs), unchosen_actions] == 0
        chosen_action[envs_with_unvisited_actions] = unchosen_actions[envs_with_unvisited_actions]

        return one_hot(chosen_action, self.n_actions).float()


class BanditCrUCBController(BanditController):
    counts: Tensor
    arm_totals: Tensor
    step: int

    flag_p: bool
    const: float
    alpha_frac: float

    def __init__(self, n_envs: int, n_steps: int, n_actions: int, alpha_frac: float, const: float = 1.0, *, flag_p: bool = False, device=None):
        super().__init__(n_envs, n_steps, n_actions, device)

        self.flag_p = flag_p
        self.const = const
        self.alpha_frac = alpha_frac
        self.clear_dataset()

    def clear_dataset(self) -> None:
        self.arm_totals = torch.zeros((self.n_envs, self.n_actions, self.n_steps), device=self.device)
        self.counts = torch.zeros((self.n_envs, self.n_actions), device=self.device, dtype=torch.long)
        self.step = 0

    def append(self, actions: Tensor, rewards: Tensor, rewards_original: Tensor) -> None:
        self.arm_totals[torch.arange(self.n_envs), actions.argmax(dim=-1), self.counts[torch.arange(self.n_envs), actions.argmax(dim=-1)]] += rewards
        self.counts += actions.long()
        self.step += 1

    def estimate_mean(self) -> tuple[Tensor, Tensor]:
        means = torch.zeros_like(self.counts, dtype=torch.float)
        modified_counts = self.counts.clone()

        for env in range(self.n_envs):
            for action in range(self.n_actions):
                times_arm_observed = self.counts[env, action].item()
                elements_to_trim = math.ceil(times_arm_observed * self.alpha_frac)
                if times_arm_observed == 0 or 2 * elements_to_trim >= times_arm_observed:
                    modified_counts[env, action] = 0
                    continue
                self.arm_totals[env, action, :times_arm_observed] = torch.sort(self.arm_totals[env, action, :times_arm_observed]).values
                means[env, action] = torch.mean(self.arm_totals[env, action, elements_to_trim : times_arm_observed - elements_to_trim])
                modified_counts[env, action] -= 2 * elements_to_trim

        return means, modified_counts

    def estimate_mean_original(self) -> Tensor:
        return self.arm_totals.sum(-1) / torch.maximum(torch.ones(1, device=self.device), self.counts)

    def sample_actions(self):
        empirical_mean, modified_counts = self.estimate_mean()

        log_step = math.log(self.step) if self.step != 0 else 0
        if self.flag_p:
            bonus = self.const * torch.sqrt(4 * log_step / torch.maximum(torch.ones_like(modified_counts) * 0.001, modified_counts))
        else:
            bonus = self.const / (1 - 2 * self.alpha_frac) * torch.sqrt(4 * log_step / self.counts)
        augmented_mean = empirical_mean + bonus

        chosen_actions = torch.argmax(augmented_mean, dim=-1)
        unchosen_actions = torch.argmin(modified_counts, dim=-1)

        envs_with_unvisited_actions = modified_counts[torch.arange(self.n_envs), unchosen_actions] == 0
        chosen_actions[envs_with_unvisited_actions] = unchosen_actions[envs_with_unvisited_actions]

        return one_hot(chosen_actions, self.n_actions).float()
