import torch
from typing import Dict, List, Optional, Tuple, Set
from ...environment.state import AssemblyState
from ...environment.actions import AssemblyAction
from ...core.topology import TopologyAnalyzer


class TopologicalRewards:
    def __init__(self, weights: Optional[Dict[str, float]] = None):
        self.weights = weights or {
            'connectivity': 1.0,
            'edge_progress': 0.8,
            'topology_similarity': 0.6,
            'over_connection_penalty': 0.5
        }

        self.topology_analyzer = TopologyAnalyzer()

    def calculate_topological_reward(self, prev_state: AssemblyState, action: AssemblyAction,
                                   next_state: AssemblyState) -> Dict[str, float]:
        rewards = {}

        # Cross-component connection reward
        rewards['connectivity'] = self._calculate_connectivity_reward(prev_state, next_state)

        # Target edge progress reward (for reconstruction)
        rewards['edge_progress'] = self._calculate_edge_progress_reward(prev_state, next_state, action)

        # Topology similarity reward
        rewards['topology_similarity'] = self._calculate_topology_similarity_reward(next_state)

        # Over-connection penalty
        rewards['over_connection_penalty'] = self._calculate_over_connection_penalty(next_state, action)

        # Calculate weighted total
        total_reward = sum(self.weights[key] * reward for key, reward in rewards.items())
        rewards['total_topological'] = total_reward

        return rewards

    def _calculate_connectivity_reward(self, prev_state: AssemblyState,
                                     next_state: AssemblyState) -> float:
        # Reward for reducing number of connected components
        prev_components = prev_state.current_graph.num_connected_components()
        next_components = next_state.current_graph.num_connected_components()

        component_reduction = prev_components - next_components

        if component_reduction > 0:
            # Reward proportional to the reduction
            return float(component_reduction * 2.0)  # Scale factor for component reduction
        else:
            return 0.0

    def _calculate_edge_progress_reward(self, prev_state: AssemblyState, next_state: AssemblyState,
                                      action: AssemblyAction) -> float:
        if action.is_stop_action() or not next_state.target_graph:
            return 0.0

        if next_state.mode != "reconstruction":
            return 0.0

        # Calculate progress toward target edges
        prev_edges = set(prev_state.current_graph.graph.edges())
        next_edges = set(next_state.current_graph.graph.edges())
        target_edges = set(next_state.target_graph.graph.edges())

        # Normalize edge tuples for comparison
        def normalize_edge(edge):
            return tuple(sorted(edge))

        prev_edges = {normalize_edge(e) for e in prev_edges}
        next_edges = {normalize_edge(e) for e in next_edges}
        target_edges = {normalize_edge(e) for e in target_edges}

        # Calculate correct edges before and after
        prev_correct = len(prev_edges & target_edges)
        next_correct = len(next_edges & target_edges)

        edge_progress = next_correct - prev_correct

        if edge_progress > 0:
            return float(edge_progress * 3.0)  # High reward for correct edge addition
        elif edge_progress < 0:
            return float(edge_progress * 2.0)  # Penalty for removing correct edges
        else:
            return 0.0

    def _calculate_topology_similarity_reward(self, state: AssemblyState) -> float:
        if not state.target_graph or state.mode != "reconstruction":
            return 0.0

        try:
            # Calculate graph edit distance
            current_graph = state.current_graph
            target_graph = state.target_graph

            edit_distance = self.topology_analyzer.compute_graph_edit_distance(current_graph, target_graph)

            # Convert distance to similarity reward (higher similarity = lower distance)
            if edit_distance == float('inf'):
                return -1.0

            # Normalize distance and convert to reward
            max_possible_distance = len(current_graph.motifs) + len(target_graph.connections)
            normalized_distance = edit_distance / max(max_possible_distance, 1)

            similarity_reward = 1.0 - min(normalized_distance, 1.0)
            return similarity_reward

        except Exception:
            return 0.0

    def _calculate_over_connection_penalty(self, state: AssemblyState, action: AssemblyAction) -> float:
        if action.is_stop_action() or state.mode != "reconstruction" or not state.target_graph:
            return 0.0

        current_edges = len(state.current_graph.connections)
        target_edges = len(state.target_graph.connections)

        over_connections = max(0, current_edges - target_edges)

        if over_connections > 0:
            # Penalty proportional to excess connections
            return -float(over_connections * 1.0)
        else:
            return 0.0

    def calculate_topology_distance_reward(self, prev_state: AssemblyState,
                                         next_state: AssemblyState) -> float:
        # Reward based on reduction in topology distance to target
        if not next_state.target_graph or next_state.mode != "reconstruction":
            return 0.0

        try:
            prev_distance = self.topology_analyzer.compute_graph_edit_distance(
                prev_state.current_graph, next_state.target_graph
            )
            next_distance = self.topology_analyzer.compute_graph_edit_distance(
                next_state.current_graph, next_state.target_graph
            )

            distance_reduction = prev_distance - next_distance

            if distance_reduction > 0:
                return min(distance_reduction * 0.5, 2.0)  # Cap the reward
            else:
                return max(distance_reduction * 0.3, -1.0)  # Light penalty for increasing distance

        except Exception:
            return 0.0

    def calculate_assembly_progress_reward(self, state: AssemblyState) -> Dict[str, float]:
        # Comprehensive assembly progress evaluation
        progress_metrics = {}

        # Connectivity progress
        num_components = state.current_graph.num_connected_components()
        num_motifs = len(state.current_graph.motifs)

        if num_motifs > 1:
            connectivity_progress = 1.0 - (num_components - 1) / (num_motifs - 1)
        else:
            connectivity_progress = 1.0

        progress_metrics['connectivity_progress'] = connectivity_progress

        # Edge completion progress (for reconstruction)
        if state.target_graph and state.mode == "reconstruction":
            current_edges = set((u, v) if u < v else (v, u) for u, v in state.current_graph.graph.edges())
            target_edges = set((u, v) if u < v else (v, u) for u, v in state.target_graph.graph.edges())

            if target_edges:
                edge_completion = len(current_edges & target_edges) / len(target_edges)
            else:
                edge_completion = 1.0

            progress_metrics['edge_completion'] = edge_completion

        # Topology complexity progress
        topology_features = state.current_graph.get_topology_features()
        target_complexity = 0.5  # Baseline complexity target

        if topology_features.get('average_clustering', 0) > target_complexity:
            progress_metrics['complexity_progress'] = 1.0
        else:
            progress_metrics['complexity_progress'] = topology_features.get('average_clustering', 0) / target_complexity

        return progress_metrics

    def get_critical_connections(self, state: AssemblyState) -> List[Tuple[str, str, float]]:
        # Identify connections that would most improve topology
        if state.current_graph.num_connected_components() <= 1:
            return []

        components = state.current_graph.get_connected_components()
        critical_connections = []

        # Find best inter-component connections
        for i, comp1 in enumerate(components):
            for j, comp2 in enumerate(components[i+1:], i+1):
                # Find best connection between these components
                best_score = 0
                best_pair = None

                for motif1 in comp1:
                    for motif2 in comp2:
                        score = self._score_potential_connection(state, motif1, motif2)
                        if score > best_score:
                            best_score = score
                            best_pair = (motif1, motif2)

                if best_pair:
                    critical_connections.append((best_pair[0], best_pair[1], best_score))

        # Sort by importance
        critical_connections.sort(key=lambda x: x[2], reverse=True)

        return critical_connections

    def _score_potential_connection(self, state: AssemblyState, motif1: str, motif2: str) -> float:
        # Score potential connection between two motifs
        if motif1 not in state.current_graph.motifs or motif2 not in state.current_graph.motifs:
            return 0.0

        motif1_obj = state.current_graph.motifs[motif1]
        motif2_obj = state.current_graph.motifs[motif2]

        score = 0.0

        # Available connection sites score
        sites1 = motif1_obj.get_available_sites()
        sites2 = motif2_obj.get_available_sites()

        if sites1 and sites2:
            score += 1.0

        # Chemical compatibility score
        for site1 in sites1:
            for site2 in sites2:
                common_bonds = site1.allowed_bond_types & site2.allowed_bond_types
                if common_bonds:
                    score += len(common_bonds) * 0.2

        # Aromatic bonus
        if motif1_obj.is_aromatic and motif2_obj.is_aromatic:
            score += 0.5

        # Target alignment score (for reconstruction)
        if state.target_graph and state.mode == "reconstruction":
            target_edges = set(state.target_graph.graph.edges())
            if (motif1, motif2) in target_edges or (motif2, motif1) in target_edges:
                score += 2.0

        return score

    def calculate_potential_shaping_reward(self, state: AssemblyState, gamma: float = 0.99,
                                         beta: float = 1.0) -> float:
        # Potential-based reward shaping
        if not state.target_graph or state.mode != "reconstruction":
            return 0.0

        current_edges = set(state.current_graph.graph.edges())
        target_edges = set(state.target_graph.graph.edges())

        # Potential function: negative of remaining work
        missing_edges = len(target_edges - current_edges)
        num_components = state.current_graph.num_connected_components()

        potential = -missing_edges - beta * (num_components - 1)

        return potential

    def get_reward_breakdown(self, prev_state: AssemblyState, action: AssemblyAction,
                           next_state: AssemblyState) -> Dict[str, float]:
        topological_rewards = self.calculate_topological_reward(prev_state, action, next_state)

        # Add topology distance reward
        topological_rewards['topology_distance'] = self.calculate_topology_distance_reward(prev_state, next_state)

        # Add potential shaping
        if next_state.target_graph:
            prev_potential = self.calculate_potential_shaping_reward(prev_state)
            next_potential = self.calculate_potential_shaping_reward(next_state)
            gamma = 0.99  # discount factor

            potential_reward = gamma * next_potential - prev_potential
            topological_rewards['potential_shaping'] = potential_reward

        # Update total
        additional_reward = (
            topological_rewards.get('topology_distance', 0.0) +
            topological_rewards.get('potential_shaping', 0.0)
        )
        topological_rewards['total_topological'] += additional_reward

        return topological_rewards