from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Batch

from rl.agents.replay_buffer import BipartiteNodeData
from rl.networks.loss import TreeHLGaussLoss


def pad_tensor(input_: torch.Tensor, pad_sizes: torch.Tensor):
    """
    Pad tensor with zeros. Breaks torch.geometric Batch data into seperate
    samples.
    TODO: Modify to make it compatible with new NN!
    """
    input_dim = input_.dim()
    is_logits = input_dim > 1
    # last_dimension_size = input_.shape[-1]
    ## Adjust !
    pad_value = -1e9  # if not is_logits else last_dimension_size**-1
    max_pad_size = pad_sizes.max()
    output = input_.split(pad_sizes.cpu().numpy().tolist())
    if is_logits:
        output = torch.stack(
            [
                F.pad(slice_, (0, 0, 0, max_pad_size - slice_.size(0)), "constant", pad_value)
                for slice_ in output
            ],
            dim=0,
        )
    else:
        output = torch.stack(
            [F.pad(slice_, (0, max_pad_size - slice_.size(0)), "constant", pad_value) for slice_ in output],
            dim=0,
        )
    return output


class TreeDQNAgent:
    """DQN Agent interacting with environment.
    Version : DDQN + n-step learning + PER

    Attribute:
        target_hard_update (int): period for target model's hard update
        target_soft_update (int): weight for target model's soft update
        value_network (nn.Module): model to train and select actions
        dqn_target (nn.Module): target model to update
        classification (bool): whether or not to use HL-Gauss
    """

    def __init__(
        self,
        value_network: nn.Module,
        target_network: nn.Module,
        target_hard_update: int = 0,
        target_soft_update: float = 1e-4,
        train: bool = True,
        classification: bool = False,
    ):
        """Initialization.

        Args:
            env (gym.Env): openAI Gym environment
            memory_size (int): length of memory
            target_hard_update (int): period for target model's hard update
            lr (float): learning rate
            gamma (float): discount factor
            alpha (float): determines how much prioritization is used
            beta (float): determines how much importance sampling is used
            prior_eps (float): guarantees every transition can be sampled
            n_step (int): step number to calculate n-step td error
        """

        # device: cpu / gpu
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.device = torch.device("cpu")
        self.train = train

        # networks: dqn, dqn_target
        self.dqn = value_network
        self.dqn_target = target_network
        self.dqn_target.load_state_dict(self.dqn.state_dict())
        self.dqn.train() if self.train else self.dqn.eval()
        self.dqn_target.eval()
        self.dqn.to(self.device)
        self.dqn_target.to(self.device)

        # target update
        self.target_hard_update = target_hard_update
        self.target_soft_update = target_soft_update

        # Q_function
        self.classification = classification
        if self.classification:
            min_value, max_value, sigma = -1.5, 16.5, 0.75
            num_bins = self.dqn.num_bins
            self.loss_fn = TreeHLGaussLoss(min_value, max_value, num_bins, sigma, self.device)
            self.last_dim = num_bins

    def _tensorize_state(self, state: Tuple[np.ndarray]) -> Tuple[torch.Tensor]:
        constraint_features, edge_indices, edge_features, variable_features = state
        edge_features = np.expand_dims(edge_features, axis=-1)
        tensor_state = (
            torch.FloatTensor(constraint_features).to(self.device),
            torch.LongTensor(edge_indices).to(self.device),
            torch.FloatTensor(edge_features).to(self.device),
            torch.FloatTensor(variable_features).to(self.device),
        )

        return tensor_state

    def _tensorize_state_batch(self, state_batch: np.ndarray[object]) -> torch_geometric.data.Batch:
        batch = []
        for sample in state_batch:
            constraint_features, edge_indices, edge_features, variable_features, action_set = sample
            edge_features = np.expand_dims(edge_features, axis=-1)
            graph = BipartiteNodeData(
                torch.FloatTensor(constraint_features),
                torch.LongTensor(edge_indices),
                torch.FloatTensor(edge_features),
                torch.FloatTensor(variable_features),
                torch.LongTensor(action_set),
            )
            graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0]
            batch += [graph]

        return Batch.from_data_list(batch)

    def _compute_Q_value(self, state: Tuple[np.ndarray]) -> torch.Tensor:
        with torch.no_grad():
            tensor_state = self._tensorize_state(state)
            Q_value = self.dqn(*tensor_state)
            Q_value = self._to_value_from_logits(Q_value) if self.classification else Q_value
        return Q_value

    def select_greedy_action(self, state: Tuple[np.ndarray]) -> int:
        """Select an action from the input state."""
        state, action_set = state[:-1], torch.LongTensor(state[-1])
        with torch.no_grad():
            Q_value = self._compute_Q_value(state).cpu().detach()
            greedy_action = Q_value[action_set].argmax().numpy()
            action = action_set[greedy_action].item()
        return action

    def compute_dqn_loss(
        self,
        samples: Dict[str, object],
        gamma: float,
        n_step: bool,
    ) -> torch.Tensor:
        if self.bb_mdp:
            # DQN-BBMDP or DQN-TreeLDO
            return self._compute_tree_dqn_loss(samples, gamma, n_step)
        else:
            # retro branching DQN-Retro
            return self._compute_dqn_loss(samples, gamma, n_step)

    def _compute_dqn_loss(self, samples: Dict[str, object], gamma: float, n_step: bool) -> torch.Tensor:
        """Return MSE dqn loss."""
        device = self.device  # for shortening the following lines
        obs = samples["state"]
        next_obs = samples["next_state"] if not n_step else samples["n_step_next_state"]
        action = samples["action"]
        reward = samples["reward"] if not n_step else samples["n_step_reward"]
        done = samples["done"] if not n_step else samples["n_step_done"]
        state = self._tensorize_state_batch(obs).to(device)
        next_state = self._tensorize_state_batch(next_obs).to(device)
        action = torch.LongTensor(action.reshape(-1, 1)).to(device)
        reward = torch.FloatTensor(reward.reshape(-1, 1)).to(device)
        done = torch.LongTensor(done.reshape(-1, 1)).to(device)

        current_q_value = self.dqn(
            state.constraint_features,
            state.edge_index,
            state.edge_attr,
            state.variable_features,
        )
        next_q_value = self.dqn(
            next_state.constraint_features,
            next_state.edge_index,
            next_state.edge_attr,
            next_state.variable_features,
        )
        next_target_q_value = self.dqn_target(
            next_state.constraint_features,
            next_state.edge_index,
            next_state.edge_attr,
            next_state.variable_features,
        )

        current_q_value = pad_tensor(current_q_value, state.nb_variables).gather(1, action)
        next_q_value = pad_tensor(next_q_value[next_state.candidates], next_state.nb_candidates)
        next_target_q_value = pad_tensor(
            next_target_q_value[next_state.candidates], next_state.nb_candidates
        )
        next_actions = next_q_value.argmax(dim=1, keepdim=True)
        next_target_q_value = next_target_q_value.gather(1, next_actions).detach()

        mask = 1 - done
        target = torch.addcmul(reward, mask, next_target_q_value, value=gamma).to(self.device)

        # calculate element-wise dqn loss
        elementwise_loss = F.smooth_l1_loss(current_q_value, target, reduction="none")

        return elementwise_loss

    def _compute_tree_dqn_loss(self, samples: Dict[str, object], gamma: float, n_step: bool) -> torch.Tensor:
        """
        Return categorical dqn loss, if classification.
        Else, return MSE dqn loss.
        """
        device = self.device  # for shortening the following lines
        obs = samples["state"]
        next_obss = samples["next_state"] if not n_step else samples["n_step_next_state"]
        action = samples["action"]
        action_idx = samples["action_idx"]
        reward = samples["reward"] if not n_step else samples["n_step_reward"]
        done = samples["done"] if not n_step else samples["n_step_done"]
        children_mask = samples["children_mask"] if not n_step else samples["n_step_children_mask"]
        state = self._tensorize_state_batch(obs).to(device)
        action = torch.LongTensor(action.reshape(-1, 1)).to(device)
        action_idx = torch.LongTensor(action_idx.reshape(-1, 1)).to(device)
        reward = torch.FloatTensor(reward.reshape(-1, 1)).to(device)
        done = torch.LongTensor(done.reshape(-1, 1)).to(device)

        if self.classification:
            batch_size = action.size(0)
            action = action.view(-1, 1, 1).expand(batch_size, 1, self.last_dim)

        current_q_value = self.dqn(
            state.constraint_features,
            state.edge_index,
            state.edge_attr,
            state.variable_features,
        )
        current_q_value = pad_tensor(current_q_value, state.nb_variables).gather(1, action)

        if children_mask is not None:
            next_states = [self._tensorize_state_batch(next_obs).to(device) for next_obs in next_obss]
            children_mask = torch.tensor(children_mask).unsqueeze(-1).to(device)

            # Compute Q values for next states
            next_target_q_values = [
                self.dqn_target(
                    next_state.constraint_features,
                    next_state.edge_index,
                    next_state.edge_attr,
                    next_state.variable_features,
                ).detach()
                for next_state in next_states
            ]

            if self.classification:
                # Convert logits scores to actual Q values
                next_target_q_values = [
                    self._to_value_from_logits(next_target_q_value)
                    for next_target_q_value in next_target_q_values
                ]

            # Pad Q values according to the max number of candidate variables
            next_target_q_values = [
                pad_tensor(next_target_q_value[next_state.candidates], next_state.nb_candidates)
                for next_target_q_value, next_state in zip(next_target_q_values, next_states)
            ]

            # DDQN
            next_q_values = [
                self.dqn(
                    next_state.constraint_features,
                    next_state.edge_index,
                    next_state.edge_attr,
                    next_state.variable_features,
                ).detach()
                for next_state in next_states
            ]

            if self.classification:
                next_q_values = [self._to_value_from_logits(next_q_value) for next_q_value in next_q_values]

            next_q_values = [
                pad_tensor(next_q_value[next_state.candidates], next_state.nb_candidates)
                for next_q_value, next_state in zip(next_q_values, next_states)
            ]

            # Retrieve optimal actions according to value network
            next_actions = self._get_next_actions(next_q_values, next_states)

            # Deduce associated Q values according to the target network
            next_target_q_values = torch.stack(
                [
                    next_target_q_value.gather(1, next_action).detach()
                    for next_target_q_value, next_action in zip(next_target_q_values, next_actions)
                ]
            )

            # next_target_q_value = (next_target_q_values * children_mask.squeeze(-1)).sum(dim=0).unsqueeze(-1)
            next_target_q_value = (next_target_q_values * children_mask).sum(dim=0)
            mask = 1 - done
            target = torch.addcmul(reward, mask, next_target_q_value, value=gamma).to(device).detach()

        else:
            # if no children
            target = reward.detach()

        # calculate element-wise dqn loss
        if self.classification:
            assert target.max() <= 0
            current_q_value = current_q_value.squeeze(1)
            target = torch.log2(-target + 1e-8).squeeze(-1)
            assert target.min() >= self.loss_fn.min_value and target.max() <= self.loss_fn.max_value
            elementwise_loss = self.loss_fn(current_q_value, target)
        else:
            elementwise_loss = F.smooth_l1_loss(current_q_value, target, reduction="none")

        return elementwise_loss

    def _get_next_actions(
        self, next_q_values: List[torch.Tensor], next_states: List[torch.Tensor]
    ) -> torch.Tensor:
        device = self.device
        # Build masks
        padding_masks = [
            torch.arange(next_q_value.size(-1)).unsqueeze(0).to(device)
            < next_state.nb_candidates.unsqueeze(1)
            for next_q_value, next_state in zip(next_q_values, next_states)
        ]
        # Mask padded values to compute optimal actions !
        for pad_mask, next_q_value in zip(padding_masks, next_q_values):
            next_q_value[~pad_mask] = -1e9

        # Retrieve optimal actions according to value network
        next_actions = [next_q_value.argmax(dim=1, keepdim=True) for next_q_value in next_q_values]
        for next_action, next_state in zip(next_actions, next_states):
            assert (next_action.squeeze(-1) < next_state.nb_candidates).all()

        return next_actions

    def _to_value_from_logits(self, q_value: torch.Tensor):
        return -torch.pow(2, self.loss_fn.transform_from_probs(q_value.softmax(dim=-1)))

    def _target_soft_update(self):
        """Soft update: target <- target * (1 - tau) + local * tau"""
        tau = self.target_soft_update
        for target_param, value_param in zip(self.dqn_target.parameters(), self.dqn.parameters()):
            target_param.data.copy_(tau * value_param.data + (1.0 - tau) * target_param.data)

    def _target_hard_update(self):
        """Hard update: target <- local."""
        self.dqn_target.load_state_dict(self.dqn.state_dict())

    def _target_update(self, update_cnt: int):
        """Target update to perform : soft or hard"""
        if self.target_hard_update > 0 and update_cnt % self.target_hard_update == 0:
            self._target_hard_update()
        elif self.target_hard_update == 0:
            self._target_soft_update()

    def _save(self, save_path: str, n_epoch: int, best: bool = False, final: bool = True):
        if best:
            torch.save(self.dqn.state_dict(), f"{save_path}/trained_network_best.pkl")
        elif final:
            torch.save(self.dqn.state_dict(), f"{save_path}/trained_network_final.pkl")
            torch.save(self.dqn.state_dict(), f"{save_path}/trained_network_{n_epoch}.pkl")
