import gym
import torch
import torch.nn.functional
from einops import rearrange, reduce

from policy.base import BasePolicy


class A2CPolicy(BasePolicy):
    def __init__(self,
                 net_class: type[torch.nn.Module],
                 net_kwargs: dict,
                 gamma: float,
                 action_space: gym.spaces.Discrete,
                 learning_rate: float):
        super().__init__()

        self.gamma = gamma
        self.action_space = action_space
        self.learning_rate = learning_rate

        self.net = net_class(**net_kwargs, action_space=action_space)  # noqa

        # bind optimizer to a training policy
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=learning_rate)

        self._n_backprop_steps = 0

        # self.explore_mode = True  # train

    @property
    def ON_POLICY(self) -> bool:
        return True

    @property
    def REPLAY_BUFFER_CAPACITY(self) -> int:
        return 1

    @property
    def LEARN_BATCH_SIZE(self) -> int:
        return 1

    @property
    def n_backprop_steps(self) -> int:
        return self._n_backprop_steps

    def _get_expected_returns(self, rewards: torch.Tensor) -> torch.Tensor:
        rewards = rearrange(rewards, "b t -> t b")
        returns = []
        discounted_sum = 0.0
        for rew in rewards.flip(0):
            discounted_sum = rew + self.gamma * discounted_sum
            returns.append(discounted_sum)
        returns = torch.stack(returns[::-1], 1)
        return returns

    def explore(self, obs, h) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # on-policy - does not detach tensors
        self.net.train()

        act_values, state_value, h, penalties = self.net(obs, h)

        # sample random action
        act_probs = torch.softmax(act_values, -1)
        act_idx = torch.distributions.Categorical(probs=act_probs).sample()
        act = act_idx + self.action_space.start
        return act, h, act_values, state_value, penalties

    def greedy(self, obs, h) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        self.net.eval()

        with torch.no_grad():
            act_values, state_value, h, penalties = self.net(obs, h)

            # greedy action
            act_idx = torch.argmax(act_values, -1)
            act = act_idx + self.action_space.start
        return act, h, act_values, state_value, penalties

    def learn(self, memory) -> tuple[dict[str, float], dict[str, float]]:  # tuple[float, float, float]:
        self.net.train()

        traj = memory.sample(1)[0]  # memory capacity must be 1

        # states = rearrange([step.state for step in traj], 't ... -> 1 t ...')
        rewards = rearrange([step.reward for step in traj], 't -> 1 t')
        acts = rearrange([step.act for step in traj], 't -> 1 t')
        act_idxs = acts - self.action_space.start
        act_values = rearrange([step.act_logits for step in traj], 't a -> 1 t a')
        state_values = rearrange([step.state_value for step in traj], 't -> 1 t')

        # compute activation penalty
        activation_penalties = rearrange([step.penalties for step in traj], 't p -> t p')
        activation_penalties = reduce(activation_penalties, "t p -> p", "sum")
        activation_penalties_info = {
            "encoder_activation": activation_penalties[0].detach().cpu().item(),
            "memory_activation": activation_penalties[1].detach().cpu().item()
        }
        activation_penalty = reduce(activation_penalties, "p -> ", "sum")

        # compute weight penalty
        weight_penalty, weight_penalties_info = self.net.compute_weight_penalty()

        # compute expected returns
        returns = self._get_expected_returns(rewards)

        # compute actor loss
        advantage = returns - state_values
        act_probs = torch.softmax(act_values, -1)
        chosen_act_probs = act_probs.gather(-1, act_idxs.unsqueeze(-1)).squeeze(-1)
        actor_step_losses = -chosen_act_probs.log() * advantage
        actor_loss = actor_step_losses.sum(-1)  # sum over time

        # compute critic loss
        critic_step_losses = torch.nn.functional.mse_loss(state_values, returns, reduction="none")
        critic_loss = critic_step_losses.sum(-1)  # sum over time

        data_loss = actor_loss + critic_loss

        self.optimizer.zero_grad()
        (data_loss + weight_penalty + activation_penalty).backward()
        self.optimizer.step()
        self._n_backprop_steps += 1

        loss_info = {
            "total": data_loss.detach().cpu().item(),
            "actor": actor_loss.detach().cpu().item(),
            "critic": critic_loss.detach().cpu().item()
        }

        return loss_info, weight_penalties_info | activation_penalties_info

    # def forward(self, obs, h) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    #     act_values, state_value, h = self.net(obs, h)
    #
    #     act_idx = torch.argmax(act_values, -1)
    #     act = act_idx + self.action_space.start
    #
    #     return act, h, act_values, state_value
