import math
import time
from math import sqrt
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch._prims_common import DeviceLikeType
from torch.nn import init

from args import AdversarialTrainingConfig
from mdp.mdp_dataset import MDPDatasetImagesTorch, MDPDatasetTorch


class MDPAttacker(nn.Module):
    device: DeviceLikeType | None

    def __init__(self, device: DeviceLikeType | None) -> None:
        super().__init__()
        self.device = device

    def get_reward(self, rewards_original: Tensor, states: Tensor, actions: Tensor) -> Tensor:
        raise NotImplementedError()


class MDPNaiveAttacker(MDPAttacker):
    n_envs: int
    n_states: int
    n_actions: int

    def __init__(self, n_envs: int, n_states: int, n_actions: int, device: DeviceLikeType | None) -> None:
        super().__init__(device)

        self.n_envs = n_envs
        self.n_states = n_states
        self.n_actions = n_actions

        self.rewards_mean = nn.Parameter(torch.randn((self.n_envs, self.n_states, self.n_actions), device=self.device) / 100)
        self.rewards_std = nn.Parameter(torch.randn((self.n_envs, self.n_states, self.n_actions), device=self.device) / 100)

    def get_reward(self, rewards_original: Tensor, states: Tensor, actions: Tensor) -> Tensor:
        means_at_states = self.rewards_mean[torch.arange(states.shape[0])[:, None], states, :]
        selected_means = torch.einsum("esa,esa->es", means_at_states, actions)

        stds_at_states = self.rewards_std[torch.arange(states.shape[0])[:, None], states, :]
        selected_stds = torch.einsum("esa,esa->es", torch.abs(stds_at_states), actions)
        rewards = torch.normal(selected_means, selected_stds)
        return rewards

    def get_probability(self, states: Tensor, actions: Tensor, realized_rewards: Tensor) -> Tensor:
        means_at_states = self.rewards_mean[torch.arange(states.shape[0])[:, None], states, :]
        stds_at_states = self.rewards_std[torch.arange(states.shape[0])[:, None], states, :]

        selected_means = torch.einsum("esa,esa->es", means_at_states, actions)
        selected_stds = torch.einsum("esa,esa->es", torch.abs(stds_at_states), actions)

        return -0.5 * torch.log(2 * torch.tensor(torch.pi)) - torch.log(selected_stds) - 0.5 * ((realized_rewards - selected_means) / selected_stds) ** 2


class MDPGridRandomAttacker(MDPAttacker):
    n_envs: int
    square_len: int
    n_poisoned_states: int

    def __init__(self, n_envs: int, square_len: int, n_poisoned_states: float, device: DeviceLikeType | None = None) -> None:
        super().__init__(device)

        self.n_envs = n_envs
        self.square_len = square_len
        self.n_poisoned_states = math.floor(n_poisoned_states)

        self.reward_map = torch.zeros((n_envs, square_len, square_len), device=device)

        all_states = torch.stack(
            [
                torch.arange(square_len, device=device).repeat_interleave(square_len),
                torch.arange(square_len, device=device).repeat(square_len),
            ],
            dim=1,
        )
        poisoned_indices = torch.randint(0, all_states.shape[0], (n_envs, self.n_poisoned_states), device=device)
        attack = torch.randint(0, 2, (n_envs, self.n_poisoned_states), device=self.device).float()
        attack[attack == 0] = -1

        self.reward_map[torch.arange(n_envs)[:, None].repeat((1, self.n_poisoned_states)), all_states[poisoned_indices, 0], all_states[poisoned_indices, 1]] = attack

    def get_reward(self, rewards_original: Tensor, states: Tensor, actions: Tensor) -> Tensor:
        reward_weights = self.reward_map[torch.arange(self.n_envs)[:, None], states[:, :, 0], states[:, :, 1]]

        return reward_weights.unsqueeze(2)


