import gin
import torch

from rgfn.api.trajectories import Trajectories
from rgfn.gfns.reaction_gfn.api.reaction_api import (
    ReactionAction,
    ReactionActionSpace,
    ReactionState,
    ReactionState0, ReactionActionSpaceC,
)
from rgfn.trainer.trajectory_filters.trajectory_filter_base import TrajectoryFilterBase


@gin.configurable()
class RGFNTrajectoryFilter(
    TrajectoryFilterBase[ReactionState, ReactionActionSpace, ReactionAction]
):
    def __call__(
            self, trajectories: Trajectories[ReactionState, ReactionActionSpace, ReactionAction]
    ) -> Trajectories[ReactionState, ReactionActionSpace, ReactionAction]:
        source_states = trajectories.get_source_states_flat()
        valid_source_mask = torch.tensor(
            [isinstance(state, ReactionState0) for state in source_states]
        )
        other_valid_mask = []
        for action_list, backward_action_spaces_list in zip(
                trajectories._actions_list, trajectories._backward_action_spaces_list
        ):
            valid = True
            for action, backward_action_space in zip(action_list, backward_action_spaces_list):
                if (
                        isinstance(backward_action_space, ReactionActionSpaceC)
                        and action not in backward_action_space.possible_actions
                ):
                    print("Invalid action space!")
                    print(action)
                    print(backward_action_space)
                    valid = False
                    break
            other_valid_mask.append(valid)

        valid_source_mask = valid_source_mask & torch.tensor(other_valid_mask)
        return trajectories.masked_select(valid_source_mask)
