import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional, Any
from transformers import AutoModel, AutoTokenizer
from ...environment.state import AssemblyState
from ...environment.actions import AssemblyAction, ActionSpace


class LLMPromptTemplate:
    def __init__(self):
        self.system_prompt = self._create_system_prompt()

    def _create_system_prompt(self) -> str:
        return """You are a molecular assembly expert responsible for connecting molecular fragments (motifs) to form complete molecules.

Your task is to analyze the current molecular topology and suggest the best connection between available motifs based on:
1. Chemical compatibility and valence rules
2. Topological connectivity requirements
3. Target molecular structure (if provided)
4. Thermodynamic stability considerations

Always consider the complete 2D molecular topology when making decisions."""

    def create_state_prompt(self, state: AssemblyState, motif_focus: Optional[str] = None) -> str:
        prompt_parts = [
            "CURRENT_ASSEMBLY_STATE:",
            f"Mode: {state.mode}",
            f"Step: {state.step}/{state.max_steps}",
            f"Connected_Components: {state.current_graph.num_connected_components()}",
            "",
            "AVAILABLE_MOTIFS:"
        ]

        for motif_id in state.available_motifs:
            motif = state.current_graph.motifs[motif_id]
            motif_info = motif.to_text_representation()
            if motif_focus and motif_id == motif_focus:
                motif_info = f">>> FOCUS MOTIF >>>\n{motif_info}\n<<< FOCUS MOTIF <<<"
            prompt_parts.append(motif_info)
            prompt_parts.append("")

        prompt_parts.extend([
            "CURRENT_CONNECTIONS:",
            *[conn.to_text_representation() for conn in state.current_graph.connections],
            ""
        ])

        if state.target_graph and state.mode == "reconstruction":
            prompt_parts.extend([
                "TARGET_TOPOLOGY:",
                state.target_graph.to_text_representation(),
                ""
            ])

        topology_features = state.current_graph.get_topology_features()
        prompt_parts.extend([
            "TOPOLOGY_ANALYSIS:",
            f"Density: {topology_features.get('density', 0):.3f}",
            f"Clustering: {topology_features.get('clustering_coefficient', 0):.3f}",
            f"Connected: {topology_features.get('is_connected', 0)}",
            ""
        ])

        return "\n".join(prompt_parts)

    def create_action_prompt(self, state: AssemblyState, valid_actions: List[AssemblyAction]) -> str:
        action_prompt = [
            "TASK: Select the best action from the following options:",
            ""
        ]

        for i, action in enumerate(valid_actions):
            action_prompt.append(f"{i}: {action.to_text_representation()}")

        action_prompt.extend([
            "",
            "Consider:",
            "1. Chemical compatibility between connection sites",
            "2. Progress toward connectivity goals",
            "3. Maintenance of chemical stability",
            "4. Topological requirements",
            "",
            "Provide your analysis and selected action index."
        ])

        return "\n".join(action_prompt)


