import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional, Any
from transformers import AutoModel, AutoTokenizer
from torch_geometric.nn import GCNConv, GlobalAttention, global_mean_pool
from ...environment.state import AssemblyState
from ...core.topology import TopologyAnalyzer


class MultiModalFusion(nn.Module):
    def __init__(self, text_dim: int, graph_dim: int, mask_dim: int, topo_dim: int, hidden_dim: int):
        super().__init__()

        self.text_projection = nn.Linear(text_dim, hidden_dim)
        self.graph_projection = nn.Linear(graph_dim, hidden_dim)
        self.mask_projection = nn.Linear(mask_dim, hidden_dim)
        self.topo_projection = nn.Linear(topo_dim, hidden_dim)

        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
        self.fusion_layer = nn.Sequential(
            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )

    def forward(self, text_features: torch.Tensor, graph_features: torch.Tensor,
                mask_features: torch.Tensor, topo_features: torch.Tensor) -> torch.Tensor:

        # Project all features to common dimension
        text_proj = self.text_projection(text_features)
        graph_proj = self.graph_projection(graph_features)
        mask_proj = self.mask_projection(mask_features)
        topo_proj = self.topo_projection(topo_features)

        # Stack for attention
        stacked_features = torch.stack([text_proj, graph_proj, mask_proj, topo_proj], dim=1)

        # Apply self-attention
        attended_features, _ = self.attention(stacked_features, stacked_features, stacked_features)

        # Concatenate and fuse
        fused = attended_features.flatten(start_dim=1)
        output = self.fusion_layer(fused)

        return output


class GraphEncoder(nn.Module):
    def __init__(self, node_features: int = 64, hidden_dim: int = 128, num_layers: int = 3):
        super().__init__()

        self.node_embedding = nn.Embedding(100, node_features)  # Max 100 motifs
        self.convs = nn.ModuleList([
            GCNConv(node_features if i == 0 else hidden_dim, hidden_dim)
            for i in range(num_layers)
        ])

        self.global_pool = GlobalAttention(nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        ))

        self.output_dim = hidden_dim

    def forward(self, node_ids: torch.Tensor, edge_index: torch.Tensor,
                batch: Optional[torch.Tensor] = None) -> torch.Tensor:

        x = self.node_embedding(node_ids)

        for conv in self.convs:
            x = F.relu(conv(x, edge_index))

        # Global pooling
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        graph_embedding = self.global_pool(x, batch)

        return graph_embedding


