import torch
from typing import List, Dict, Tuple, Set, Optional
from .assembly_action import AssemblyAction, ActionType
from ..state import AssemblyState


class ActionSpace:
    def __init__(self, max_motifs: int = 50, max_sites_per_motif: int = 10):
        self.max_motifs = max_motifs
        self.max_sites_per_motif = max_sites_per_motif
        self.bond_types = ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC"]
        self.action_dim = self._calculate_action_dim()

    def _calculate_action_dim(self) -> int:
        # Hierarchical action space dimensions
        # Layer 1: Source motif selection (max_motifs + 1 for STOP)
        source_dim = self.max_motifs + 1

        # Layer 2: Source site selection (max_sites_per_motif)
        source_site_dim = self.max_sites_per_motif

        # Layer 3: Target motif selection (max_motifs)
        target_dim = self.max_motifs

        # Layer 4: Target site selection (max_sites_per_motif)
        target_site_dim = self.max_sites_per_motif

        # Layer 5: Bond type selection (len(bond_types))
        bond_type_dim = len(self.bond_types)

        return {
            'source_motif': source_dim,
            'source_site': source_site_dim,
            'target_motif': target_dim,
            'target_site': target_site_dim,
            'bond_type': bond_type_dim
        }

    def get_valid_actions(self, state: AssemblyState) -> List[AssemblyAction]:
        valid_actions = []

        # Always allow STOP action
        valid_actions.append(AssemblyAction.create_stop_action())

        # Get valid connection actions
        valid_connections = state.get_available_connections()
        for source_motif, source_site, target_motif, target_site, bond_type in valid_connections:
            action = AssemblyAction.create_connect_action(
                source_motif, source_site, target_motif, target_site, bond_type
            )
            valid_actions.append(action)

        return valid_actions

    def create_action_masks(self, state: AssemblyState) -> Dict[str, torch.Tensor]:
        masks = {}

        # Source motif mask (including STOP)
        source_mask = torch.zeros(self.action_dim['source_motif'], dtype=torch.bool)
        source_mask[-1] = True  # STOP action always valid

        motif_ids = list(state.available_motifs)
        for i, motif_id in enumerate(motif_ids[:self.max_motifs]):
            if state.current_graph.motifs[motif_id].get_available_sites():
                source_mask[i] = True

        masks['source_motif'] = source_mask

        # For each layer, create conditional masks
        masks.update(self._create_conditional_masks(state, motif_ids))

        return masks

    def _create_conditional_masks(self, state: AssemblyState,
                                motif_ids: List[str]) -> Dict[str, torch.Tensor]:
        masks = {}

        # Source site mask
        source_site_mask = torch.zeros(
            (self.max_motifs, self.max_sites_per_motif),
            dtype=torch.bool
        )

        for i, motif_id in enumerate(motif_ids[:self.max_motifs]):
            if motif_id in state.current_graph.motifs:
                motif = state.current_graph.motifs[motif_id]
                available_sites = motif.get_available_sites()
                for site in available_sites:
                    if site.site_id < self.max_sites_per_motif:
                        source_site_mask[i, site.site_id] = True

        masks['source_site'] = source_site_mask

        # Target motif mask (conditional on source)
        target_mask = torch.zeros(
            (self.max_motifs, self.max_sites_per_motif, self.max_motifs),
            dtype=torch.bool
        )

        for i, source_motif in enumerate(motif_ids[:self.max_motifs]):
            if source_motif not in state.current_graph.motifs:
                continue

            source_motif_obj = state.current_graph.motifs[source_motif]
            source_sites = source_motif_obj.get_available_sites()

            for source_site in source_sites:
                if source_site.site_id >= self.max_sites_per_motif:
                    continue

                for j, target_motif in enumerate(motif_ids[:self.max_motifs]):
                    if i != j and not state.current_graph.graph.has_edge(source_motif, target_motif):
                        if target_motif in state.current_graph.motifs:
                            target_mask[i, source_site.site_id, j] = True

        masks['target_motif'] = target_mask

        # Target site and bond type masks
        target_site_mask, bond_type_mask = self._create_site_bond_masks(state, motif_ids)
        masks['target_site'] = target_site_mask
        masks['bond_type'] = bond_type_mask

        return masks

    def _create_site_bond_masks(self, state: AssemblyState, motif_ids: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
        # Target site mask
        target_site_mask = torch.zeros(
            (self.max_motifs, self.max_sites_per_motif, self.max_motifs, self.max_sites_per_motif),
            dtype=torch.bool
        )

        # Bond type mask
        bond_type_mask = torch.zeros(
            (self.max_motifs, self.max_sites_per_motif, self.max_motifs, self.max_sites_per_motif, len(self.bond_types)),
            dtype=torch.bool
        )

        for i, source_motif in enumerate(motif_ids[:self.max_motifs]):
            if source_motif not in state.current_graph.motifs:
                continue

            source_motif_obj = state.current_graph.motifs[source_motif]
            source_sites = source_motif_obj.get_available_sites()

            for source_site in source_sites:
                if source_site.site_id >= self.max_sites_per_motif:
                    continue

                for j, target_motif in enumerate(motif_ids[:self.max_motifs]):
                    if i == j or state.current_graph.graph.has_edge(source_motif, target_motif):
                        continue

                    if target_motif not in state.current_graph.motifs:
                        continue

                    target_motif_obj = state.current_graph.motifs[target_motif]
                    target_sites = target_motif_obj.get_available_sites()

                    for target_site in target_sites:
                        if target_site.site_id >= self.max_sites_per_motif:
                            continue

                        target_site_mask[i, source_site.site_id, j, target_site.site_id] = True

                        # Check compatible bond types
                        compatible_bonds = source_site.allowed_bond_types & target_site.allowed_bond_types
                        for k, bond_type in enumerate(self.bond_types):
                            if bond_type in compatible_bonds:
                                bond_type_mask[i, source_site.site_id, j, target_site.site_id, k] = True

        return target_site_mask, bond_type_mask

    def sample_action(self, state: AssemblyState, logits: Dict[str, torch.Tensor]) -> AssemblyAction:
        masks = self.create_action_masks(state)

        # Sample source motif (including STOP)
        masked_source_logits = logits['source_motif'].clone()
        masked_source_logits[~masks['source_motif']] = float('-inf')
        source_probs = torch.softmax(masked_source_logits, dim=-1)
        source_idx = torch.multinomial(source_probs, 1).item()

        # If STOP action selected
        if source_idx == len(masks['source_motif']) - 1:
            return AssemblyAction.create_stop_action(confidence=source_probs[source_idx].item())

        # Continue with hierarchical sampling for connection
        motif_ids = list(state.available_motifs)
        source_motif = motif_ids[source_idx]

        # Sample remaining components
        action_components = self._sample_connection_components(
            state, logits, masks, source_idx, source_motif, motif_ids
        )

        return AssemblyAction.create_connect_action(
            source_motif=source_motif,
            source_site=action_components['source_site'],
            target_motif=action_components['target_motif'],
            target_site=action_components['target_site'],
            bond_type=action_components['bond_type'],
            confidence=action_components['confidence']
        )

    def _sample_connection_components(self, state: AssemblyState, logits: Dict[str, torch.Tensor],
                                    masks: Dict[str, torch.Tensor], source_idx: int,
                                    source_motif: str, motif_ids: List[str]) -> Dict:
        # Sample source site
        source_site_logits = logits['source_site'][source_idx].clone()
        source_site_mask = masks['source_site'][source_idx]
        source_site_logits[~source_site_mask] = float('-inf')
        source_site_probs = torch.softmax(source_site_logits, dim=-1)
        source_site_idx = torch.multinomial(source_site_probs, 1).item()

        # Sample target motif
        target_motif_logits = logits['target_motif'][source_idx, source_site_idx].clone()
        target_motif_mask = masks['target_motif'][source_idx, source_site_idx]
        target_motif_logits[~target_motif_mask] = float('-inf')
        target_motif_probs = torch.softmax(target_motif_logits, dim=-1)
        target_idx = torch.multinomial(target_motif_probs, 1).item()
        target_motif = motif_ids[target_idx]

        # Sample target site
        target_site_logits = logits['target_site'][source_idx, source_site_idx, target_idx].clone()
        target_site_mask = masks['target_site'][source_idx, source_site_idx, target_idx]
        target_site_logits[~target_site_mask] = float('-inf')
        target_site_probs = torch.softmax(target_site_logits, dim=-1)
        target_site_idx = torch.multinomial(target_site_probs, 1).item()

        # Sample bond type
        bond_type_logits = logits['bond_type'][source_idx, source_site_idx, target_idx, target_site_idx].clone()
        bond_type_mask = masks['bond_type'][source_idx, source_site_idx, target_idx, target_site_idx]
        bond_type_logits[~bond_type_mask] = float('-inf')
        bond_type_probs = torch.softmax(bond_type_logits, dim=-1)
        bond_type_idx = torch.multinomial(bond_type_probs, 1).item()

        confidence = (
            source_probs[source_idx] *
            source_site_probs[source_site_idx] *
            target_motif_probs[target_idx] *
            target_site_probs[target_site_idx] *
            bond_type_probs[bond_type_idx]
        ).item()

        return {
            'source_site': source_site_idx,
            'target_motif': target_motif,
            'target_site': target_site_idx,
            'bond_type': self.bond_types[bond_type_idx],
            'confidence': confidence
        }