import torch
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional, Set
from rdkit import Chem


@dataclass
class ConnectionSite:
    site_id: int
    atom_idx: int
    site_type: str
    chemical_environment: str
    allowed_bond_types: Set[str]
    is_aromatic: bool = False


@dataclass
class Motif:
    motif_id: str
    smiles: str
    mol: Chem.Mol
    connection_sites: List[ConnectionSite]
    properties: Dict[str, float]
    is_aromatic: bool = False
    ring_info: Dict = None
    functional_groups: List[str] = None

    def __post_init__(self):
        if self.ring_info is None:
            self.ring_info = {}
        if self.functional_groups is None:
            self.functional_groups = []

    def to_text_representation(self) -> str:
        text_parts = [
            f"MOTIF_ID: {self.motif_id}",
            f"SMILES: {self.smiles}",
            f"AROMATIC: {self.is_aromatic}",
            f"RINGS: {len(self.ring_info.get('rings', []))}",
            f"FUNCTIONAL_GROUPS: {', '.join(self.functional_groups)}",
            "CONNECTION_SITES:"
        ]

        for site in self.connection_sites:
            site_text = (
                f"  Site_{site.site_id}: atom_{site.atom_idx}, "
                f"type_{site.site_type}, "
                f"env_{site.chemical_environment}, "
                f"bonds_{','.join(site.allowed_bond_types)}"
            )
            if site.is_aromatic:
                site_text += ", aromatic"
            text_parts.append(site_text)

        return "\n".join(text_parts)

    def get_available_sites(self) -> List[ConnectionSite]:
        return [site for site in self.connection_sites if site.allowed_bond_types]

    def is_compatible_with(self, other_site: ConnectionSite, bond_type: str) -> bool:
        for site in self.connection_sites:
            if (bond_type in site.allowed_bond_types and
                bond_type in other_site.allowed_bond_types):
                return True
        return False


@dataclass
class Connection:
    source_motif: str
    source_site: int
    target_motif: str
    target_site: int
    bond_type: str

    def to_text_representation(self) -> str:
        return (
            f"CONNECTION: {self.source_motif}[site_{self.source_site}] "
            f"--{self.bond_type}--> {self.target_motif}[site_{self.target_site}]"
        )