class CentralizedCritic(nn.Module):
    def __init__(self, llm_model_name: str = "microsoft/DialoGPT-medium",
                 hidden_dim: int = 256, max_motifs: int = 50):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.max_motifs = max_motifs

        # Text encoder (shared with actor)
        self.tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.text_encoder = AutoModel.from_pretrained(llm_model_name)
        self.text_dim = self.text_encoder.config.hidden_size

        # Graph encoder
        self.graph_encoder = GraphEncoder(hidden_dim=hidden_dim)
        self.graph_dim = self.graph_encoder.output_dim

        # Mask encoder
        self.mask_dim = max_motifs * 5  # Rough estimate for flattened masks
        self.mask_encoder = nn.Sequential(
            nn.Linear(self.mask_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Topology statistics encoder
        self.topo_dim = 20  # Based on topology features
        self.topo_encoder = nn.Sequential(
            nn.Linear(self.topo_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim)
        )

        # Multimodal fusion
        self.fusion = MultiModalFusion(
            text_dim=self.text_dim,
            graph_dim=self.graph_dim,
            mask_dim=hidden_dim,
            topo_dim=hidden_dim,
            hidden_dim=hidden_dim
        )

        # Value heads
        self.main_value_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, 1)
        )

        # Auxiliary heads for multi-task learning
        self.edges_prediction_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )

        self.components_prediction_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )

        self.topology_analyzer = TopologyAnalyzer()

    def forward(self, state: AssemblyState, action_masks: Optional[Dict[str, torch.Tensor]] = None) -> Dict[str, torch.Tensor]:
        # Encode text representation
        text_features = self._encode_text(state)

        # Encode graph structure
        graph_features = self._encode_graph(state)

        # Encode action masks
        mask_features = self._encode_masks(action_masks or {})

        # Encode topology statistics
        topo_features = self._encode_topology(state)

        # Multimodal fusion
        fused_features = self.fusion(text_features, graph_features, mask_features, topo_features)

        # Generate value predictions
        main_value = self.main_value_head(fused_features)
        edges_pred = self.edges_prediction_head(fused_features)
        components_pred = self.components_prediction_head(fused_features)

        return {
            'main_value': main_value,
            'edges_prediction': edges_pred,
            'components_prediction': components_pred,
            'fused_features': fused_features
        }

    def _encode_text(self, state: AssemblyState) -> torch.Tensor:
        text_repr = state.to_text_representation()

        inputs = self.tokenizer(
            text_repr,
            return_tensors="pt",
            max_length=1024,
            truncation=True,
            padding=True
        )

        with torch.no_grad():
            outputs = self.text_encoder(**inputs)

        # Use pooled output or last hidden state
        if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
            return outputs.pooler_output
        else:
            return outputs.last_hidden_state[:, -1, :]

    def _encode_graph(self, state: AssemblyState) -> torch.Tensor:
        graph = state.current_graph

        if not graph.motifs:
            return torch.zeros(1, self.graph_dim)

        # Create node IDs and edge indices
        motif_ids = list(graph.motifs.keys())
        node_ids = torch.tensor(range(len(motif_ids)), dtype=torch.long)

        # Create edge index from connections
        edge_list = []
        for connection in graph.connections:
            source_idx = motif_ids.index(connection.source_motif)
            target_idx = motif_ids.index(connection.target_motif)
            edge_list.extend([[source_idx, target_idx], [target_idx, source_idx]])  # Undirected

        if edge_list:
            edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
        else:
            # No edges - create self-loops to avoid empty graph
            edge_index = torch.stack([node_ids, node_ids])

        # Encode graph
        graph_embedding = self.graph_encoder(node_ids, edge_index)

        return graph_embedding

    def _encode_masks(self, action_masks: Dict[str, torch.Tensor]) -> torch.Tensor:
        if not action_masks:
            return torch.zeros(1, self.hidden_dim)

        # Flatten and concatenate all masks
        mask_tensors = []
        for key in ['source_motif', 'source_site', 'target_motif', 'target_site', 'bond_type']:
            if key in action_masks:
                mask = action_masks[key]
                # Flatten the mask
                flattened = mask.flatten()
                # Pad or truncate to fixed size
                if len(flattened) < self.max_motifs:
                    flattened = F.pad(flattened, (0, self.max_motifs - len(flattened)))
                else:
                    flattened = flattened[:self.max_motifs]
                mask_tensors.append(flattened.float())
            else:
                mask_tensors.append(torch.zeros(self.max_motifs))

        concatenated_masks = torch.cat(mask_tensors)

        # Pad to expected dimension
        if len(concatenated_masks) < self.mask_dim:
            concatenated_masks = F.pad(concatenated_masks, (0, self.mask_dim - len(concatenated_masks)))
        else:
            concatenated_masks = concatenated_masks[:self.mask_dim]

        return self.mask_encoder(concatenated_masks.unsqueeze(0))

    def _encode_topology(self, state: AssemblyState) -> torch.Tensor:
        topology_features = self.topology_analyzer.analyze_topology(state.current_graph)

        # Extract key features and pad/truncate to fixed size
        feature_values = []
        expected_features = [
            'num_nodes', 'num_edges', 'density', 'num_connected_components',
            'is_connected', 'diameter', 'radius', 'average_clustering',
            'transitivity', 'average_degree', 'degree_centrality_mean',
            'betweenness_centrality_mean', 'closeness_centrality_mean',
            'num_cycles', 'edge_connectivity', 'node_connectivity',
            'degree_std', 'max_degree', 'min_degree', 'avg_shortest_path'
        ]

        for feature_name in expected_features:
            value = topology_features.get(feature_name, 0.0)
            # Handle inf values
            if value == float('inf'):
                value = 1000.0
            elif value == float('-inf'):
                value = -1000.0
            feature_values.append(float(value))

        # Ensure we have exactly topo_dim features
        while len(feature_values) < self.topo_dim:
            feature_values.append(0.0)
        feature_values = feature_values[:self.topo_dim]

        topo_tensor = torch.tensor(feature_values, dtype=torch.float32).unsqueeze(0)
        return self.topo_encoder(topo_tensor)

    def get_assembly_progress_estimation(self, state: AssemblyState) -> Dict[str, float]:
        outputs = self.forward(state)

        # Extract progress indicators from auxiliary heads
        remaining_edges = max(0, outputs['edges_prediction'].item())
        predicted_components = max(1, outputs['components_prediction'].item())

        current_components = state.current_graph.num_connected_components()
        current_edges = len(state.current_graph.connections)

        progress = {
            'estimated_remaining_edges': remaining_edges,
            'predicted_components': predicted_components,
            'current_components': current_components,
            'connectivity_progress': 1.0 - (current_components - 1) / max(len(state.current_graph.motifs) - 1, 1)
        }

        if state.target_graph and state.mode == "reconstruction":
            target_edges = len(state.target_graph.connections)
            progress['edge_completion_ratio'] = current_edges / max(target_edges, 1)

        return progress

    def compute_value_targets(self, state: AssemblyState) -> Dict[str, float]:
        # Compute auxiliary target values for training
        if state.target_graph and state.mode == "reconstruction":
            current_edges = set(state.current_graph.graph.edges())
            target_edges = set(state.target_graph.graph.edges())
            remaining_edges = len(target_edges - current_edges)
        else:
            # Estimate remaining edges needed for connectivity
            remaining_edges = max(0, state.current_graph.num_connected_components() - 1)

        return {
            'remaining_edges': float(remaining_edges),
            'current_components': float(state.current_graph.num_connected_components())
        }