import torch
from typing import Dict, List, Tuple, Optional, Any, Set
from dataclasses import dataclass
import random
import numpy as np

from ...agents.actor import LLMActor, MotifAgent
from ...agents.coordinator import CentralizedCoordinator
from ...environment import AssemblyEnvironment
from ...core.representation import MolecularGraph, Motif
from ...core.topology import TopologyAnalyzer


@dataclass
class GenerationResult:
    success: bool
    generated_graph: MolecularGraph
    steps_taken: int
    properties: Dict[str, float]
    action_sequence: List[Any]
    reasoning_log: List[str]
    topology_analysis: Dict[str, Any]
    generation_metrics: Dict[str, float]


class MolecularGenerator:
    def __init__(self, actor: LLMActor, coordinator: CentralizedCoordinator,
                 motif_library: Optional[List[Motif]] = None):
        self.actor = actor
        self.coordinator = coordinator
        self.motif_library = motif_library or self._create_default_motif_library()
        self.topology_analyzer = TopologyAnalyzer()

    def _create_default_motif_library(self) -> List[Motif]:
        # Create a basic library of common motifs
        # In practice, this would be loaded from a comprehensive database
        from rdkit import Chem
        from ...core.representation import ConnectionSite

        common_smiles = [
            'c1ccccc1',      # benzene
            'CCO',           # ethanol
            'CC(=O)O',       # acetic acid
            'CCN',           # ethylamine
            'C=C',           # ethene
            'C#C',           # ethyne
            'C1CCCCC1',      # cyclohexane
            'c1ccncc1',      # pyridine
            'c1ccc2ccccc2c1', # naphthalene
            'CC(C)C',        # isobutane
        ]

        motifs = []
        for i, smiles in enumerate(common_smiles):
            mol = Chem.MolFromSmiles(smiles)
            if mol:
                # Create simple connection sites
                sites = []
                for atom_idx, atom in enumerate(mol.GetAtoms()):
                    if atom.GetTotalValence() < self._get_max_valence(atom.GetAtomicNum()):
                        site = ConnectionSite(
                            site_id=len(sites),
                            atom_idx=atom_idx,
                            site_type=f"{atom.GetSymbol()}_sp3",
                            chemical_environment=f"{atom.GetSymbol()}",
                            allowed_bond_types={"SINGLE", "DOUBLE"} if atom.GetAtomicNum() == 6 else {"SINGLE"},
                            is_aromatic=atom.GetIsAromatic()
                        )
                        sites.append(site)

                motif = Motif(
                    motif_id=f"lib_motif_{i}",
                    smiles=smiles,
                    mol=mol,
                    connection_sites=sites,
                    properties={
                        'molecular_weight': Chem.Descriptors.MolWt(mol),
                        'logp': Chem.Descriptors.MolLogP(mol)
                    },
                    is_aromatic=any(atom.GetIsAromatic() for atom in mol.GetAtoms()),
                    functional_groups=self._identify_functional_groups(mol)
                )
                motifs.append(motif)

        return motifs

    def _get_max_valence(self, atomic_num: int) -> int:
        valence_dict = {6: 4, 7: 3, 8: 2, 16: 6, 15: 5, 9: 1, 17: 1, 35: 1, 53: 1}
        return valence_dict.get(atomic_num, 4)

    def _identify_functional_groups(self, mol) -> List[str]:
        # Simplified functional group identification
        groups = []
        for atom in mol.GetAtoms():
            if atom.GetIsAromatic():
                groups.append('aromatic')
            if atom.GetAtomicNum() == 8 and atom.GetTotalDegree() == 1:
                groups.append('alcohol')
        return list(set(groups))

    def generate_molecule(self, target_properties: Optional[Dict[str, float]] = None,
                         num_motifs: int = 5, max_steps: int = 100,
                         temperature: float = 0.5, diversity_bonus: float = 0.1) -> GenerationResult:

        # Sample motifs from library
        selected_motifs = self._sample_motifs_for_generation(num_motifs, target_properties)

        # Create initial graph with disconnected motifs
        initial_graph = MolecularGraph(selected_motifs)

        # Set up environment for generation
        env = AssemblyEnvironment(max_steps=max_steps, mode="generation")
        state = env.reset(initial_graph, target_graph=None, mode="generation",
                         properties_target=target_properties)

        # Generate molecule
        return self._generate_with_policy(env, state, temperature, diversity_bonus)

    def _sample_motifs_for_generation(self, num_motifs: int,
                                    target_properties: Optional[Dict[str, float]] = None) -> List[Motif]:
        # Sample motifs based on target properties (if provided)
        if target_properties and 'molecular_weight' in target_properties:
            # Try to select motifs that could achieve target molecular weight
            target_mw = target_properties['molecular_weight']
            avg_motif_mw = target_mw / num_motifs

            # Filter library by appropriate molecular weight
            suitable_motifs = [
                motif for motif in self.motif_library
                if motif.properties.get('molecular_weight', 0) <= avg_motif_mw * 2
            ]
        else:
            suitable_motifs = self.motif_library

        if len(suitable_motifs) < num_motifs:
            # Repeat motifs if necessary
            selected = suitable_motifs * (num_motifs // len(suitable_motifs) + 1)
        else:
            selected = random.sample(suitable_motifs, num_motifs)

        # Create copies with unique IDs
        copied_motifs = []
        for i, motif in enumerate(selected[:num_motifs]):
            copied_motif = self._copy_motif_with_id(motif, f"gen_{i}")
            copied_motifs.append(copied_motif)

        return copied_motifs

    def _copy_motif_with_id(self, motif: Motif, new_id: str) -> Motif:
        return Motif(
            motif_id=new_id,
            smiles=motif.smiles,
            mol=motif.mol,
            connection_sites=[site for site in motif.connection_sites],  # Copy sites
            properties=motif.properties.copy(),
            is_aromatic=motif.is_aromatic,
            ring_info=motif.ring_info.copy() if motif.ring_info else {},
            functional_groups=motif.functional_groups.copy() if motif.functional_groups else []
        )

    def _generate_with_policy(self, env: AssemblyEnvironment, initial_state,
                            temperature: float, diversity_bonus: float) -> GenerationResult:
        state = initial_state
        action_sequence = []
        reasoning_log = []
        steps_taken = 0

        # Create motif agents
        motif_agents = {}
        for motif_id in state.available_motifs:
            motif_agents[motif_id] = MotifAgent(motif_id, self.actor)

        # Track visited states for diversity
        visited_topologies = set()

        while not state.terminated and not state.truncated and steps_taken < env.max_steps:
            # Apply diversity bonus
            current_topology_hash = self._hash_topology(state.current_graph)
            if current_topology_hash in visited_topologies:
                temperature += diversity_bonus  # Increase exploration

            visited_topologies.add(current_topology_hash)

            # Get action from coordinator
            action, coordination_info = self.coordinator.coordinate_actions(
                state, motif_agents, temperature=temperature
            )

            action_sequence.append(action)
            reasoning_log.append(action.reasoning)

            # Execute action
            next_state, reward, terminated, truncated, info = env.step(action)
            state = next_state
            steps_taken += 1

            # Check early termination conditions for generation
            if self._should_terminate_generation(state):
                break

        # Evaluate generation quality
        properties = self._estimate_molecular_properties(state.current_graph)
        generation_metrics = self._compute_generation_metrics(state.current_graph, properties)
        topology_analysis = self.topology_analyzer.analyze_topology(state.current_graph)

        return GenerationResult(
            success=state.current_graph.num_connected_components() == 1,
            generated_graph=state.current_graph,
            steps_taken=steps_taken,
            properties=properties,
            action_sequence=action_sequence,
            reasoning_log=reasoning_log,
            topology_analysis=topology_analysis,
            generation_metrics=generation_metrics
        )

    def _hash_topology(self, graph: MolecularGraph) -> str:
        # Create a simple hash of the current topology
        edges = sorted(graph.graph.edges())
        return str(hash(tuple(edges)))

    def _should_terminate_generation(self, state) -> bool:
        # Early termination conditions for generation
        graph = state.current_graph

        # If fully connected and reasonable size
        if (graph.num_connected_components() == 1 and
            len(graph.connections) >= len(graph.motifs) - 1):

            # Check if all motifs have few remaining sites
            high_valence_motifs = 0
            for motif in graph.motifs.values():
                available_sites = motif.get_available_sites()
                if len(available_sites) > 2:
                    high_valence_motifs += 1

            if high_valence_motifs <= 1:  # At most one motif with many available sites
                return True

        return False

    def _estimate_molecular_properties(self, graph: MolecularGraph) -> Dict[str, float]:
        # Estimate molecular properties from assembled motifs
        properties = {
            'molecular_weight': 0.0,
            'logp': 0.0,
            'tpsa': 0.0,
            'num_rings': 0,
            'num_aromatic_rings': 0,
            'num_heteroatoms': 0
        }

        for motif in graph.motifs.values():
            motif_props = motif.properties
            properties['molecular_weight'] += motif_props.get('molecular_weight', 0.0)
            properties['logp'] += motif_props.get('logp', 0.0)
            properties['tpsa'] += motif_props.get('tpsa', 0.0)
            properties['num_rings'] += motif_props.get('num_rings', 0)
            properties['num_aromatic_rings'] += motif_props.get('num_aromatic_rings', 0)
            properties['num_heteroatoms'] += motif_props.get('num_heteroatoms', 0)

        # Adjust for connections (very simplified)
        num_connections = len(graph.connections)
        properties['molecular_weight'] -= num_connections * 2  # Approximate loss of H atoms

        return properties

    def _compute_generation_metrics(self, graph: MolecularGraph,
                                  properties: Dict[str, float]) -> Dict[str, float]:
        # Compute metrics for generation quality
        metrics = {
            'connectivity': float(graph.num_connected_components() == 1),
            'num_motifs': len(graph.motifs),
            'num_connections': len(graph.connections),
            'connection_density': len(graph.connections) / max(len(graph.motifs), 1),
            'avg_degree': sum(graph.graph.degree(motif_id) for motif_id in graph.motifs) / max(len(graph.motifs), 1)
        }

        # Property-based metrics
        mw = properties.get('molecular_weight', 0)
        if 150 <= mw <= 500:  # Drug-like molecular weight range
            metrics['mw_druglike'] = 1.0
        else:
            metrics['mw_druglike'] = 0.0

        logp = properties.get('logp', 0)
        if -2 <= logp <= 5:  # Drug-like logP range
            metrics['logp_druglike'] = 1.0
        else:
            metrics['logp_druglike'] = 0.0

        # Diversity metrics
        unique_motif_types = len(set(motif.smiles for motif in graph.motifs.values()))
        metrics['motif_diversity'] = unique_motif_types / max(len(graph.motifs), 1)

        # Aromatic content
        aromatic_motifs = sum(1 for motif in graph.motifs.values() if motif.is_aromatic)
        metrics['aromatic_fraction'] = aromatic_motifs / max(len(graph.motifs), 1)

        return metrics

    def guided_generation(self, target_properties: Dict[str, float],
                         num_attempts: int = 5, **kwargs) -> List[GenerationResult]:
        # Generate multiple molecules and select best ones based on properties
        results = []

        for attempt in range(num_attempts):
            # Adjust generation parameters for each attempt
            temp = 0.3 + 0.2 * attempt  # Increase temperature for diversity
            result = self.generate_molecule(target_properties, temperature=temp, **kwargs)
            results.append(result)

        # Sort by property alignment
        scored_results = []
        for result in results:
            if result.success:
                property_score = self._score_property_alignment(result.properties, target_properties)
                scored_results.append((result, property_score))

        scored_results.sort(key=lambda x: x[1], reverse=True)
        return [result for result, score in scored_results]

    def _score_property_alignment(self, properties: Dict[str, float],
                                target_properties: Dict[str, float]) -> float:
        # Score how well properties match targets
        if not target_properties:
            return 1.0

        total_score = 0.0
        num_properties = 0

        for prop_name, target_value in target_properties.items():
            if prop_name in properties:
                actual_value = properties[prop_name]
                # Normalized absolute error
                error = abs(actual_value - target_value) / max(abs(target_value), 1.0)
                score = max(0, 1.0 - error)
                total_score += score
                num_properties += 1

        return total_score / max(num_properties, 1)

    def batch_generate(self, num_molecules: int, target_properties: Optional[Dict[str, float]] = None,
                      **kwargs) -> List[GenerationResult]:
        # Generate multiple molecules
        results = []

        for i in range(num_molecules):
            try:
                result = self.generate_molecule(target_properties, **kwargs)
                results.append(result)
            except Exception as e:
                # Create failed result
                failed_result = GenerationResult(
                    success=False,
                    generated_graph=MolecularGraph([]),
                    steps_taken=0,
                    properties={},
                    action_sequence=[],
                    reasoning_log=[f"Error: {str(e)}"],
                    topology_analysis={},
                    generation_metrics={'error': str(e)}
                )
                results.append(failed_result)

        return results

    def get_generation_statistics(self, results: List[GenerationResult]) -> Dict[str, Any]:
        # Compute statistics across multiple generations
        if not results:
            return {}

        valid_results = [r for r in results if r.success]

        if not valid_results:
            return {
                'total_generations': len(results),
                'success_rate': 0.0,
                'error': 'No successful generations'
            }

        stats = {
            'total_generations': len(results),
            'successful_generations': len(valid_results),
            'success_rate': len(valid_results) / len(results),
            'average_steps': np.mean([r.steps_taken for r in valid_results]),
            'average_num_motifs': np.mean([r.generation_metrics.get('num_motifs', 0) for r in valid_results]),
            'average_connections': np.mean([r.generation_metrics.get('num_connections', 0) for r in valid_results]),
            'connectivity_rate': np.mean([r.generation_metrics.get('connectivity', 0) for r in valid_results]),
            'druglike_mw_rate': np.mean([r.generation_metrics.get('mw_druglike', 0) for r in valid_results]),
            'druglike_logp_rate': np.mean([r.generation_metrics.get('logp_druglike', 0) for r in valid_results]),
            'average_diversity': np.mean([r.generation_metrics.get('motif_diversity', 0) for r in valid_results])
        }

        # Property statistics
        property_stats = {}
        for prop in ['molecular_weight', 'logp', 'tpsa']:
            values = [r.properties.get(prop, 0) for r in valid_results if prop in r.properties]
            if values:
                property_stats[prop] = {
                    'mean': np.mean(values),
                    'std': np.std(values),
                    'min': np.min(values),
                    'max': np.max(values)
                }

        stats['property_statistics'] = property_stats

        return stats