import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch_geometric.nn.conv import TransformerConv
from torch_geometric.nn.pool import global_max_pool
from typing import Optional
from rl.net import NeuralNetworkModule
from model.resnet import ResnetBlockFC
from model.robot_simple_motor_head_moe.encoder import RLEncoder
from model.robot_simple_motor_head_moe.decoder import RLDecoder


class RobotMotorHeadEncoder(nn.Module):
    def __init__(
        self,
        normalize_radius: float,
        grid_resolution: int = 64,
    ):
        super(RobotMotorHeadEncoder, self).__init__()
        self.normalize_radius = normalize_radius
        self.voxel_encoder_net = RLEncoder(
            f_dim=1,
            c_dim=4,
            hidden_dim=4,
            scatter_type="max",
            grid_resolution=grid_resolution,
            padding=0.02,
            unet3d_kwargs={"layer_order": "cl", "num_levels": 1},
        )
        self.voxel_decoder_net = RLDecoder(
            f_dim=8,
            c_dim=4,
            hidden_dim=4,
            grid_resolution=grid_resolution,
            sample_mode="bilinear",
            padding=0.02,
        )
        self.kinematic_encoder_conv1 = TransformerConv(
            in_channels=22, out_channels=256, edge_dim=9
        )
        self.kinematic_encoder_conv2 = TransformerConv(
            in_channels=256, out_channels=512, edge_dim=9
        )
        self.kinematic_encoder_conv3 = TransformerConv(
            in_channels=512, out_channels=512, edge_dim=9
        )

    def forward(
        self,
        com,
        relative_voxel_positions,
        voxel_features,
        relative_node_positions,
        node_features,
        edges,
        edge_features,
    ):
        """
        Note:
            Voxel num V* corresponds to voxel num, variable across batch
            Node num N* corresponds to rigid body num, variable across batch
            Edge num E* corresponds to joint num, variable across batch
            There are E* * 2 edges for edges in both directions
        Args:
            com: List of length B, inner shape [1, 3]
            relative_voxel_positions: List of length B, inner shape [V*, 3], must be relative
            positions in meters to body com.
            voxel_features: List of length B, inner shape [V*, 1]
            relative_node_positions: List of length B, inner shape [N*, 3], must be relative
            positions in meters to body com.
            node_features: List of length B, inner shape [N*, 14]
            edges: List of length B, inner shape [2, E* * 2]
            edge_features: List of length B, inner shape [E* * 2, 9]
            camera_feature_tensor: list of length B, inner shape [C*, 7]
            camera_pixels: list of length B, inner shape [C*, 3, H, W]
        Returns:
            out_node_features: Shape [N_0 + N_1 + ... + N_(B-1), 512]
            out_pooled_features: Shape [B, 512]
            out_node_batch: Index tensor, Shape [N_0 + N_1 + ... + N_(B-1)]
            out_edge_batch: Index tensor, Shape [(E_0 + E_1 + ... + E_(B-1)) * 2]
            out_edges: Shape [2, (E_0 + E_1 + ... + E_(B-1)) * 2]
            out_unique_edge_batch: Index tensor, Shape [E_0 + E_1 + ... + E_(B-1)]
            out_unique_edges: Shape [2, E_0 + E_1 + ... + E_(B-1)]
        """
        com = t.concatenate(com, dim=0)

        # Sample equal number of voxels for each sample in the batch
        sample_num = min([len(v) for v in relative_voxel_positions])
        voxel_sample_mask = [
            t.randperm(len(v))[:sample_num] for v in relative_voxel_positions
        ]
        relative_voxel_positions = t.stack(
            [v[mask] for v, mask in zip(relative_voxel_positions, voxel_sample_mask)],
            dim=0,
        )
        voxel_features = t.stack(
            [v[mask] for v, mask in zip(voxel_features, voxel_sample_mask)], dim=0
        )

        device = com.device
        norm_voxel_positions = (
            t.clamp(relative_voxel_positions / (2 * self.normalize_radius), -0.5, 0.5)
        ) + 0.5
        latent = self.voxel_encoder_net(norm_voxel_positions, voxel_features)

        node_batch = []
        node_num = []
        max_node_num = 0
        for i in range(len(relative_node_positions)):
            node_batch += [i] * len(relative_node_positions[i])
            node_num.append(len(relative_node_positions[i]))
            max_node_num = max(max_node_num, len(relative_node_positions[i]))
        node_batch = t.tensor(node_batch, dtype=t.long, device=device)
        norm_node_positions_batch = t.zeros(
            [com.shape[0], max_node_num, 3],
            dtype=relative_voxel_positions.dtype,
            device=device,
        )
        for i in range(len(relative_node_positions)):
            norm_node_positions_batch[i, : len(relative_node_positions[i])] = (
                t.clamp(
                    relative_node_positions[i] / (2 * self.normalize_radius), -0.5, 0.5
                )
                + 0.5
            )
        # Add offsets to edge indices
        edge_indices = []
        unique_edge_indices = []
        edge_batch = []
        unique_edge_batch = []
        unique_edge_num_per_robot = []
        offset = 0
        for i, (e, n_num) in enumerate(zip(edges, node_num)):
            edge_batch += [i] * e.shape[1]

            edge_indices.append(e + offset)
            unique_edge_batch += [i] * (e.shape[1] // 2)
            unique_edge_num_per_robot.append(e.shape[1] // 2)
            unique_edge_indices.append((e + offset)[:, : e.shape[1] // 2])
            offset += n_num

        edge_batch = t.tensor(edge_batch, dtype=t.long, device=device)
        edge_indices = t.concatenate(edge_indices, dim=1)
        unique_edge_batch = t.tensor(unique_edge_batch, dtype=t.long, device=device)
        unique_edge_indices = t.concatenate(unique_edge_indices, dim=1)
        edge_features = t.concatenate(edge_features, dim=0)

        node_voxel_features_batch = self.voxel_decoder_net(
            norm_node_positions_batch, latent
        )
        node_voxel_features = t.concatenate(
            [node_voxel_features_batch[i, :num] for i, num in enumerate(node_num)],
            dim=0,
        )

        # Shape [N_0 + N_1 + ... N_(B-1), 8 + 14 = 22]
        node_features = t.concatenate(
            (node_voxel_features, t.concatenate(node_features, dim=0).to(device)),
            dim=1,
        )

        out1 = self.kinematic_encoder_conv1(node_features, edge_indices, edge_features)
        out2 = F.leaky_relu(
            self.kinematic_encoder_conv2(
                F.leaky_relu(out1, negative_slope=0.2), edge_indices, edge_features
            ),
            negative_slope=0.2,
        )
        out_node_feature = self.kinematic_encoder_conv3(
            out2, edge_indices, edge_features
        )

        pooled_out_node_feature = global_max_pool(
            out_node_feature, node_batch, size=com.shape[0]
        )

        return (
            out_node_feature,
            pooled_out_node_feature,
            unique_edge_batch,
            unique_edge_indices,
            unique_edge_num_per_robot,
        )


class RobotMotorHeadCritic(NeuralNetworkModule):
    def __init__(
        self,
        encoder: RobotMotorHeadEncoder,
        freeze_encoder: bool = False,
    ):
        super().__init__()
        self.encoder = encoder
        self.value_resnet = nn.Sequential(
            ResnetBlockFC(515, 256),
            ResnetBlockFC(256, 64),
            ResnetBlockFC(64, 1),
        )
        self.freeze_encoder = freeze_encoder
        self.set_input_module(self.value_resnet)
        self.set_output_module(self.value_resnet)

    def forward(
        self,
        com,
        relative_voxel_positions,
        voxel_features,
        relative_node_positions,
        node_features,
        edges,
        edge_features,
        velocity,
    ):
        if self.freeze_encoder:
            with t.no_grad():
                _, pooled_out_node_features, *__ = self.encoder(
                    com,
                    relative_voxel_positions,
                    voxel_features,
                    relative_node_positions,
                    node_features,
                    edges,
                    edge_features,
                )
        else:
            _, pooled_out_node_features, *__ = self.encoder(
                com,
                relative_voxel_positions,
                voxel_features,
                relative_node_positions,
                node_features,
                edges,
                edge_features,
            )

        # Shape [B, 512+3=515]
        pooled_out_node_features = t.cat(
            (
                pooled_out_node_features,
                t.cat([v for v in velocity], dim=0),
            ),
            dim=1,
        )

        value = self.value_resnet(pooled_out_node_features)
        return value.view(-1)


class RobotActorMoterHeadMoE(NeuralNetworkModule):
    """Mixture-of-experts robot actor with shared encoder."""

    def __init__(
        self,
        encoder: RobotMotorHeadEncoder,
        freeze_encoder: bool = False,
        resolution: int = 5,
        num_experts: int = 8,
        expert_hidden1: int = 128,
        expert_hidden2: int = 64,
        gating_mode: str = "soft",
        gating_threshold: Optional[float] = None,
        gating_top_k: Optional[int] = None,
    ):
        """Initialize the MoE actor.

        Args:
            encoder: Shared geometry/kinematic encoder.
            freeze_encoder: Whether to stop gradients through the encoder.
            resolution: Number of discrete actions.
            num_experts: Number of expert policies.
            expert_hidden1: Hidden width of the first FC layer in each expert.
            expert_hidden2: Hidden width of the second FC layer in each expert.
            gating_mode: 'soft' to weight experts by softmax, 'hard' for top expert(s).
            gating_threshold: Optional minimum gate weight for soft selection.
            gating_top_k: Optional top-k experts to activate (both soft and hard modes).
        """
        super().__init__()
        self.resolution = resolution
        self.encoder = encoder
        self.freeze_encoder = freeze_encoder
        self.num_experts = num_experts
        self.gating_mode = gating_mode
        self.gating_threshold = gating_threshold
        self.gating_top_k = gating_top_k

        # Experts are lightweight; total params across experts stays close to original.
        self.action_experts = nn.ModuleList(
            [
                nn.Sequential(
                    ResnetBlockFC(1539, expert_hidden1),
                    ResnetBlockFC(expert_hidden1, expert_hidden2),
                    ResnetBlockFC(expert_hidden2, resolution),
                )
                for _ in range(num_experts)
            ]
        )
        # Single-layer gate producing logits per expert.
        self.gate = nn.Linear(1539, num_experts)

        self.set_input_module(self.gate)
        self.set_output_module(self.action_experts[0])

    def _mix_logits(self, edge_features: t.Tensor) -> t.Tensor:
        """Compute mixed action logits from all experts for each edge feature row."""
        gate_logits = self.gate(edge_features)  # [E, num_experts]
        expert_outputs = [
            expert(edge_features) for expert in self.action_experts
        ]  # list of [E, resolution]
        expert_stack = t.stack(expert_outputs, dim=1)  # [E, num_experts, resolution]

        if self.gating_mode == "hard":
            # Hard gating always picks a single expert (argmax); top_k > 1 is ignored.
            top_idx = gate_logits.argmax(dim=1)  # [E]
            return expert_stack[range(expert_stack.shape[0]), top_idx]

        # Soft gating: softmax over experts with optional threshold/top-k mask.
        weights = F.softmax(gate_logits, dim=1)
        if self.gating_top_k is not None:
            top_vals, top_idx = gate_logits.topk(self.gating_top_k, dim=1)
            mask = t.zeros_like(weights)
            mask.scatter_(1, top_idx, 1.0)
            weights = weights * mask
        if self.gating_threshold is not None:
            mask = (weights >= self.gating_threshold).float()
            weights = weights * mask
        weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-8)
        mixed = (expert_stack * weights.unsqueeze(-1)).sum(dim=1)
        return mixed

    def forward(
        self,
        com,
        relative_voxel_positions,
        voxel_features,
        relative_node_positions,
        node_features,
        edges,
        edge_features,
        velocity,
        action=None,
    ):
        """Run MoE policy. Signature matches RobotActor for drop-in use."""
        batch_size = len(com)
        if self.freeze_encoder:
            with t.no_grad():
                (
                    out_node_features,
                    pooled_out_node_features,
                    out_unique_edge_batch,
                    out_unique_edge_indices,
                    unique_edge_num_per_robot,
                ) = self.encoder(
                    com,
                    relative_voxel_positions,
                    voxel_features,
                    relative_node_positions,
                    node_features,
                    edges,
                    edge_features,
                )
        else:
            (
                out_node_features,
                pooled_out_node_features,
                out_unique_edge_batch,
                out_unique_edge_indices,
                unique_edge_num_per_robot,
            ) = self.encoder(
                com,
                relative_voxel_positions,
                voxel_features,
                relative_node_positions,
                node_features,
                edges,
                edge_features,
            )

        # out_node_features: [sum_N, 512], pooled_out_node_features: [B, 512]
        # out_unique_edge_indices: [2, sum_E], out_unique_edge_batch: [sum_E]
        edge_total = out_unique_edge_indices.shape[1]
        first_node_features = t.gather(
            out_node_features,
            dim=0,
            index=out_unique_edge_indices[0].unsqueeze(1).expand(edge_total, 512),
        )  # [sum_E, 512]
        second_node_features = t.gather(
            out_node_features,
            dim=0,
            index=out_unique_edge_indices[1].unsqueeze(1).expand(edge_total, 512),
        )  # [sum_E, 512]
        node_pooled_features = t.gather(
            pooled_out_node_features,
            dim=0,
            index=out_unique_edge_batch.unsqueeze(1).expand(
                out_unique_edge_batch.shape[0], 512
            ),
        )  # [sum_E, 512] broadcasted pooled feature per edge

        edge_features_of_robot = t.cat(
            (first_node_features, second_node_features, node_pooled_features), dim=1
        )  # [sum_E, 1536]
        edge_features_of_robot = t.cat(
            (
                edge_features_of_robot,
                t.cat(
                    [
                        v.expand(unique_edge_num_per_robot[i], 3)
                        for i, v in enumerate(velocity)
                    ],
                    dim=0,
                ),
            ),
            dim=1,
        )  # [sum_E, 1539] append per-robot velocity repeated per edge

        # edge_features_of_robot: [sum_E, 1539] where sum_E is total unique edges
        action_logits = self._mix_logits(edge_features_of_robot)
        dist = Categorical(logits=action_logits)
        raw_action = dist.sample() if action is None else t.concatenate(action)
        raw_log_prob = dist.log_prob(raw_action)
        raw_entropy = dist.entropy()

        action_out = []
        log_prob = []
        entropy = []
        for idx in range(batch_size):
            action_out.append(raw_action[out_unique_edge_batch == idx])
            log_prob.append(t.sum(raw_log_prob[out_unique_edge_batch == idx]))
            entropy.append(t.sum(raw_entropy[out_unique_edge_batch == idx]))
        log_prob = t.stack(log_prob)
        entropy = t.stack(entropy)
        return action_out, log_prob, entropy
