import random

import gym
import torch
import torch.nn.functional
from einops import rearrange, reduce

from policy.base import BasePolicy


def custom_decay_schedule(epoch: int):
    if epoch == 0:
        epsilon = 1.0
    elif epoch == 1:
        epsilon = 0.8
    elif epoch == 2:
        epsilon = 0.6
    elif epoch == 3:
        epsilon = 0.4
    elif epoch == 4:
        epsilon = 0.2
    else:
        epsilon = 0.1
    return epsilon


def decayed_epsilon_greedy(act_logits, action_space, n_backprop_steps=None, n_train_epochs=None):
    """
    epsilon-greedy - chance to take uniformly-sampled action
    """

    n_epochs = int(n_train_epochs)

    epsilon = custom_decay_schedule(n_epochs)

    # FIXME: random should be uniquely determined for each batch/time
    if random.random() < epsilon:
        act_idx = torch.randint(action_space.n, act_logits.size()[:2])
    else:
        act_idx = torch.argmax(act_logits, -1)
    act = act_idx + action_space.start
    return act


class DQNPolicy(BasePolicy):

    def __init__(self,
                 net_class: type[torch.nn.Module],
                 net_kwargs: dict,
                 gamma: float,
                 double: bool,
                 learn_batch_size: int,
                 sync_freq: int,
                 action_space: gym.spaces.Discrete,
                 learning_rate: float):
        super().__init__()

        self.gamma = gamma
        self.double = double
        self._learn_batch_size = learn_batch_size
        self.sync_freq = sync_freq
        self.action_space = action_space
        self.learning_rate = learning_rate

        self.net = net_class(**net_kwargs, action_space=action_space)  # noqa
        self._target_net = net_class(**net_kwargs, action_space=action_space)  # noqa
        self._sync_target_weights()

        # bind optimizer to a training policy
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=learning_rate)

        self._n_backprop_steps = 0

        self.explore_mode = False  # eval

    @property
    def ON_POLICY(self) -> bool:
        return False

    @property
    def REPLAY_BUFFER_CAPACITY(self) -> int:
        return 500

    @property
    def LEARN_BATCH_SIZE(self) -> int:
        return self._learn_batch_size

    @property
    def n_backprop_steps(self) -> int:
        return self._n_backprop_steps

    def _sync_target_weights(self):
        self._target_net.load_state_dict(self.net.state_dict())
        self._target_net.eval()

    def _q_values(self, states, act_idxs, h):
        # compute Q(s_0...s_t, a) for all t
        # states.shape: (batch, time, feature)
        # act_idxs.shape: (batch, time)

        act_logits, h, penalties = self.net(states, h)
        act_idxs = act_idxs.unsqueeze(-1)
        state_action_values = act_logits.gather(-1, act_idxs).squeeze(-1)
        return state_action_values, h, penalties

    def _target_q_values(self, next_states, rewards, policy_h, target_h, terminal_valid):
        # computes target Q-values y_t, for all t:
        # if double:
        #   y_t = r_t + gamma * Q_{target}(s_0...s_{t+1}, argmax_a Q_{policy}(s_0...s_{t+1}, a))
        # else:
        #   y_t = r_t + gamma * max_a Q_{target}(s_0...s_{t+1}, a)

        with torch.no_grad():
            # FIXME: handle ragged batch of trajectory lengths
            target_act_logits, target_h, _ = self._target_net(next_states, target_h)

            if self.double:
                policy_act_logits, policy_h, _ = self.net(next_states, policy_h)
                max_act_idxs = policy_act_logits.argmax(-1, keepdim=True)
            else:
                max_act_idxs = target_act_logits.argmax(-1, keepdim=True)

            next_state_values = target_act_logits.gather(-1, max_act_idxs).squeeze(-1)  # V(s_{t+1})

            if not terminal_valid:
                # append value of 0 for terminal state
                next_state_values = torch.cat([next_state_values, torch.zeros([1, 1])], -1)

            target_state_action_values = (next_state_values * self.gamma) + rewards

            return target_state_action_values, policy_h, target_h

    def explore(self, obs, h) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
        # off-policy - detaches tensors
        self.net.train()

        act_values, h, penalties = self.net(obs, h)

        # sample random action
        act = decayed_epsilon_greedy(
            act_values,
            self.action_space,
            n_backprop_steps=self._n_backprop_steps,
            n_train_epochs=self.n_backprop_steps / self.LEARN_BATCH_SIZE
        )

        return act.detach(), h, act_values.detach(), None, penalties

    def greedy(self, obs, h) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, tuple[torch.Tensor]]:
        self.net.eval()

        with torch.no_grad():
            act_values, h, penalties = self.net(obs, h)

            # greedy action
            act_idx = torch.argmax(act_values, -1)
            act = act_idx + self.action_space.start
        return act, h, act_values, None, penalties

    def learn(self, memory) -> tuple[dict[str, float], dict[str, float]]:
        self.net.train()

        traj = memory.sample(1)[0]

        states = rearrange([step.state for step in traj], 't ... -> 1 t ...')
        rewards = rearrange([step.reward for step in traj], 't -> 1 t')
        acts = rearrange([step.act for step in traj], 't -> 1 t')
        act_idxs = acts - self.action_space.start

        if traj.done_reason == "timeout":
            terminal_valid = True
            next_states = rearrange([step.next_state for step in traj], 't ... -> 1 t ...')
        else:
            # remove invalid terminal next_state
            terminal_valid = False
            next_states = rearrange([step.next_state for step in traj[:-1]], 't ... -> 1 t ...')

        h = None
        init_state = states[:, (0,), ...]
        _, policy_h, _ = self.net(init_state, None)
        _, target_h, _ = self._target_net(init_state, None)

        # compute Q values
        q_values, h, activation_penalties = self._q_values(states, act_idxs, h)

        # compute activation penalty
        activation_penalties = reduce(activation_penalties, "b t p -> t p", "mean")
        activation_penalties = reduce(activation_penalties, "t p -> p", "sum")
        activation_penalties_info = {
            "encoder_activation": activation_penalties[0].detach().cpu().item(),
            "memory_activation": activation_penalties[1].detach().cpu().item()
        }
        activation_penalty = reduce(activation_penalties, "p -> ", "sum")

        # compute weight penalty
        weight_penalty, weight_penalties_info = self.net.compute_weight_penalty()

        # compute target Q values
        target_q_values, policy_h, target_h = self._target_q_values(next_states, rewards, policy_h, target_h, terminal_valid)

        # use Huber loss instead of clipping grad to [-1,1]
        step_losses = torch.nn.functional.huber_loss(q_values, target_q_values, reduction="none")
        # step_losses = torch.nn.functional.mse_loss(q_values, target_q_values, reduction="none")

        data_loss = step_losses.sum(-1)  # sum over time

        self.optimizer.zero_grad()
        (data_loss + weight_penalty + activation_penalty).backward()

        # self.net.net.weight.grad[:, :] = 0.0
        # self.net.net.bias.grad[:] = 0.0

        self.optimizer.step()
        self._n_backprop_steps += 1

        loss_info = {
            "total": data_loss.detach().cpu().item()
        }

        # self.net.net.weight.data[:, 0] = torch.tensor([0, 0, 9])
        # self.net.net.weight.data[:, 1] = torch.tensor([0, 9, 0])
        # self.net.net.weight.data[:, 2:7] = 0.0
        # self.net.net.weight.data[:, 7] = torch.tensor([-1, 0.1, 0.1])
        # self.net.net.weight.data[:, 8] = 0.0
        # self.net.net.bias.data[:] = torch.tensor([0, -0.1, -0.1])

        # if i % 10 == 0:
        #     acts = act_idxs + self.action_space.start
        #     out = ""
        #     out += "\n" * 30
        #     with pd.option_context('display.float_format', '{:.3f}'.format):
        #         df = pd.DataFrame({
        #                               f'ctx{i}': v for i, v in
        #                               enumerate(rearrange(saved_ctx, "1 t f -> f t").numpy())
        #                           } | {
        #                               'act': rearrange(acts, "1 t -> t").detach(),
        #                               'q': rearrange(q_values, "1 t -> t").detach(),
        #                               "exp_q": rearrange(expected_q_values, "1 t -> t").detach(),
        #                               "loss": rearrange(traj_loss, "1 t -> t").detach()
        #                           })
        #
        #         # df = df.round({f'ctx{i}': 2 for i in range(9)} | {'q': 3, 'exp_q': 3, 'loss': 3})
        #
        #         df = df.rename(columns={"ctx0": "0-2",
        #                                 "ctx1": "2-0",
        #                                 "ctx2": "1-2",
        #                                 "ctx3": "2-1",
        #                                 "ctx4": "0-1",
        #                                 "ctx5": "1-0",
        #                                 "ctx6": "z0",
        #                                 "ctx7": "z1",
        #                                 "ctx8": "z2"})
        #
        #         out += f"{df}\n"
        #
        #     rowlabels = [" 0-2", " 2-0", " 1-2", " 2-1", " 0-1", " 1-0", "  z0", "  z1", "  z2"]
        #     out += "\n"
        #     out += f"      Forward | Right | Left   \n"
        #     for row, lab in zip(self.net.net.weight.detach().numpy().T, rowlabels, strict=True):
        #         out += f"{lab} {row}\n"
        #     out += f"bias {self.net.net.bias.detach().numpy().T}\n"
        #     # out += f"bias grad {self.net.bias.grad.detach().numpy().T}\n"
        #
        #     pbar.write(out)

        # update the target model
        if self._n_backprop_steps % self.sync_freq == 0:
            self._sync_target_weights()

        return loss_info, weight_penalties_info | activation_penalties_info
