import time
from typing import Any, Sequence

import torch
from torch import Tensor, nn
from torch._prims_common import DeviceLikeType

from args import AdaptiveAttackerConfig, AdversarialTrainingConfig
from bandit2.attacker_net import AttackerTransformer
from bandit2.bandit_dataset import BanditDataset, BanditDatasetTorch
from util.tensor import uniform_random


class BanditAttacker(nn.Module):
    initialized: bool
    device: DeviceLikeType | None

    def __init__(self, device=None, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.device = device
        self.initialized = False

    def get_reward(self, rewards_original: Tensor, actions: Tensor) -> Tensor:
        raise NotImplementedError()

    def append(self, actions: Tensor, rewards: Tensor, rewards_original: Tensor) -> None:
        pass

    def clear_dataset(self) -> None:
        pass

    def update(self, dataset: BanditDatasetTorch, adv_train_config: AdversarialTrainingConfig) -> list[dict[str, Any]]:
        return []


class BanditUniformRandomAttacker(BanditAttacker):
    n_envs: int
    n_arms: int
    variance: float
    means: Tensor

    def __init__(self, n_envs: int, n_arms: int, variance: float, max_poison_diff: float, scale_means: float = 1.0, device=None, dtype=None, *args, **kwargs):
        super().__init__(device, *args, **kwargs)

        self.n_envs = n_envs
        self.n_arms = n_arms
        self.variance = variance
        self.max_poison_diff = max_poison_diff

        self.means = uniform_random((self.n_envs, self.n_arms), scale_means, device=self.device, dtype=dtype)

        # bound max_poison_diff:
        difference = [d for d in self.means]
        for d in difference:
            d.grad = d.data  # fake grads for torch clip_grad_norm
            torch.nn.utils.clip_grad_norm_(d, max_poison_diff)
        self.means = torch.stack(difference, dim=0)

    def get_reward(self, rewards_original: Tensor, actions: Tensor) -> Tensor:
        selected_means = torch.einsum("ea,e...a->e...", self.means, actions)

        rewards_mod = torch.normal(selected_means, torch.ones_like(selected_means) * self.variance)
        return rewards_mod


class BanditNaiveAttacker(BanditAttacker):
    n_envs: int
    n_arms: int
    means: Tensor
    stds: Tensor
    means_original: Tensor
    _optimizer: torch.optim.Optimizer | None

    def __init__(self, n_envs: int, n_arms: int, initial_means: Tensor, lr: float | None = None, device=None, dtype=None, *args, **kwargs):
        super().__init__(device, *args, **kwargs)

        self.n_envs = n_envs
        self.n_arms = n_arms

        self.means = nn.Parameter(torch.zeros((self.n_envs, self.n_arms), device=self.device, dtype=dtype))
        self.stds = nn.Parameter(torch.ones((self.n_envs, self.n_arms), device=self.device, dtype=dtype))
        self.means_original = nn.Parameter(initial_means.clone().detach().to(device=self.device, dtype=dtype))
        self.means_original.requires_grad = False

        self._optimizer = None
        if lr is not None:
            self._optimizer = torch.optim.Adam(self.parameters(), lr=lr)

    def get_reward(self, rewards_original: Tensor, actions: Tensor) -> Tensor:
        selected_means = torch.einsum("ea,e...a->e...", self.means, actions)

        selected_stds = torch.einsum("ea,e...a->e...", torch.abs(self.stds), actions)
        rewards = torch.normal(selected_means, selected_stds)
        return rewards

    def get_probability(self, realized_rewards: Tensor, actions: Tensor) -> Tensor:
        selected_means = torch.einsum("ea,e...a->e...", self.means, actions)
        selected_stds = torch.einsum("ea,e...a->e...", torch.abs(self.stds), actions)
        return -0.5 * torch.log(2 * torch.tensor(torch.pi)) - torch.log(selected_stds) - 0.5 * ((realized_rewards - selected_means) / selected_stds) ** 2

    def update(self, dataset: BanditDatasetTorch, adv_train_config: AdversarialTrainingConfig) -> list[dict[str, Any]]:
        assert dataset.rewards_original is not None
        assert self._optimizer is not None, "Optimizer is not initialized; please pass a suitable lr parameter during initialization"

        max_std = 1.0

        metrics = []

        for _ in range(adv_train_config.attacker_iters):
            start_time = time.time()
            attacker_rewards = -dataset.rewards_original

            log_policy_prob = self.get_probability(dataset.rewards, dataset.actions)

            r_dagger = self.means
            att_budget_max = adv_train_config.max_poison_diff
            att_budget_current = torch.norm(r_dagger, dim=-1)
            std_diff = torch.norm(self.stds, dim=-1)

            reinforce_loss = -(log_policy_prob * attacker_rewards).mean()
            regularization_loss = adv_train_config.budget_regularizer * torch.relu(att_budget_current - att_budget_max).sum()
            std_loss = adv_train_config.budget_regularizer * torch.relu(std_diff - max_std).sum()

            loss = reinforce_loss + regularization_loss + std_loss

            self._optimizer.zero_grad()
            loss.backward()
            self._optimizer.step()

            end_time = time.time()
            metrics.append(
                {
                    "train/attacker_loss": loss.item(),
                    "train/attacker_loss/budget_reg_loss": regularization_loss.item(),
                    "train/attacker_loss/std_reg_loss": regularization_loss.item(),
                    "train/attacker_loss/reinforce_loss": reinforce_loss.item(),
                    "train/attacker_time": end_time - start_time,
                    "train/attacker_avg_env_att_dist": att_budget_current.sum() / self.n_envs,
                }
            )

        return metrics


class BanditAdaptiveAttacker(BanditAttacker):
    _net: AttackerTransformer
    _optimizer: torch.optim.Optimizer

    n_envs: int
    n_actions: int

    dataset: BanditDataset

    def __init__(self, config: AdaptiveAttackerConfig, n_envs: int, n_actions: int, initial_means: Tensor, device=None, *args, **kwargs):
        super().__init__(device, *args, **kwargs)

        self.config = config
        self.n_actions = n_actions
        self.n_envs = n_envs

        transformer_config = config.get_params_unprefixed({"H": config.context_len, "state_dim": 1, "action_dim": n_actions})
        self._net = AttackerTransformer(transformer_config).to(device)
        self.dataset = BanditDataset(self.n_envs, config.context_len, n_actions, device=device)

        self._optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)

        # Just to be able to restore the env the attacker was trained on:
        self.means_original = nn.Parameter(initial_means.clone().detach().to(device=self.device))
        self.means_original.requires_grad = False

    def clear_dataset(self):
        self.dataset.clear()

    def append(self, actions: Tensor, rewards: Tensor, rewards_original):
        self.dataset.append(actions, rewards, rewards_original)

    def get_context(self) -> Tensor:
        context = self.dataset.get_context_for_transformer(with_rewards_original=True)
        if context.size(dim=1) == 0:
            context = torch.zeros((self.n_envs, 1, 2 + self.n_actions + 1), device=self.device)
        return context

    def get_reward(self, rewards_original: Tensor, actions: Tensor) -> Tensor:
        self._net.test = True
        context = self.get_context()
        query_line = self._net.make_query_line(torch.ones((self.n_envs, 1), device=self.device), self.n_envs)
        means, stds = self._net.predict_rewards(context, query_line)
        selected_means = torch.einsum("ea,e...a->e...", means, actions)

        selected_stds = torch.einsum("ea,e...a->e...", torch.abs(stds), actions)
        rewards = torch.normal(selected_means, selected_stds)
        self._net.test = False
        return rewards.detach()

    def get_means_stds(self):
        context = self.get_context()
        query_line = self._net.make_query_line(torch.ones((self.n_envs, 1), device=self.device), self.n_envs)
        means, stds = self._net.predict_rewards(context, query_line)
        return means, stds

    def get_probability(self, realized_rewards: Tensor, actions: Tensor, all_means: Tensor, all_stds: Tensor) -> Tensor:
        selected_means = (all_means * actions).sum(dim=-1)
        selected_stds = (all_stds * actions).sum(dim=-1) + 1e-5
        return -0.5 * torch.log(2 * torch.tensor(torch.pi, device=self.device)) - torch.log(selected_stds) - 0.5 * ((realized_rewards - selected_means) / selected_stds) ** 2

    def update(self, dataset: BanditDatasetTorch, adv_train_config: AdversarialTrainingConfig) -> list[dict[str, Any]]:
        assert dataset.rewards_original is not None
        max_std = 1.0

        metrics = []

        for _ in range(adv_train_config.attacker_iters):
            start_time = time.time()

            attacker_rewards = -dataset.rewards_original.detach()

            attacker_means, attacker_stds = self.get_means_stds()
            log_policy_prob = self.get_probability(dataset.rewards_original.detach(), dataset.actions, attacker_means, attacker_stds)

            att_budget_max = adv_train_config.max_poison_diff
            att_budget_current = torch.norm(attacker_means, dim=-1)
            std_diff = torch.norm(attacker_stds, dim=-1)
            reinforce_loss = -(log_policy_prob * attacker_rewards).mean()
            # del attacker_means, attacker_stds, log_policy_prob
            # torch.cuda.empty_cache()
            # gc.collect()

            regularization_loss = adv_train_config.budget_regularizer * torch.relu(att_budget_current - att_budget_max).sum()
            std_loss = adv_train_config.budget_regularizer * torch.relu(std_diff - max_std).sum()

            loss = reinforce_loss + regularization_loss + std_loss

            self._optimizer.zero_grad()
            loss.backward()
            self._optimizer.step()

            end_time = time.time()
            metrics.append(
                {
                    "train/attacker_loss": loss.item(),
                    "train/attacker_loss/budget_reg_loss": regularization_loss.item(),
                    "train/attacker_loss/std_reg_loss": regularization_loss.item(),
                    "train/attacker_loss/reinforce_loss": reinforce_loss.item(),
                    "train/attacker_time": end_time - start_time,
                    "train/attacker_avg_env_att_dist": att_budget_current.mean().item(),
                }
            )

        return metrics
