import gin
import torch
from torch import nn

from rgfn.api.objective_base import ObjectiveBase, ObjectiveOutput
from rgfn.api.policy_base import PolicyBase
from rgfn.api.trajectories import Trajectories, TrajectoriesContainer
from rgfn.api.type_variables import TAction, TActionSpace, TState
from rgfn.gfns.reaction_gfn.api.reaction_api import (
    ReactionState0,
    ReactionStateTerminal,
)
from rgfn.gfns.reaction_gfn.proxies.path_cost_proxy import PathCostProxy


@gin.configurable()
class TrajectoryBalanceObjectiveReinforce(ObjectiveBase[TState, TActionSpace, TAction]):
    def __init__(
        self,
        path_cost_proxy: PathCostProxy,
        forward_policy: PolicyBase[TState, TActionSpace, TAction],
        backward_policy: PolicyBase[TState, TActionSpace, TAction],
        detach_backward: bool = True,
        cost_beta: float = 1.0,
        cost_alpha: float = 1.0,
        backward_loss_weight: float = 1.0,
        entropy_weight: float = 1.0,
        z_dim: int = 16,
    ):
        super().__init__(forward_policy=forward_policy, backward_policy=backward_policy)
        self.forward_logZ = nn.Parameter(torch.ones(z_dim) * 150.0 / 64, requires_grad=True)
        self.mol_fingerprint_size = 2048
        self.detach_backward = detach_backward
        self.cost_beta = cost_beta
        self.cost_alpha = cost_alpha
        self.path_cost_proxy = path_cost_proxy
        self.entropy_weight = entropy_weight
        self.backward_loss_weight = backward_loss_weight
        self.forward_logZ = nn.Parameter(torch.ones(z_dim) * 150.0 / 64, requires_grad=True)

    def _compute_log_probs(
        self, trajectories: Trajectories[TState, TActionSpace, TAction], forward: bool
    ):
        actions = trajectories.get_actions_flat()  # [n_actions]
        index = trajectories.get_index_flat().to(self.device)  # [n_actions]
        if forward:
            log_prob, entropy = self.forward_policy.compute_action_log_probs(
                states=trajectories.get_non_last_states_flat(),
                action_spaces=trajectories.get_forward_action_spaces_flat(),
                actions=actions,
            )  # [n_actions]
        else:
            log_prob, entropy = self.backward_policy.compute_action_log_probs(
                states=trajectories.get_non_source_states_flat(),
                action_spaces=trajectories.get_backward_action_spaces_flat(),
                actions=actions,
            )
        log_prob = torch.scatter_add(
            input=torch.zeros(size=(len(trajectories),), dtype=torch.float32, device=self.device),
            index=index,
            src=log_prob,
            dim=0,
        )
        entropy = torch.scatter_add(
            input=torch.zeros(size=(len(trajectories),), dtype=torch.float32, device=self.device),
            index=index,
            src=entropy,
            dim=0,
        )
        return log_prob, entropy

    def compute_objective_output(
        self, trajectories_container: TrajectoriesContainer[TState, TActionSpace, TAction]
    ) -> ObjectiveOutput:
        """
        Compute the objective output on a batch of trajectories.

        Args:
            trajectories_container: the batch of trajectories obtained in the sampling process. It contains the states, actions,
                action spaces in forward and backward directions, and rewards. Other important quantities (e.g. log
                probabilities of taking actions in forward and backward directions) should be assigned in this method
                using appropriate methods (e.g. assign_log_probs).

        Returns:
            The output of the objective function, containing the loss and possibly some metrics.
        """

        all_trajectories = Trajectories.from_trajectories(
            [
                trajectories_container.replay_trajectories,
                trajectories_container.forward_trajectories,
                trajectories_container.backward_trajectories,
            ]
        )

        # prepare masks for forward and backward trajectories
        replay_mask = torch.zeros(len(all_trajectories), dtype=torch.bool, device=self.device)
        replay_mask[: len(trajectories_container.replay_trajectories)] = True
        backward_mask = torch.zeros(len(all_trajectories), dtype=torch.bool, device=self.device)
        backward_mask[-len(trajectories_container.backward_trajectories) :] = True
        valid_for_forward_mask = torch.tensor(
            [
                isinstance(state, ReactionState0)
                for state in all_trajectories.get_source_states_flat()
            ],
            dtype=torch.bool,
            device=self.device,
        )
        valid_for_backward_mask = torch.tensor(
            [
                isinstance(state, ReactionStateTerminal)
                for state in all_trajectories.get_last_states_flat()
            ],
            dtype=torch.bool,
            device=self.device,
        )

        forward_trajectories_mask = ~backward_mask & valid_for_forward_mask
        backward_trajectories_mask = ~replay_mask & valid_for_backward_mask

        # compute the log probs for forward and backward trajectories
        forward_trajectories = all_trajectories.masked_select(forward_trajectories_mask)
        forward_log_prob_for_forward, _ = self._compute_log_probs(
            forward_trajectories, forward=True
        )

        backward_log_prob_for_all, entropies_for_all = self._compute_log_probs(
            all_trajectories, forward=False
        )
        backward_log_prob_for_forward = backward_log_prob_for_all[forward_trajectories_mask]
        backward_log_prob_for_backward = backward_log_prob_for_all[backward_trajectories_mask]
        entropies_for_backward = entropies_for_all[backward_trajectories_mask]

        # compute the forward loss
        forward_output_log_flow = (
            forward_trajectories.get_reward_outputs().log_reward
        )  # [n_trajectories]
        forward_input_log_flow = self.forward_logZ.sum()
        forward_loss = (
            forward_input_log_flow
            + forward_log_prob_for_forward
            - (
                backward_log_prob_for_forward.detach()
                if self.detach_backward
                else backward_log_prob_for_forward
            )
            - forward_output_log_flow
        )
        forward_loss = forward_loss.pow(2).mean()

        # compute the backward loss
        backward_trajectories = all_trajectories.masked_select(backward_trajectories_mask)
        backward_reward = (
            torch.exp(
                -torch.tensor(backward_trajectories.get_costs(), device=self.device)
                * self.cost_beta
            )
            * self.cost_alpha
        )
        backward_valid_mask = torch.tensor(
            [
                isinstance(state, ReactionState0)
                for state in backward_trajectories.get_source_states_flat()
            ],
            dtype=torch.bool,
            device=self.device,
        )
        backward_reward = torch.masked_fill(backward_reward, ~backward_valid_mask, -1)
        backward_loss = -(
            backward_log_prob_for_backward * backward_reward
        ).mean() + self.entropy_weight * (-entropies_for_backward.mean())

        # compute the final loss and some metrics
        loss = forward_loss + backward_loss * self.backward_loss_weight
        metrics = {
            "forward_logZ": forward_input_log_flow.sum().item(),
            "forward_loss": forward_loss.item(),
            "backward_loss": backward_loss.item(),
            "backward_reward": backward_reward.mean().item(),
        }

        return ObjectiveOutput(loss=loss, metrics=metrics)
