import torch
import torch.nn.functional as F
from typing import List, Set, Dict, Optional, Any
from ...environment.state import AssemblyState
from ...environment.actions import AssemblyAction


class SetBehaviorCloning:
    def __init__(self, weight: float = 1.0, temperature: float = 1.0):
        self.weight = weight
        self.temperature = temperature

    def compute_correct_action_set(self, state: AssemblyState) -> Set[AssemblyAction]:
        # Compute set of all correct actions for the current state
        # In reconstruction mode, this includes all actions that form target edges
        correct_actions = set()

        if state.mode == "reconstruction" and state.target_graph:
            target_edges = set(state.target_graph.graph.edges())
            current_edges = set(state.current_graph.graph.edges())

            # Add actions for missing target edges
            missing_edges = target_edges - current_edges

            for source_motif, target_motif in missing_edges:
                if (source_motif in state.available_motifs and
                    target_motif in state.available_motifs and
                    source_motif in state.current_graph.motifs and
                    target_motif in state.current_graph.motifs):

                    # Find valid connection sites and bond types
                    source_motif_obj = state.current_graph.motifs[source_motif]
                    target_motif_obj = state.current_graph.motifs[target_motif]

                    for source_site in source_motif_obj.get_available_sites():
                        for target_site in target_motif_obj.get_available_sites():
                            compatible_bonds = (source_site.allowed_bond_types &
                                             target_site.allowed_bond_types)

                            for bond_type in compatible_bonds:
                                action = AssemblyAction.create_connect_action(
                                    source_motif=source_motif,
                                    source_site=source_site.site_id,
                                    target_motif=target_motif,
                                    target_site=target_site.site_id,
                                    bond_type=bond_type
                                )
                                correct_actions.add(action)

            # Add STOP action if reconstruction is complete
            if not missing_edges and state.current_graph.num_connected_components() == 1:
                stop_action = AssemblyAction.create_stop_action()
                correct_actions.add(stop_action)

        else:  # Generation mode
            # In generation mode, any chemically and topologically valid action could be correct
            # For now, include all valid actions (this could be refined)
            valid_actions = self._get_valid_actions_for_generation(state)
            correct_actions.update(valid_actions)

        return correct_actions

    def _get_valid_actions_for_generation(self, state: AssemblyState) -> List[AssemblyAction]:
        # Get all valid actions for generation mode
        valid_actions = []

        # Always include STOP if reasonable
        if (state.current_graph.num_connected_components() == 1 and
            len(state.current_graph.connections) >= len(state.current_graph.motifs) - 1):
            valid_actions.append(AssemblyAction.create_stop_action())

        # Add valid connection actions
        available_connections = state.get_available_connections()
        for source_motif, source_site, target_motif, target_site, bond_type in available_connections:
            action = AssemblyAction.create_connect_action(
                source_motif, source_site, target_motif, target_site, bond_type
            )
            valid_actions.append(action)

        return valid_actions

    def compute_bc_loss(self, state: AssemblyState, actor_outputs: Dict[str, torch.Tensor],
                       correct_actions: Optional[Set[AssemblyAction]] = None) -> torch.Tensor:
        if correct_actions is None:
            correct_actions = self.compute_correct_action_set(state)

        if not correct_actions:
            return torch.tensor(0.0)

        # Get action logits from actor
        action_logits = actor_outputs.get('action_logits', {})

        if not action_logits:
            return torch.tensor(0.0)

        # Compute the total probability mass over all correct actions
        total_correct_prob = torch.tensor(0.0)

        for action in correct_actions:
            action_prob = self._compute_action_probability(action_logits, action, state)
            total_correct_prob += action_prob

        # Set-BC loss: negative log probability of the correct set
        bc_loss = -torch.log(total_correct_prob + 1e-8)

        return bc_loss * self.weight

    def _compute_action_probability(self, action_logits: Dict[str, torch.Tensor],
                                  action: AssemblyAction, state: AssemblyState) -> torch.Tensor:
        # Compute probability of a specific action given the logits

        if action.is_stop_action():
            # STOP action probability
            source_logits = action_logits.get('source_motif', torch.zeros(1))
            if source_logits.numel() == 0:
                return torch.tensor(0.0)

            source_probs = F.softmax(source_logits / self.temperature, dim=-1)
            # Assume STOP is the last action in source_motif
            stop_prob = source_probs[0, -1] if source_probs.dim() > 1 else source_probs[-1]
            return stop_prob

        # Connection action probability
        try:
            motif_ids = list(state.available_motifs)
            source_idx = motif_ids.index(action.source_motif)
            target_idx = motif_ids.index(action.target_motif)

            # Get bond type index
            bond_types = ['SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC']  # Should match action space
            bond_idx = bond_types.index(action.bond_type)

            # Compute hierarchical probability
            # P(action) = P(source) * P(source_site|source) * P(target|source,site) *
            #             P(target_site|source,site,target) * P(bond|source,site,target,target_site)

            prob = torch.tensor(1.0)

            # Source motif probability
            source_logits = action_logits.get('source_motif')
            if source_logits is not None and source_logits.numel() > 0:
                source_probs = F.softmax(source_logits / self.temperature, dim=-1)
                prob *= source_probs[0, source_idx] if source_probs.dim() > 1 else source_probs[source_idx]

            # Source site probability
            source_site_logits = action_logits.get('source_site')
            if (source_site_logits is not None and
                source_site_logits.numel() > 0 and
                source_idx < source_site_logits.size(0)):

                site_logits = source_site_logits[source_idx]
                site_probs = F.softmax(site_logits / self.temperature, dim=-1)
                if action.source_site < site_probs.size(0):
                    prob *= site_probs[action.source_site]

            # Target motif probability
            target_logits = action_logits.get('target_motif')
            if (target_logits is not None and
                target_logits.numel() > 0 and
                source_idx < target_logits.size(0) and
                action.source_site < target_logits.size(1)):

                target_logits_slice = target_logits[source_idx, action.source_site]
                target_probs = F.softmax(target_logits_slice / self.temperature, dim=-1)
                if target_idx < target_probs.size(0):
                    prob *= target_probs[target_idx]

            # Target site probability
            target_site_logits = action_logits.get('target_site')
            if (target_site_logits is not None and
                target_site_logits.numel() > 0 and
                source_idx < target_site_logits.size(0) and
                action.source_site < target_site_logits.size(1) and
                target_idx < target_site_logits.size(2)):

                site_logits = target_site_logits[source_idx, action.source_site, target_idx]
                site_probs = F.softmax(site_logits / self.temperature, dim=-1)
                if action.target_site < site_probs.size(0):
                    prob *= site_probs[action.target_site]

            # Bond type probability
            bond_logits = action_logits.get('bond_type')
            if (bond_logits is not None and
                bond_logits.numel() > 0 and
                source_idx < bond_logits.size(0) and
                action.source_site < bond_logits.size(1) and
                target_idx < bond_logits.size(2) and
                action.target_site < bond_logits.size(3)):

                bond_logits_slice = bond_logits[source_idx, action.source_site, target_idx, action.target_site]
                bond_probs = F.softmax(bond_logits_slice / self.temperature, dim=-1)
                if bond_idx < bond_probs.size(0):
                    prob *= bond_probs[bond_idx]

            return prob

        except (ValueError, IndexError):
            # Action not valid for current state
            return torch.tensor(0.0)

    def compute_topology_equivalent_actions(self, state: AssemblyState,
                                          reference_action: AssemblyAction) -> Set[AssemblyAction]:
        # Find actions that are topologically equivalent to the reference action
        # This is used for advanced Set-BC that considers topological equivalence

        equivalent_actions = set()
        equivalent_actions.add(reference_action)

        if reference_action.is_stop_action():
            return equivalent_actions

        # For connection actions, find alternative site combinations that achieve same topology
        source_motif_obj = state.current_graph.motifs.get(reference_action.source_motif)
        target_motif_obj = state.current_graph.motifs.get(reference_action.target_motif)

        if not source_motif_obj or not target_motif_obj:
            return equivalent_actions

        # Find alternative sites with same chemical environment
        reference_source_site = None
        reference_target_site = None

        for site in source_motif_obj.connection_sites:
            if site.site_id == reference_action.source_site:
                reference_source_site = site
                break

        for site in target_motif_obj.connection_sites:
            if site.site_id == reference_action.target_site:
                reference_target_site = site
                break

        if not reference_source_site or not reference_target_site:
            return equivalent_actions

        # Find sites with equivalent chemical environments
        for source_site in source_motif_obj.get_available_sites():
            if (source_site.site_type == reference_source_site.site_type and
                source_site.chemical_environment == reference_source_site.chemical_environment):

                for target_site in target_motif_obj.get_available_sites():
                    if (target_site.site_type == reference_target_site.site_type and
                        target_site.chemical_environment == reference_target_site.chemical_environment):

                        # Check bond compatibility
                        if reference_action.bond_type in (source_site.allowed_bond_types &
                                                        target_site.allowed_bond_types):

                            equivalent_action = AssemblyAction.create_connect_action(
                                source_motif=reference_action.source_motif,
                                source_site=source_site.site_id,
                                target_motif=reference_action.target_motif,
                                target_site=target_site.site_id,
                                bond_type=reference_action.bond_type
                            )

                            equivalent_actions.add(equivalent_action)

        return equivalent_actions

    def compute_advanced_bc_loss(self, state: AssemblyState, actor_outputs: Dict[str, torch.Tensor],
                                reference_actions: List[AssemblyAction]) -> torch.Tensor:
        # Advanced Set-BC that considers topological equivalence
        total_loss = torch.tensor(0.0)

        for ref_action in reference_actions:
            equivalent_actions = self.compute_topology_equivalent_actions(state, ref_action)

            # Compute probability mass over equivalent actions
            equivalent_prob = torch.tensor(0.0)

            for action in equivalent_actions:
                action_prob = self._compute_action_probability(actor_outputs.get('action_logits', {}), action, state)
                equivalent_prob += action_prob

            # Add to loss
            total_loss += -torch.log(equivalent_prob + 1e-8)

        return total_loss * self.weight / max(len(reference_actions), 1)

    def get_bc_statistics(self, state: AssemblyState, actor_outputs: Dict[str, torch.Tensor]) -> Dict[str, float]:
        # Get statistics about BC performance
        correct_actions = self.compute_correct_action_set(state)

        if not correct_actions:
            return {'num_correct_actions': 0, 'total_correct_prob': 0.0}

        total_correct_prob = torch.tensor(0.0)
        action_probs = []

        for action in correct_actions:
            prob = self._compute_action_probability(actor_outputs.get('action_logits', {}), action, state)
            action_probs.append(prob.item())
            total_correct_prob += prob

        return {
            'num_correct_actions': len(correct_actions),
            'total_correct_prob': total_correct_prob.item(),
            'avg_correct_action_prob': total_correct_prob.item() / len(correct_actions),
            'max_correct_action_prob': max(action_probs) if action_probs else 0.0,
            'min_correct_action_prob': min(action_probs) if action_probs else 0.0
        }