class LLMActor(nn.Module):
    def __init__(self, model_name: str = "microsoft/DialoGPT-medium",
                 max_length: int = 1024, action_space: Optional[ActionSpace] = None):
        super().__init__()

        self.model_name = model_name
        self.max_length = max_length
        self.action_space = action_space or ActionSpace()

        # Initialize LLM backbone
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.llm_backbone = AutoModel.from_pretrained(model_name)
        self.hidden_size = self.llm_backbone.config.hidden_size

        # Action heads for hierarchical sampling
        self.action_heads = nn.ModuleDict({
            'source_motif': nn.Linear(self.hidden_size, self.action_space.action_dim['source_motif']),
            'source_site': nn.Linear(self.hidden_size, self.action_space.action_dim['source_site']),
            'target_motif': nn.Linear(self.hidden_size, self.action_space.action_dim['target_motif']),
            'target_site': nn.Linear(self.hidden_size, self.action_space.action_dim['target_site']),
            'bond_type': nn.Linear(self.hidden_size, self.action_space.action_dim['bond_type'])
        })

        # Value head for advantage calculation
        self.value_head = nn.Linear(self.hidden_size, 1)

        self.prompt_template = LLMPromptTemplate()

    def forward(self, state: AssemblyState, motif_focus: Optional[str] = None) -> Dict[str, torch.Tensor]:
        prompt = self.prompt_template.create_state_prompt(state, motif_focus)

        # Tokenize prompt
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            max_length=self.max_length,
            truncation=True,
            padding=True
        )

        # Get LLM representations
        with torch.no_grad():
            outputs = self.llm_backbone(**inputs)

        # Use [CLS] token or last token representation
        if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
            hidden_state = outputs.pooler_output
        else:
            # Use last token representation
            hidden_state = outputs.last_hidden_state[:, -1, :]

        # Generate action logits
        action_logits = {}
        for action_type, head in self.action_heads.items():
            action_logits[action_type] = head(hidden_state)

        # Apply action masks
        masks = self.action_space.create_action_masks(state)
        for action_type in action_logits:
            if action_type in masks:
                mask = masks[action_type]
                if mask.dim() == 1:
                    action_logits[action_type] = action_logits[action_type].masked_fill(~mask, float('-inf'))

        return {
            'action_logits': action_logits,
            'hidden_state': hidden_state,
            'value': self.value_head(hidden_state)
        }

    def get_action_probabilities(self, state: AssemblyState, motif_focus: Optional[str] = None) -> Dict[str, torch.Tensor]:
        outputs = self.forward(state, motif_focus)
        action_logits = outputs['action_logits']

        action_probs = {}
        for action_type, logits in action_logits.items():
            action_probs[action_type] = F.softmax(logits, dim=-1)

        return action_probs

    def sample_action(self, state: AssemblyState, motif_focus: Optional[str] = None,
                     temperature: float = 1.0) -> Tuple[AssemblyAction, Dict[str, torch.Tensor]]:
        outputs = self.forward(state, motif_focus)
        action_logits = outputs['action_logits']

        # Apply temperature scaling
        for action_type in action_logits:
            action_logits[action_type] = action_logits[action_type] / temperature

        # Sample action using action space
        action = self.action_space.sample_action(state, action_logits)

        return action, outputs

    def get_action_log_prob(self, state: AssemblyState, action: AssemblyAction,
                           motif_focus: Optional[str] = None) -> torch.Tensor:
        outputs = self.forward(state, motif_focus)
        action_logits = outputs['action_logits']

        if action.is_stop_action():
            # STOP action is the last index in source_motif
            source_probs = F.log_softmax(action_logits['source_motif'], dim=-1)
            return source_probs[0, -1]  # Assuming batch_size=1

        # Calculate log probability for connection action
        motif_ids = list(state.available_motifs)

        try:
            source_idx = motif_ids.index(action.source_motif)
            target_idx = motif_ids.index(action.target_motif)
            bond_idx = self.action_space.bond_types.index(action.bond_type)
        except ValueError:
            return torch.tensor(float('-inf'))

        # Get log probabilities for each component
        source_log_prob = F.log_softmax(action_logits['source_motif'], dim=-1)[0, source_idx]
        source_site_log_prob = F.log_softmax(action_logits['source_site'][source_idx], dim=-1)[action.source_site]
        target_log_prob = F.log_softmax(action_logits['target_motif'][source_idx, action.source_site], dim=-1)[target_idx]
        target_site_log_prob = F.log_softmax(action_logits['target_site'][source_idx, action.source_site, target_idx], dim=-1)[action.target_site]
        bond_log_prob = F.log_softmax(action_logits['bond_type'][source_idx, action.source_site, target_idx, action.target_site], dim=-1)[bond_idx]

        total_log_prob = (source_log_prob + source_site_log_prob +
                         target_log_prob + target_site_log_prob + bond_log_prob)

        return total_log_prob

    def analyze_action_reasoning(self, state: AssemblyState, action: AssemblyAction) -> str:
        # Generate textual reasoning for the action
        if action.is_stop_action():
            reasoning_parts = [
                "STOP ACTION ANALYSIS:",
                f"Current connectivity: {state.current_graph.num_connected_components()} components"
            ]

            if state.mode == "reconstruction" and state.target_graph:
                current_edges = set(state.current_graph.graph.edges())
                target_edges = set(state.target_graph.graph.edges())
                completion = len(current_edges & target_edges) / len(target_edges) if target_edges else 1.0
                reasoning_parts.append(f"Reconstruction progress: {completion:.2%}")

            return "\n".join(reasoning_parts)

        reasoning_parts = [
            f"CONNECTION ANALYSIS: {action.to_text_representation()}",
            f"Chemical rationale: Connecting {action.bond_type} bond between compatible sites"
        ]

        # Analyze motifs involved
        source_motif = state.current_graph.motifs[action.source_motif]
        target_motif = state.current_graph.motifs[action.target_motif]

        reasoning_parts.extend([
            f"Source motif properties: {source_motif.functional_groups}",
            f"Target motif properties: {target_motif.functional_groups}",
            f"Topological impact: Will reduce components from {state.current_graph.num_connected_components()}"
        ])

        return "\n".join(reasoning_parts)


class MotifAgent:
    def __init__(self, motif_id: str, actor: LLMActor):
        self.motif_id = motif_id
        self.actor = actor
        self.action_history: List[AssemblyAction] = []

    def propose_action(self, state: AssemblyState, temperature: float = 1.0) -> Tuple[AssemblyAction, Dict]:
        action, outputs = self.actor.sample_action(state, motif_focus=self.motif_id, temperature=temperature)

        # Add reasoning
        reasoning = self.actor.analyze_action_reasoning(state, action)
        action.reasoning = reasoning

        proposal = {
            'action': action,
            'confidence': outputs.get('value', torch.tensor(0.0)).item(),
            'motif_id': self.motif_id,
            'reasoning': reasoning,
            'outputs': outputs
        }

        return action, proposal

    def update_history(self, action: AssemblyAction):
        self.action_history.append(action)

    def get_local_state_representation(self, state: AssemblyState) -> str:
        if self.motif_id not in state.current_graph.motifs:
            return f"Motif {self.motif_id} not found in current graph"

        motif = state.current_graph.motifs[self.motif_id]
        connections = [c for c in state.current_graph.connections
                      if c.source_motif == self.motif_id or c.target_motif == self.motif_id]

        local_repr = [
            f"LOCAL_STATE for {self.motif_id}:",
            motif.to_text_representation(),
            f"Current connections: {len(connections)}",
            *[c.to_text_representation() for c in connections],
            f"Available sites: {len(motif.get_available_sites())}"
        ]

        return "\n".join(local_repr)