import torch
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from ...agents.actor import LLMActor, MotifAgent
from ...agents.coordinator import CentralizedCoordinator
from ...environment import AssemblyEnvironment
from ...core.representation import MolecularGraph, Motif
from ...core.segmentation import BRICSSegmentation
from ...core.topology import TopologyAnalyzer


@dataclass
class ReconstructionResult:
    success: bool
    reconstructed_graph: MolecularGraph
    target_graph: MolecularGraph
    steps_taken: int
    accuracy_metrics: Dict[str, float]
    action_sequence: List[Any]
    reasoning_log: List[str]
    topology_analysis: Dict[str, Any]


class MolecularReconstructor:
    def __init__(self, actor: LLMActor, coordinator: CentralizedCoordinator,
                 segmentation: Optional[BRICSSegmentation] = None):
        self.actor = actor
        self.coordinator = coordinator
        self.segmentation = segmentation or BRICSSegmentation()
        self.topology_analyzer = TopologyAnalyzer()

    def reconstruct_molecule(self, target_smiles: str, max_steps: int = 100,
                           temperature: float = 0.1, beam_width: int = 1) -> ReconstructionResult:
        # Parse target molecule and create motifs
        target_mol = self._parse_smiles(target_smiles)
        if target_mol is None:
            return self._create_failed_result("Invalid target SMILES")

        # Segment target molecule
        target_motifs = self.segmentation.segment_molecule(target_mol, "target")
        target_graph = MolecularGraph(target_motifs)
        self._reconstruct_target_connections(target_graph, target_mol)

        # Create initial state with disconnected motifs
        initial_motifs = [self._copy_motif_disconnected(motif) for motif in target_motifs]
        initial_graph = MolecularGraph(initial_motifs)

        # Set up environment
        env = AssemblyEnvironment(max_steps=max_steps, mode="reconstruction")
        state = env.reset(initial_graph, target_graph, mode="reconstruction")

        if beam_width == 1:
            return self._greedy_reconstruction(env, state, target_graph, temperature)
        else:
            return self._beam_search_reconstruction(env, state, target_graph, beam_width, temperature)

    def _parse_smiles(self, smiles: str):
        try:
            from rdkit import Chem
            return Chem.MolFromSmiles(smiles)
        except Exception:
            return None

    def _reconstruct_target_connections(self, target_graph: MolecularGraph, target_mol):
        # This is simplified - in practice, would need sophisticated connection inference
        # For now, assume all motifs should be connected in sequence
        motif_ids = list(target_graph.motifs.keys())

        for i in range(len(motif_ids) - 1):
            # Create simple connection between consecutive motifs
            from ...core.representation import Connection
            connection = Connection(
                source_motif=motif_ids[i],
                source_site=0,
                target_motif=motif_ids[i + 1],
                target_site=0,
                bond_type="SINGLE"
            )
            target_graph.add_connection(connection)

    def _copy_motif_disconnected(self, motif: Motif) -> Motif:
        # Create a copy of the motif with all connection sites available
        new_motif = Motif(
            motif_id=motif.motif_id,
            smiles=motif.smiles,
            mol=motif.mol,
            connection_sites=motif.connection_sites.copy(),
            properties=motif.properties.copy(),
            is_aromatic=motif.is_aromatic,
            ring_info=motif.ring_info.copy() if motif.ring_info else {},
            functional_groups=motif.functional_groups.copy() if motif.functional_groups else []
        )
        return new_motif

    def _greedy_reconstruction(self, env: AssemblyEnvironment, initial_state, target_graph: MolecularGraph,
                             temperature: float) -> ReconstructionResult:
        state = initial_state
        action_sequence = []
        reasoning_log = []
        steps_taken = 0

        # Create motif agents
        motif_agents = {}
        for motif_id in state.available_motifs:
            motif_agents[motif_id] = MotifAgent(motif_id, self.actor)

        while not state.terminated and not state.truncated and steps_taken < env.max_steps:
            # Get action from coordinator
            action, coordination_info = self.coordinator.coordinate_actions(
                state, motif_agents, temperature=temperature
            )

            action_sequence.append(action)
            reasoning_log.append(action.reasoning)

            # Execute action
            next_state, reward, terminated, truncated, info = env.step(action)
            state = next_state
            steps_taken += 1

            if terminated:
                break

        # Evaluate reconstruction accuracy
        accuracy_metrics = self._compute_accuracy_metrics(state.current_graph, target_graph)
        topology_analysis = self.topology_analyzer.analyze_topology(state.current_graph)

        return ReconstructionResult(
            success=accuracy_metrics['perfect_match'],
            reconstructed_graph=state.current_graph,
            target_graph=target_graph,
            steps_taken=steps_taken,
            accuracy_metrics=accuracy_metrics,
            action_sequence=action_sequence,
            reasoning_log=reasoning_log,
            topology_analysis=topology_analysis
        )

    def _beam_search_reconstruction(self, env: AssemblyEnvironment, initial_state, target_graph: MolecularGraph,
                                  beam_width: int, temperature: float) -> ReconstructionResult:
        # Beam search implementation for exploring multiple reconstruction paths

        @dataclass
        class BeamState:
            state: Any
            env: AssemblyEnvironment
            score: float
            action_sequence: List[Any]
            reasoning_log: List[str]

        # Initialize beam
        initial_beam_state = BeamState(
            state=initial_state,
            env=env,
            score=0.0,
            action_sequence=[],
            reasoning_log=[]
        )
        beam = [initial_beam_state]

        max_steps = env.max_steps
        for step in range(max_steps):
            new_beam = []

            for beam_state in beam:
                if beam_state.state.terminated or beam_state.state.truncated:
                    new_beam.append(beam_state)
                    continue

                # Create motif agents for this state
                motif_agents = {}
                for motif_id in beam_state.state.available_motifs:
                    motif_agents[motif_id] = MotifAgent(motif_id, self.actor)

                # Get multiple action candidates
                valid_actions = env.get_valid_actions()
                action_scores = []

                for action in valid_actions[:beam_width * 2]:  # Consider more actions than beam width
                    # Score action based on topology progress
                    score = self._score_action_for_reconstruction(beam_state.state, action, target_graph)
                    action_scores.append((action, score))

                # Sort by score and take top actions
                action_scores.sort(key=lambda x: x[1], reverse=True)
                top_actions = action_scores[:beam_width]

                # Expand beam
                for action, action_score in top_actions:
                    # Clone environment and execute action
                    new_env = self._clone_environment(beam_state.env)
                    next_state, reward, terminated, truncated, info = new_env.step(action)

                    new_beam_state = BeamState(
                        state=next_state,
                        env=new_env,
                        score=beam_state.score + action_score + reward,
                        action_sequence=beam_state.action_sequence + [action],
                        reasoning_log=beam_state.reasoning_log + [action.reasoning]
                    )
                    new_beam.append(new_beam_state)

            # Prune beam to keep only top candidates
            new_beam.sort(key=lambda x: x.score, reverse=True)
            beam = new_beam[:beam_width]

            # Check if any path is complete
            completed_paths = [bs for bs in beam if bs.state.terminated]
            if completed_paths:
                best_path = max(completed_paths, key=lambda x: x.score)
                return self._create_result_from_beam_state(best_path, target_graph)

        # Return best incomplete path
        if beam:
            best_path = max(beam, key=lambda x: x.score)
            return self._create_result_from_beam_state(best_path, target_graph)

        return self._create_failed_result("Beam search failed")

    def _score_action_for_reconstruction(self, state, action, target_graph: MolecularGraph) -> float:
        # Score actions based on how much they advance reconstruction
        if action.is_stop_action():
            # Score STOP based on completion
            current_edges = set(state.current_graph.graph.edges())
            target_edges = set(target_graph.graph.edges())
            completion = len(current_edges & target_edges) / max(len(target_edges), 1)
            connectivity = 1.0 if state.current_graph.num_connected_components() == 1 else 0.0
            return completion + connectivity

        # Score connection actions
        score = 0.0

        # Check if this creates a target edge
        proposed_edge = (action.source_motif, action.target_motif)
        reverse_edge = (action.target_motif, action.source_motif)
        target_edges = set(target_graph.graph.edges())

        if proposed_edge in target_edges or reverse_edge in target_edges:
            score += 2.0  # High bonus for correct edge

        # Connectivity bonus
        current_components = state.current_graph.num_connected_components()
        if current_components > 1:
            score += 1.0  # Bonus for reducing fragmentation

        return score

    def _clone_environment(self, env: AssemblyEnvironment) -> AssemblyEnvironment:
        # Create a copy of the environment - simplified
        new_env = AssemblyEnvironment(max_steps=env.max_steps, mode="reconstruction")
        new_env.current_state = env.current_state.clone()
        return new_env

    def _create_result_from_beam_state(self, beam_state, target_graph: MolecularGraph) -> ReconstructionResult:
        accuracy_metrics = self._compute_accuracy_metrics(beam_state.state.current_graph, target_graph)
        topology_analysis = self.topology_analyzer.analyze_topology(beam_state.state.current_graph)

        return ReconstructionResult(
            success=accuracy_metrics['perfect_match'],
            reconstructed_graph=beam_state.state.current_graph,
            target_graph=target_graph,
            steps_taken=len(beam_state.action_sequence),
            accuracy_metrics=accuracy_metrics,
            action_sequence=beam_state.action_sequence,
            reasoning_log=beam_state.reasoning_log,
            topology_analysis=topology_analysis
        )

    def _compute_accuracy_metrics(self, reconstructed_graph: MolecularGraph,
                                 target_graph: MolecularGraph) -> Dict[str, float]:
        current_edges = set(reconstructed_graph.graph.edges())
        target_edges = set(target_graph.graph.edges())

        # Normalize edges (ensure consistent ordering)
        current_edges = {tuple(sorted(edge)) for edge in current_edges}
        target_edges = {tuple(sorted(edge)) for edge in target_edges}

        # Basic metrics
        if not target_edges:
            edge_precision = 1.0 if not current_edges else 0.0
            edge_recall = 1.0
            edge_f1 = 1.0
        else:
            correct_edges = len(current_edges & target_edges)
            edge_precision = correct_edges / len(current_edges) if current_edges else 0.0
            edge_recall = correct_edges / len(target_edges)
            edge_f1 = 2 * edge_precision * edge_recall / (edge_precision + edge_recall) if (edge_precision + edge_recall) > 0 else 0.0

        # Perfect match
        perfect_match = current_edges == target_edges

        # Connectivity match
        connectivity_match = (reconstructed_graph.num_connected_components() == 1 and
                            target_graph.num_connected_components() == 1)

        # Topology similarity
        try:
            edit_distance = self.topology_analyzer.compute_graph_edit_distance(reconstructed_graph, target_graph)
            topology_similarity = 1.0 / (1.0 + edit_distance) if edit_distance != float('inf') else 0.0
        except:
            topology_similarity = 0.0

        return {
            'edge_precision': edge_precision,
            'edge_recall': edge_recall,
            'edge_f1': edge_f1,
            'perfect_match': perfect_match,
            'connectivity_match': connectivity_match,
            'topology_similarity': topology_similarity,
            'num_correct_edges': len(current_edges & target_edges),
            'num_extra_edges': len(current_edges - target_edges),
            'num_missing_edges': len(target_edges - current_edges)
        }

    def _create_failed_result(self, error_message: str) -> ReconstructionResult:
        return ReconstructionResult(
            success=False,
            reconstructed_graph=MolecularGraph([]),
            target_graph=MolecularGraph([]),
            steps_taken=0,
            accuracy_metrics={'error': error_message},
            action_sequence=[],
            reasoning_log=[error_message],
            topology_analysis={}
        )

    def batch_reconstruct(self, target_smiles_list: List[str], **kwargs) -> List[ReconstructionResult]:
        # Reconstruct multiple molecules in batch
        results = []

        for smiles in target_smiles_list:
            try:
                result = self.reconstruct_molecule(smiles, **kwargs)
                results.append(result)
            except Exception as e:
                failed_result = self._create_failed_result(f"Error: {str(e)}")
                results.append(failed_result)

        return results

    def get_reconstruction_statistics(self, results: List[ReconstructionResult]) -> Dict[str, Any]:
        # Compute statistics across multiple reconstructions
        if not results:
            return {}

        valid_results = [r for r in results if r.success is not False and 'error' not in r.accuracy_metrics]

        if not valid_results:
            return {'error': 'No valid results to analyze'}

        stats = {
            'total_reconstructions': len(results),
            'successful_reconstructions': sum(1 for r in results if r.success),
            'success_rate': sum(1 for r in results if r.success) / len(results),
            'average_steps': sum(r.steps_taken for r in valid_results) / len(valid_results),
            'average_edge_precision': sum(r.accuracy_metrics.get('edge_precision', 0) for r in valid_results) / len(valid_results),
            'average_edge_recall': sum(r.accuracy_metrics.get('edge_recall', 0) for r in valid_results) / len(valid_results),
            'average_edge_f1': sum(r.accuracy_metrics.get('edge_f1', 0) for r in valid_results) / len(valid_results),
            'perfect_match_rate': sum(1 for r in valid_results if r.accuracy_metrics.get('perfect_match', False)) / len(valid_results),
            'connectivity_match_rate': sum(1 for r in valid_results if r.accuracy_metrics.get('connectivity_match', False)) / len(valid_results)
        }

        return stats