import torch
from typing import Dict, List, Tuple, Optional, Any
from .state import AssemblyState
from .actions import AssemblyAction, ActionSpace
from .constraints import ChemicalConstraints, TopologyConstraints
from ..core.representation import MolecularGraph, Connection
from ..core.topology import TopologyAnalyzer


class AssemblyEnvironment:
    def __init__(self, max_steps: int = 100, max_motifs: int = 50,
                 mode: str = "reconstruction", chemical_validation: bool = True,
                 topology_validation: bool = True):
        self.max_steps = max_steps
        self.max_motifs = max_motifs
        self.mode = mode

        self.action_space = ActionSpace(max_motifs=max_motifs)
        self.chemical_constraints = ChemicalConstraints() if chemical_validation else None
        self.topology_constraints = TopologyConstraints() if topology_validation else None
        self.topology_analyzer = TopologyAnalyzer()

        self.current_state: Optional[AssemblyState] = None
        self.step_history: List[Dict] = []

    def reset(self, initial_graph: MolecularGraph, target_graph: Optional[MolecularGraph] = None,
              mode: Optional[str] = None, properties_target: Optional[Dict[str, float]] = None) -> AssemblyState:

        if mode is None:
            mode = self.mode

        self.current_state = AssemblyState(
            current_graph=initial_graph,
            available_motifs=set(initial_graph.motifs.keys()),
            target_graph=target_graph,
            step=0,
            max_steps=self.max_steps,
            terminated=False,
            truncated=False,
            mode=mode,
            properties=properties_target or {}
        )

        self.step_history = []
        return self.current_state

    def step(self, action: AssemblyAction) -> Tuple[AssemblyState, float, bool, bool, Dict[str, Any]]:
        if self.current_state is None:
            raise ValueError("Environment not reset. Call reset() first.")

        if self.current_state.terminated or self.current_state.truncated:
            return self.current_state, 0.0, True, True, {"error": "Environment already terminated"}

        # Validate action
        valid, error_msg = self._validate_action(action)
        if not valid:
            info = {"error": error_msg, "action_valid": False}
            reward = -1.0  # Penalty for invalid action
            return self.current_state, reward, False, False, info

        # Execute action
        next_state, reward, terminated, truncated, info = self._execute_action(action)

        # Update state
        self.current_state = next_state
        self.current_state.step += 1

        # Record step in history
        step_record = {
            "step": self.current_state.step,
            "action": action.to_dict(),
            "reward": reward,
            "terminated": terminated,
            "truncated": truncated,
            "info": info
        }
        self.step_history.append(step_record)

        # Check truncation due to max steps
        if self.current_state.step >= self.max_steps:
            self.current_state.truncated = True
            truncated = True

        return self.current_state, reward, terminated, truncated, info

    def _validate_action(self, action: AssemblyAction) -> Tuple[bool, str]:
        # Chemical constraints validation
        if self.chemical_constraints:
            valid, msg = self.chemical_constraints.validate_action(action, self.current_state.current_graph)
            if not valid:
                return False, f"Chemical constraint violation: {msg}"

        # Topology constraints validation
        if self.topology_constraints:
            valid, msg = self.topology_constraints.validate_topology(self.current_state, action)
            if not valid:
                return False, f"Topology constraint violation: {msg}"

        # Basic action validation
        if action.is_connect_action():
            valid = self.current_state.is_valid_connection(
                action.source_motif, action.source_site,
                action.target_motif, action.target_site, action.bond_type
            )
            if not valid:
                return False, "Invalid connection parameters"

        return True, ""

    def _execute_action(self, action: AssemblyAction) -> Tuple[AssemblyState, float, bool, bool, Dict[str, Any]]:
        next_state = self.current_state.clone()
        info = {"action_valid": True}

        if action.is_stop_action():
            terminated, termination_msg = self._check_termination_conditions(next_state)
            next_state.terminated = terminated

            reward = self._calculate_reward(self.current_state, action, next_state)
            info.update({
                "termination_reason": termination_msg,
                "stop_action": True
            })

            return next_state, reward, terminated, False, info

        # Execute connection action
        connection = Connection(
            source_motif=action.source_motif,
            source_site=action.source_site,
            target_motif=action.target_motif,
            target_site=action.target_site,
            bond_type=action.bond_type
        )

        next_state.current_graph.add_connection(connection)

        # Update available sites (simplified - remove used sites)
        self._update_available_sites(next_state, connection)

        # Calculate reward
        reward = self._calculate_reward(self.current_state, action, next_state)

        # Check if automatically terminated
        terminated, termination_msg = self._check_automatic_termination(next_state)
        next_state.terminated = terminated

        info.update({
            "connection_added": connection.to_dict() if hasattr(connection, 'to_dict') else str(connection),
            "topology_features": next_state.current_graph.get_topology_features(),
            "termination_reason": termination_msg if terminated else None
        })

        return next_state, reward, terminated, False, info

    def _update_available_sites(self, state: AssemblyState, connection: Connection):
        # Simplified site management - in practice, you'd need more sophisticated
        # tracking of which specific sites are used
        source_motif = state.current_graph.motifs[connection.source_motif]
        target_motif = state.current_graph.motifs[connection.target_motif]

        # Remove used sites from available sites (simplified)
        source_motif.connection_sites = [
            site for site in source_motif.connection_sites
            if site.site_id != connection.source_site
        ]

        target_motif.connection_sites = [
            site for site in target_motif.connection_sites
            if site.site_id != connection.target_site
        ]

    def _calculate_reward(self, prev_state: AssemblyState, action: AssemblyAction,
                         next_state: AssemblyState) -> float:
        # This will be implemented by the reward system
        # For now, return a simple reward
        if action.is_stop_action():
            if next_state.terminated:
                return 10.0  # Success bonus
            else:
                return -5.0  # Premature stop penalty

        # Connection reward
        reward = 1.0  # Base reward for valid connection

        # Connectivity bonus
        prev_components = prev_state.current_graph.num_connected_components()
        next_components = next_state.current_graph.num_connected_components()

        if next_components < prev_components:
            reward += 5.0 * (prev_components - next_components)

        # Target alignment (for reconstruction mode)
        if next_state.mode == "reconstruction" and next_state.target_graph:
            current_edges = set(next_state.current_graph.graph.edges())
            target_edges = set(next_state.target_graph.graph.edges())

            correct_edges = len(current_edges & target_edges)
            total_target_edges = len(target_edges)

            if total_target_edges > 0:
                progress = correct_edges / total_target_edges
                reward += 3.0 * progress

        return reward

    def _check_termination_conditions(self, state: AssemblyState) -> Tuple[bool, str]:
        if self.topology_constraints:
            terminated, msg = self.topology_constraints.check_termination_conditions(state)
            if terminated:
                return True, msg

        # Default termination conditions
        if state.mode == "reconstruction":
            if state.target_graph:
                current_edges = set(state.current_graph.graph.edges())
                target_edges = set(state.target_graph.graph.edges())

                if current_edges >= target_edges and state.current_graph.num_connected_components() == 1:
                    return True, "Reconstruction completed"

        return False, ""

    def _check_automatic_termination(self, state: AssemblyState) -> Tuple[bool, str]:
        # Check if no more valid actions are possible
        valid_actions = self.action_space.get_valid_actions(state)

        # If only STOP action is valid, don't auto-terminate (let agent choose)
        connect_actions = [a for a in valid_actions if a.is_connect_action()]

        if not connect_actions and state.current_graph.num_connected_components() == 1:
            return True, "No more valid connections possible"

        return False, ""

    def get_valid_actions(self) -> List[AssemblyAction]:
        if self.current_state is None:
            return []

        return self.action_space.get_valid_actions(self.current_state)

    def get_action_masks(self) -> Dict[str, torch.Tensor]:
        if self.current_state is None:
            return {}

        return self.action_space.create_action_masks(self.current_state)

    def render(self, mode: str = "text") -> str:
        if self.current_state is None:
            return "Environment not initialized"

        if mode == "text":
            return self.current_state.to_text_representation()

        return f"Render mode '{mode}' not supported"

    def get_info(self) -> Dict[str, Any]:
        if self.current_state is None:
            return {}

        info = {
            "step": self.current_state.step,
            "max_steps": self.max_steps,
            "mode": self.current_state.mode,
            "terminated": self.current_state.terminated,
            "truncated": self.current_state.truncated,
            "num_motifs": len(self.current_state.current_graph.motifs),
            "num_connections": len(self.current_state.current_graph.connections),
            "num_components": self.current_state.current_graph.num_connected_components(),
            "topology_features": self.current_state.current_graph.get_topology_features()
        }

        if self.current_state.target_graph:
            info["target_num_connections"] = len(self.current_state.target_graph.connections)
            info["completion_progress"] = self._get_completion_progress()

        return info

    def _get_completion_progress(self) -> float:
        if not self.current_state or not self.current_state.target_graph:
            return 0.0

        current_edges = set(self.current_state.current_graph.graph.edges())
        target_edges = set(self.current_state.target_graph.graph.edges())

        if not target_edges:
            return 1.0

        correct_edges = len(current_edges & target_edges)
        return correct_edges / len(target_edges)