class MDPGridClassifierBaseAttacker(MDPAttacker):
    n_envs: int
    square_len: int

    weights: Tensor

    _optimizer: torch.optim.Optimizer | None

    def __init__(self, n_envs: int, square_len: int, *, lr: float | None = None, device: DeviceLikeType | None = None) -> None:
        super().__init__(device)

        self.n_envs = n_envs
        self.square_len = square_len
        self.weights = nn.Parameter(torch.empty((self.n_envs, self.square_len, self.square_len, 3), device=self.device))
        init.kaiming_uniform_(self.weights, a=sqrt(5))
        self.weights.data[:, :, :, 1] += 0.5  # Bias initial attacker towards no attack (i.e., start with predicting 0)
        self.reward_values = torch.tensor([-1.0, 0.0, 1.0], device=self.device)

        self._optimizer = None
        if lr is not None:
            self._optimizer = torch.optim.Adam(self.parameters(), lr=lr)

    def get_reward(self, rewards_original: Tensor, states: Tensor, actions: Tensor) -> Tensor:
        reward_weights = self.weights[torch.arange(self.n_envs)[:, None], states[:, :, 0], states[:, :, 1], :]
        # reward_probabilities = F.softmax(reward_weights, dim=-1)
        # reward_indices = torch.multinomial(reward_probabilities.reshape(-1, self.reward_values.shape[0]), 1).reshape(reward_probabilities.shape[:2])
        reward_indices = reward_weights.argmax(-1)

        rewards = self.reward_values[reward_indices]
        return rewards[:, None]

    def get_probability(self, states: Tensor, actions: Tensor, realized_rewards: Tensor) -> Tensor:
        reward_weights = self.weights[torch.arange(self.n_envs)[:, None], states[:, :, 0], states[:, :, 1]]
        reward_probabilities = F.softmax(reward_weights, dim=-1)
        reward_indices = torch.bucketize(realized_rewards, self.reward_values)

        n_steps = states.shape[1]
        probabilities = reward_probabilities[torch.arange(self.n_envs)[:, None, None], torch.arange(n_steps)[None, :, None], reward_indices[..., None]]
        return torch.log(probabilities.squeeze(2) + 1e-10)

    def get_dataset_states(self, dataset: MDPDatasetTorch) -> Tensor:
        return dataset.states.int()

    def update(self, datasets: list[MDPDatasetTorch | MDPDatasetImagesTorch], adv_train_config: AdversarialTrainingConfig) -> list[dict[str, Any]]:
        if isinstance(datasets, (MDPDatasetTorch, MDPDatasetImagesTorch)):
            datasets = [datasets]

        assert datasets[-1].rewards_original is not None
        assert self._optimizer is not None, "Optimizer is not initialized; please pass a suitable lr parameter during initialization"

        metrics = []

        attacker_rewards = -datasets[-1].rewards_original.detach()  # only last episode reward is important?

        for _ in range(adv_train_config.attacker_iters):
            for dataset in datasets:
                start_time = time.time()

                log_policy_prob = self.get_probability(self.get_dataset_states(dataset), dataset.actions, (dataset.rewards - dataset.rewards_original).detach())

                r_dagger = self.get_reward(dataset.rewards_original, self.get_dataset_states(dataset), dataset.actions)
                att_budget_current = torch.norm(r_dagger, dim=-1)
                att_budget_max = adv_train_config.max_poison_diff

                # reinforce_loss = -(log_policy_prob * attacker_rewards_to_go).mean()
                reinforce_loss = -(log_policy_prob * attacker_rewards).mean()
                budget_constraint_loss = adv_train_config.budget_regularizer * torch.relu(att_budget_current - att_budget_max).sum()
                loss = reinforce_loss + budget_constraint_loss

                self._optimizer.zero_grad()
                loss.backward()
                self._optimizer.step()

                end_time = time.time()
                metrics.append(
                    {
                        "train/attacker_loss": loss.item(),
                        "train/attacker_loss/reinforce_loss": reinforce_loss.item(),
                        "train/attacker_loss/budget_reg_loss": budget_constraint_loss.item(),
                        "train/attacker_time": end_time - start_time,
                        "train/attacker_avg_env_att_dist": att_budget_current.sum() / self.n_envs,
                    },
                )

        return metrics


class MDPGridClassifierAttacker(MDPGridClassifierBaseAttacker):
    _original_reward_map: Tensor
    _original_goals: Tensor

    def __init__(self, original_reward_map: Tensor, original_goals: Tensor, n_envs: int, square_len: int, *, lr: float | None = None, device: DeviceLikeType | None = None) -> None:
        super().__init__(n_envs, square_len, lr=lr, device=device)

        self._original_reward_map = nn.Parameter(original_reward_map.clone().detach().float().to(device=self.device))
        self._original_reward_map.requires_grad_(False)
        self._original_goals = nn.Parameter(original_goals.clone().detach().float().to(device=self.device))
        self._original_goals.requires_grad_(False)
