import numpy as np
import torch
from typing import Dict, List, Optional, Tuple
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors, rdFingerprintGenerator
from ...environment.state import AssemblyState
from ...environment.actions import AssemblyAction
from ...core.representation import MolecularGraph


class ChemicalRewards:
    def __init__(self, weights: Optional[Dict[str, float]] = None):
        self.weights = weights or {
            'validity': 1.0,
            'stability': 0.5,
            'functional_groups': 0.3,
            'property_alignment': 0.8,
            'synthetic_accessibility': 0.2,
            'novelty': 0.1
        }

        # Property target ranges (can be customized)
        self.property_targets = {
            'molecular_weight': (200, 500),
            'logp': (0, 3),
            'tpsa': (20, 130),
            'rotatable_bonds': (2, 10),
            'hbd': (0, 5),
            'hba': (0, 10)
        }

    def calculate_chemical_reward(self, prev_state: AssemblyState, action: AssemblyAction,
                                next_state: AssemblyState) -> Dict[str, float]:
        rewards = {}

        # Chemical validity reward
        rewards['validity'] = self._calculate_validity_reward(action, next_state)

        # Local stability reward
        rewards['stability'] = self._calculate_stability_reward(prev_state, next_state, action)

        # Functional group formation reward
        rewards['functional_groups'] = self._calculate_functional_group_reward(prev_state, next_state)

        # Property alignment reward
        rewards['property_alignment'] = self._calculate_property_alignment_reward(next_state)

        # Synthetic accessibility reward
        rewards['synthetic_accessibility'] = self._calculate_sa_reward(next_state)

        # Novelty reward
        rewards['novelty'] = self._calculate_novelty_reward(next_state)

        # Calculate weighted total
        total_reward = sum(self.weights[key] * reward for key, reward in rewards.items())
        rewards['total_chemical'] = total_reward

        return rewards

    def _calculate_validity_reward(self, action: AssemblyAction, state: AssemblyState) -> float:
        if action.is_stop_action():
            return 0.0

        # Check if the action results in chemically valid connections
        try:
            # Simple validity check - in practice, this would involve more sophisticated validation
            source_motif = state.current_graph.motifs.get(action.source_motif)
            target_motif = state.current_graph.motifs.get(action.target_motif)

            if not source_motif or not target_motif:
                return -1.0

            # Check if bond type is chemically reasonable
            valid_bonds = {'SINGLE': 1.0, 'DOUBLE': 0.8, 'TRIPLE': 0.6, 'AROMATIC': 0.9}
            return valid_bonds.get(action.bond_type, -1.0)

        except Exception:
            return -1.0

    def _calculate_stability_reward(self, prev_state: AssemblyState, next_state: AssemblyState,
                                  action: AssemblyAction) -> float:
        if action.is_stop_action():
            return 0.0

        # Simplified stability calculation based on strain energy estimation
        stability_score = 0.0

        # Bond-specific stability
        bond_stability = {
            'SINGLE': 0.5,
            'DOUBLE': 0.3,
            'TRIPLE': 0.1,
            'AROMATIC': 0.8
        }
        stability_score += bond_stability.get(action.bond_type, 0.0)

        # Avoid creating unstable patterns
        if self._creates_strain(action, next_state):
            stability_score -= 0.3

        # Bonus for aromatic connections
        source_motif = next_state.current_graph.motifs.get(action.source_motif)
        target_motif = next_state.current_graph.motifs.get(action.target_motif)

        if source_motif and target_motif and source_motif.is_aromatic and target_motif.is_aromatic:
            stability_score += 0.2

        return stability_score

    def _creates_strain(self, action: AssemblyAction, state: AssemblyState) -> bool:
        # Simple strain detection - could be more sophisticated
        graph = state.current_graph

        # Check for 2-membered rings (highly strained)
        if graph.graph.has_edge(action.source_motif, action.target_motif):
            return True

        # Check for excessive connections to one motif
        source_degree = graph.graph.degree(action.source_motif) if action.source_motif in graph.graph else 0
        target_degree = graph.graph.degree(action.target_motif) if action.target_motif in graph.graph else 0

        if source_degree > 4 or target_degree > 4:  # High degree might indicate strain
            return True

        return False

    def _calculate_functional_group_reward(self, prev_state: AssemblyState,
                                         next_state: AssemblyState) -> float:
        # Reward for forming or maintaining important functional groups
        prev_groups = self._count_functional_groups(prev_state)
        next_groups = self._count_functional_groups(next_state)

        reward = 0.0

        # Reward for forming new functional groups
        for group, count in next_groups.items():
            prev_count = prev_groups.get(group, 0)
            if count > prev_count:
                group_weight = self._get_functional_group_weight(group)
                reward += (count - prev_count) * group_weight

        return reward

    def _count_functional_groups(self, state: AssemblyState) -> Dict[str, int]:
        group_counts = {}

        for motif in state.current_graph.motifs.values():
            for group in motif.functional_groups:
                group_counts[group] = group_counts.get(group, 0) + 1

        return group_counts

    def _get_functional_group_weight(self, group_name: str) -> float:
        weights = {
            'aromatic_ring': 0.3,
            'carbonyl': 0.2,
            'alcohol': 0.15,
            'amine': 0.15,
            'carboxyl': 0.25,
            'ester': 0.2,
            'amide': 0.2,
            'nitrile': 0.1
        }
        return weights.get(group_name, 0.1)

    def _calculate_property_alignment_reward(self, state: AssemblyState) -> float:
        if not hasattr(state, 'properties_target') or not state.properties:
            return 0.0

        # Calculate current molecular properties (simplified)
        current_props = self._estimate_molecular_properties(state)
        target_props = getattr(state, 'properties_target', {})

        alignment_reward = 0.0
        num_properties = 0

        for prop_name, target_value in target_props.items():
            if prop_name in current_props:
                current_value = current_props[prop_name]

                if prop_name in self.property_targets:
                    target_min, target_max = self.property_targets[prop_name]
                    if target_min <= current_value <= target_max:
                        alignment_reward += 1.0
                    else:
                        # Distance penalty
                        if current_value < target_min:
                            distance = target_min - current_value
                        else:
                            distance = current_value - target_max
                        alignment_reward -= min(distance / (target_max - target_min), 1.0)
                else:
                    # Direct target matching
                    error = abs(current_value - target_value) / max(abs(target_value), 1.0)
                    alignment_reward += max(0, 1.0 - error)

                num_properties += 1

        return alignment_reward / max(num_properties, 1)

    def _estimate_molecular_properties(self, state: AssemblyState) -> Dict[str, float]:
        # Simplified property estimation based on motif properties
        properties = {
            'molecular_weight': 0.0,
            'logp': 0.0,
            'tpsa': 0.0,
            'rotatable_bonds': 0,
            'num_rings': 0,
            'aromatic_rings': 0
        }

        for motif in state.current_graph.motifs.values():
            motif_props = motif.properties

            properties['molecular_weight'] += motif_props.get('molecular_weight', 0.0)
            properties['logp'] += motif_props.get('logp', 0.0)
            properties['tpsa'] += motif_props.get('tpsa', 0.0)
            properties['rotatable_bonds'] += motif_props.get('num_rotatable_bonds', 0)
            properties['num_rings'] += motif_props.get('num_rings', 0)
            properties['aromatic_rings'] += motif_props.get('num_aromatic_rings', 0)

        # Adjust for connections (simplified)
        num_connections = len(state.current_graph.connections)
        properties['rotatable_bonds'] = max(0, properties['rotatable_bonds'] - num_connections)

        return properties

    def _calculate_sa_reward(self, state: AssemblyState) -> float:
        # Simplified synthetic accessibility score
        # In practice, this would use more sophisticated SA scoring

        num_motifs = len(state.current_graph.motifs)
        num_connections = len(state.current_graph.connections)

        # Prefer reasonable number of fragments
        if 3 <= num_motifs <= 10:
            size_bonus = 0.2
        elif num_motifs <= 15:
            size_bonus = 0.1
        else:
            size_bonus = -0.1  # Very large molecules harder to synthesize

        # Prefer well-connected molecules
        if num_motifs > 1:
            connectivity_ratio = num_connections / (num_motifs - 1)  # Ratio to minimum connections
            connectivity_bonus = min(connectivity_ratio * 0.1, 0.2)
        else:
            connectivity_bonus = 0.0

        # Bonus for common functional groups
        functional_bonus = 0.0
        for motif in state.current_graph.motifs.values():
            common_groups = {'aromatic_ring', 'alcohol', 'amine', 'ester'}
            for group in motif.functional_groups:
                if group in common_groups:
                    functional_bonus += 0.05

        return size_bonus + connectivity_bonus + min(functional_bonus, 0.3)

    def _calculate_novelty_reward(self, state: AssemblyState) -> float:
        # Simplified novelty calculation
        # In practice, this would compare against known molecule databases

        # Novel combination of motifs gets small bonus
        unique_motif_types = set()
        for motif in state.current_graph.motifs.values():
            # Use simplified motif type based on functional groups
            motif_type = tuple(sorted(motif.functional_groups))
            unique_motif_types.add(motif_type)

        novelty_score = min(len(unique_motif_types) * 0.1, 0.5)

        # Penalty for overly complex structures
        if len(state.current_graph.motifs) > 20:
            novelty_score -= 0.2

        return max(novelty_score, 0.0)

    def calculate_termination_reward(self, state: AssemblyState) -> float:
        # Additional reward when episode terminates successfully
        if not state.terminated:
            return 0.0

        termination_reward = 0.0

        # Connectivity bonus
        if state.current_graph.num_connected_components() == 1:
            termination_reward += 5.0

        # Completion bonus for reconstruction
        if state.mode == "reconstruction" and state.target_graph:
            current_edges = set(state.current_graph.graph.edges())
            target_edges = set(state.target_graph.graph.edges())

            if current_edges == target_edges:
                termination_reward += 10.0  # Perfect reconstruction
            elif target_edges.issubset(current_edges):
                termination_reward += 7.0   # Over-connected but complete

        # Size appropriateness bonus
        num_motifs = len(state.current_graph.motifs)
        if 3 <= num_motifs <= 15:
            termination_reward += 2.0

        return termination_reward

    def get_reward_breakdown(self, prev_state: AssemblyState, action: AssemblyAction,
                           next_state: AssemblyState) -> Dict[str, float]:
        chemical_rewards = self.calculate_chemical_reward(prev_state, action, next_state)

        if next_state.terminated:
            chemical_rewards['termination'] = self.calculate_termination_reward(next_state)
            chemical_rewards['total_chemical'] += chemical_rewards['termination']

        return chemical_rewards