from typing import Dict

import gin
import numpy as np
import torch
from rdkit import DataStructs
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
from torch import nn

from rgfn.api.objective_base import ObjectiveBase, ObjectiveOutput
from rgfn.api.policy_base import PolicyBase
from rgfn.api.trajectories import Trajectories, TrajectoriesContainer
from rgfn.api.type_variables import TAction, TActionSpace, TState
from rgfn.gfns.reaction_gfn.api.data_structures import ParallelCache
from rgfn.gfns.reaction_gfn.api.reaction_api import (
    ReactionState0,
    ReactionStateTerminal,
)
from rgfn.gfns.reaction_gfn.policies.utils import one_hot
from rgfn.gfns.reaction_gfn.proxies.path_cost_proxy import PathCostProxy


@gin.configurable()
class DualTrajectoryBalanceObjective(ObjectiveBase[TState, TActionSpace, TAction]):
    def __init__(
        self,
        path_cost_proxy: PathCostProxy,
        forward_policy: PolicyBase[TState, TActionSpace, TAction],
        backward_policy: PolicyBase[TState, TActionSpace, TAction],
        max_num_reactions: int,
        backward_beta: float = 8.0,
        detach_backward: bool = False,
        backward_loss_weight: float = 1.0,
        invalid_log_flow: float = -12.0,
        z_dim: int = 1,
        backward_input_flow_with_mlp: bool = True,
    ):
        super().__init__(forward_policy=forward_policy, backward_policy=backward_policy)
        self.max_num_reactions = max_num_reactions
        self.forward_logZ = nn.Parameter(torch.ones(z_dim) * 150.0 / 64, requires_grad=True)
        self.mol_fingerprint_size = 2048
        self.detach_backward = detach_backward
        self.backward_input_log_flow_with_mlp = backward_input_flow_with_mlp
        if self.backward_input_log_flow_with_mlp:
            self.backward_logZ = nn.Sequential(
                nn.Linear(self.mol_fingerprint_size + self.max_num_reactions + 1, 32),
                nn.GELU(),
                nn.Linear(32, 1),
            )
        else:
            self.smiles_to_log_flow: Dict[str, float] = {}
        self.backward_beta = backward_beta
        self._fingerprint_cache = ParallelCache(max_size=50_000)
        self.path_cost_proxy = path_cost_proxy
        self.backward_loss_weight = backward_loss_weight
        self.invalid_log_flow = invalid_log_flow

    def _get_state_fingerprint(self, state: ReactionStateTerminal):
        item = (state.molecule.smiles, state.num_reactions)
        if item not in self._fingerprint_cache:
            fp = GetMorganFingerprintAsBitVect(
                state.molecule.rdkit_mol, radius=2, nBits=self.mol_fingerprint_size
            )
            array = np.zeros((0,), dtype=np.float32)
            DataStructs.ConvertToNumpyArray(fp, array)
            molecular_fingerprint = torch.tensor(array, dtype=torch.float32)
            num_reaction_fingerprints = torch.tensor(
                one_hot(state.num_reactions, self.max_num_reactions + 1)
            )
            fingerprint = torch.cat([molecular_fingerprint, num_reaction_fingerprints], dim=0)
            self._fingerprint_cache[item] = fingerprint
        return self._fingerprint_cache[item]

    def _get_state_log_flow(self, state: ReactionStateTerminal) -> float:
        if state.molecule.smiles not in self.smiles_to_log_flow:
            self.smiles_to_log_flow[state.molecule.smiles] = 1.0
        return self.smiles_to_log_flow[state.molecule.smiles]

    def compute_backward_input_log_flow(
        self, trajectories: Trajectories[TState, TActionSpace, TAction]
    ) -> torch.Tensor:
        if self.backward_input_log_flow_with_mlp:
            states = trajectories.get_last_states_flat()
            fingerprints = torch.stack([self._get_state_fingerprint(state) for state in states]).to(
                self.device
            )
            return self.backward_logZ(fingerprints)
        else:
            states = trajectories.get_last_states_flat()
            log_flows = torch.tensor(
                [self._get_state_log_flow(state) for state in states],
                dtype=torch.float32,
                requires_grad=True,
            )
            return log_flows

    def compute_backward_output_log_flow(
        self, trajectories: Trajectories[TState, TActionSpace, TAction]
    ) -> torch.Tensor:
        valid_mask = torch.tensor(
            [isinstance(state, ReactionState0) for state in trajectories.get_source_states_flat()],
            dtype=torch.bool,
            device=self.device,
        )
        if self.path_cost_proxy is not None:
            path_costs = trajectories.get_costs()
            log_flow = 1 / torch.tensor(path_costs, device=self.device, dtype=torch.float32)
        else:
            log_flow = torch.ones(len(trajectories), device=self.device, dtype=torch.float32)
        log_flow = log_flow * self.backward_beta
        log_flow[~valid_mask] = self.invalid_log_flow
        return log_flow

    def _compute_log_probs(
        self, trajectories: Trajectories[TState, TActionSpace, TAction], forward: bool
    ):
        actions = trajectories.get_actions_flat()  # [n_actions]
        index = trajectories.get_index_flat().to(self.device)  # [n_actions]
        if forward:
            log_prob, _ = self.forward_policy.compute_action_log_probs(
                states=trajectories.get_non_last_states_flat(),
                action_spaces=trajectories.get_forward_action_spaces_flat(),
                actions=actions,
            )  # [n_actions]
        else:
            log_prob, _ = self.backward_policy.compute_action_log_probs(
                states=trajectories.get_non_source_states_flat(),
                action_spaces=trajectories.get_backward_action_spaces_flat(),
                actions=actions,
            )
        log_prob = torch.scatter_add(
            input=torch.zeros(size=(len(trajectories),), dtype=torch.float32, device=self.device),
            index=index,
            src=log_prob,
            dim=0,
        )
        return log_prob

    def compute_objective_output(
        self, trajectories_container: TrajectoriesContainer[TState, TActionSpace, TAction]
    ) -> ObjectiveOutput:
        """
        Compute the objective output on a batch of trajectories.

        Args:
            trajectories_container: the batch of trajectories obtained in the sampling process. It contains the states, actions,
                action spaces in forward and backward directions, and rewards. Other important quantities (e.g. log
                probabilities of taking actions in forward and backward directions) should be assigned in this method
                using appropriate methods (e.g. assign_log_probs).

        Returns:
            The output of the objective function, containing the loss and possibly some metrics.
        """

        all_trajectories = Trajectories.from_trajectories(
            [
                trajectories_container.replay_trajectories,
                trajectories_container.forward_trajectories,
                trajectories_container.backward_trajectories,
            ]
        )

        # prepare masks for forward and backward trajectories
        replay_mask = torch.zeros(len(all_trajectories), dtype=torch.bool, device=self.device)
        replay_mask[: len(trajectories_container.replay_trajectories)] = True
        backward_mask = torch.zeros(len(all_trajectories), dtype=torch.bool, device=self.device)
        backward_mask[-len(trajectories_container.backward_trajectories) :] = True
        valid_for_forward_mask = torch.tensor(
            [
                isinstance(state, ReactionState0)
                for state in all_trajectories.get_source_states_flat()
            ],
            dtype=torch.bool,
            device=self.device,
        )
        valid_for_backward_mask = torch.tensor(
            [
                isinstance(state, ReactionStateTerminal)
                for state in all_trajectories.get_last_states_flat()
            ],
            dtype=torch.bool,
            device=self.device,
        )

        forward_trajectories_mask = ~backward_mask & valid_for_forward_mask
        backward_trajectories_mask = ~replay_mask & valid_for_backward_mask

        # compute the log probs for forward and backward trajectories
        forward_trajectories = all_trajectories.masked_select(forward_trajectories_mask)
        forward_log_prob_for_forward = self._compute_log_probs(forward_trajectories, forward=True)

        backward_log_prob_for_all = self._compute_log_probs(all_trajectories, forward=False)
        backward_log_prob_for_forward = backward_log_prob_for_all[forward_trajectories_mask]
        backward_log_prob_for_backward = backward_log_prob_for_all[backward_trajectories_mask]

        # compute the forward loss
        forward_output_log_flow = (
            forward_trajectories.get_reward_outputs().log_reward
        )  # [n_trajectories]
        forward_input_log_flow = self.forward_logZ.sum()
        if self.detach_backward:
            forward_loss = (
                forward_input_log_flow
                + forward_log_prob_for_forward
                - backward_log_prob_for_forward.detach()
                - forward_output_log_flow
            )
        else:
            forward_loss = (
                forward_input_log_flow
                + forward_log_prob_for_forward
                - backward_log_prob_for_forward
                - forward_output_log_flow
            )
        forward_loss = forward_loss.pow(2).mean()

        # compute the backward loss
        backward_trajectories = all_trajectories.masked_select(backward_trajectories_mask)
        backward_input_log_flow = self.compute_backward_input_log_flow(backward_trajectories)
        backward_output_log_flow = self.compute_backward_output_log_flow(backward_trajectories)
        backward_loss = (
            backward_input_log_flow + backward_log_prob_for_backward - backward_output_log_flow
        )
        backward_loss = backward_loss.pow(2).mean()

        # compute the final loss and some metrics
        loss = forward_loss + backward_loss * self.backward_loss_weight
        metrics = {
            "forward_logZ": forward_input_log_flow.sum().item(),
            "backward_logZ": backward_input_log_flow.sum().item(),
            "backward_mean_logR": backward_output_log_flow.mean().item(),
            "forward_loss": forward_loss.item(),
            "backward_loss": backward_loss.item(),
        }

        return ObjectiveOutput(loss=loss, metrics=metrics)
