import torch
from dataclasses import dataclass
from typing import List, Dict, Set, Tuple, Optional
from ...core.representation import MolecularGraph, Motif, Connection


@dataclass
class AssemblyState:
    current_graph: MolecularGraph
    available_motifs: Set[str]
    target_graph: Optional[MolecularGraph] = None
    step: int = 0
    max_steps: int = 100
    terminated: bool = False
    truncated: bool = False
    mode: str = "reconstruction"  # "reconstruction" or "generation"
    properties: Optional[Dict[str, float]] = None

    def __post_init__(self):
        if self.properties is None:
            self.properties = {}

    def to_text_representation(self) -> str:
        text_parts = [
            f"ASSEMBLY_STATE:",
            f"MODE: {self.mode}",
            f"STEP: {self.step}/{self.max_steps}",
            f"TERMINATED: {self.terminated}",
            f"TRUNCATED: {self.truncated}",
            f"AVAILABLE_MOTIFS: {len(self.available_motifs)}",
            f"CONNECTED_COMPONENTS: {self.current_graph.num_connected_components()}",
            "",
            "CURRENT_TOPOLOGY:"
        ]

        text_parts.append(self.current_graph.to_text_representation())

        if self.target_graph:
            text_parts.extend([
                "",
                "TARGET_TOPOLOGY:",
                self.target_graph.to_text_representation()
            ])

        if self.properties:
            text_parts.extend([
                "",
                "PROPERTIES:",
                *[f"{k}: {v}" for k, v in self.properties.items()]
            ])

        return "\n".join(text_parts)

    def get_global_features(self) -> Dict[str, float]:
        features = self.current_graph.get_topology_features()

        features.update({
            'step_ratio': self.step / self.max_steps,
            'num_available_motifs': len(self.available_motifs),
            'terminated': float(self.terminated),
            'mode_reconstruction': float(self.mode == "reconstruction"),
            'mode_generation': float(self.mode == "generation")
        })

        if self.target_graph:
            target_features = self.target_graph.get_topology_features()
            features.update({
                f'target_{k}': v for k, v in target_features.items()
            })

            features.update({
                'edge_completion_ratio': self._compute_edge_completion_ratio(),
                'topology_distance': self._compute_topology_distance()
            })

        features.update(self.properties)

        return features

    def _compute_edge_completion_ratio(self) -> float:
        if not self.target_graph:
            return 0.0

        current_edges = set(self.current_graph.graph.edges())
        target_edges = set(self.target_graph.graph.edges())

        if not target_edges:
            return 1.0

        correct_edges = len(current_edges & target_edges)
        return correct_edges / len(target_edges)

    def _compute_topology_distance(self) -> float:
        if not self.target_graph:
            return 0.0

        try:
            return self.current_graph.get_graph_edit_distance(self.target_graph)
        except:
            return float('inf')

    def get_available_connections(self) -> List[Tuple[str, int, str, int, str]]:
        available_connections = []

        motif_ids = list(self.available_motifs)
        for i, motif_i in enumerate(motif_ids):
            motif_obj_i = self.current_graph.motifs[motif_i]

            for site_i in motif_obj_i.get_available_sites():
                for j, motif_j in enumerate(motif_ids):
                    if i >= j:  # Avoid duplicate pairs
                        continue

                    motif_obj_j = self.current_graph.motifs[motif_j]

                    for site_j in motif_obj_j.get_available_sites():
                        for bond_type in site_i.allowed_bond_types & site_j.allowed_bond_types:
                            available_connections.append((
                                motif_i, site_i.site_id,
                                motif_j, site_j.site_id,
                                bond_type
                            ))

        return available_connections

    def is_valid_connection(self, source_motif: str, source_site: int,
                          target_motif: str, target_site: int,
                          bond_type: str) -> bool:
        if source_motif == target_motif:
            return False

        if source_motif not in self.available_motifs or target_motif not in self.available_motifs:
            return False

        if self.current_graph.graph.has_edge(source_motif, target_motif):
            return False

        source_motif_obj = self.current_graph.motifs[source_motif]
        target_motif_obj = self.current_graph.motifs[target_motif]

        source_sites = {site.site_id: site for site in source_motif_obj.connection_sites}
        target_sites = {site.site_id: site for site in target_motif_obj.connection_sites}

        if source_site not in source_sites or target_site not in target_sites:
            return False

        source_site_obj = source_sites[source_site]
        target_site_obj = target_sites[target_site]

        return (bond_type in source_site_obj.allowed_bond_types and
                bond_type in target_site_obj.allowed_bond_types)

    def clone(self) -> 'AssemblyState':
        new_graph = MolecularGraph(list(self.current_graph.motifs.values()))
        for connection in self.current_graph.connections:
            new_graph.add_connection(connection)

        return AssemblyState(
            current_graph=new_graph,
            available_motifs=self.available_motifs.copy(),
            target_graph=self.target_graph,
            step=self.step,
            max_steps=self.max_steps,
            terminated=self.terminated,
            truncated=self.truncated,
            mode=self.mode,
            properties=self.properties.copy()
        )