import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch_geometric.nn.conv import TransformerConv
from torch_geometric.nn.pool import global_max_pool
from rl.net import NeuralNetworkModule
from model.resnet import ResnetBlockFC
from model.robot_simple.encoder import RLEncoder
from model.robot_simple.decoder import RLDecoder


class RobotEncoder(nn.Module):
    def __init__(
        self,
        normalize_radius: float,
        grid_resolution: int = 64,
    ):
        super(RobotEncoder, self).__init__()
        self.normalize_radius = normalize_radius
        self.voxel_encoder_net = RLEncoder(
            f_dim=1,
            c_dim=4,
            hidden_dim=4,
            scatter_type="max",
            grid_resolution=grid_resolution,
            padding=0.02,
            unet3d_kwargs={"layer_order": "cl", "num_levels": 1},
        )
        self.voxel_decoder_net = RLDecoder(
            f_dim=8,
            c_dim=4,
            hidden_dim=4,
            grid_resolution=grid_resolution,
            sample_mode="bilinear",
            padding=0.02,
        )
        self.kinematic_encoder_conv1 = TransformerConv(
            in_channels=22, out_channels=256, edge_dim=9
        )
        self.kinematic_encoder_conv2 = TransformerConv(
            in_channels=256, out_channels=512, edge_dim=9
        )
        self.kinematic_encoder_conv3 = TransformerConv(
            in_channels=512, out_channels=512, edge_dim=9
        )

    def forward(
        self,
        com,
        relative_voxel_positions,
        voxel_features,
        relative_node_positions,
        node_features,
        edges,
        edge_features,
    ):
        """
        Note:
            Voxel num V* corresponds to voxel num, variable across batch
            Node num N* corresponds to rigid body num, variable across batch
            Edge num E* corresponds to joint num, variable across batch
            There are E* * 2 edges for edges in both directions
        Args:
            com: List of length B, inner shape [1, 3]
            relative_voxel_positions: List of length B, inner shape [V*, 3], must be relative
            positions in meters to body com.
            voxel_features: List of length B, inner shape [V*, 1]
            relative_node_positions: List of length B, inner shape [N*, 3], must be relative
            positions in meters to body com.
            node_features: List of length B, inner shape [N*, 14]
            edges: List of length B, inner shape [2, E* * 2]
            edge_features: List of length B, inner shape [E* * 2, 9]
            camera_feature_tensor: list of length B, inner shape [C*, 7]
            camera_pixels: list of length B, inner shape [C*, 3, H, W]
        Returns:
            out_node_features: Shape [N_0 + N_1 + ... + N_(B-1), 512]
            out_pooled_features: Shape [B, 512]
            out_node_batch: Index tensor, Shape [N_0 + N_1 + ... + N_(B-1)]
            out_edge_batch: Index tensor, Shape [(E_0 + E_1 + ... + E_(B-1)) * 2]
            out_edges: Shape [2, (E_0 + E_1 + ... + E_(B-1)) * 2]
            out_unique_edge_batch: Index tensor, Shape [E_0 + E_1 + ... + E_(B-1)]
            out_unique_edges: Shape [2, E_0 + E_1 + ... + E_(B-1)]
        """
        com = t.concatenate(com, dim=0)

        # Sample equal number of voxels for each sample in the batch
        sample_num = min([len(v) for v in relative_voxel_positions])
        voxel_sample_mask = [
            t.randperm(len(v))[:sample_num] for v in relative_voxel_positions
        ]
        relative_voxel_positions = t.stack(
            [v[mask] for v, mask in zip(relative_voxel_positions, voxel_sample_mask)],
            dim=0,
        )
        voxel_features = t.stack(
            [v[mask] for v, mask in zip(voxel_features, voxel_sample_mask)], dim=0
        )

        device = com.device
        norm_voxel_positions = (
            t.clamp(relative_voxel_positions / (2 * self.normalize_radius), -0.5, 0.5)
        ) + 0.5
        latent = self.voxel_encoder_net(norm_voxel_positions, voxel_features)

        node_batch = []
        node_num = []
        max_node_num = 0
        for i in range(len(relative_node_positions)):
            node_batch += [i] * len(relative_node_positions[i])
            node_num.append(len(relative_node_positions[i]))
            max_node_num = max(max_node_num, len(relative_node_positions[i]))
        node_batch = t.tensor(node_batch, dtype=t.long, device=device)
        norm_node_positions_batch = t.zeros(
            [com.shape[0], max_node_num, 3],
            dtype=relative_voxel_positions.dtype,
            device=device,
        )
        for i in range(len(relative_node_positions)):
            norm_node_positions_batch[i, : len(relative_node_positions[i])] = (
                t.clamp(
                    relative_node_positions[i] / (2 * self.normalize_radius), -0.5, 0.5
                )
                + 0.5
            )
        # Add offsets to edge indices
        edge_indices = []
        unique_edge_indices = []
        edge_batch = []
        unique_edge_batch = []
        unique_edge_num_per_robot = []
        offset = 0
        for i, (e, n_num) in enumerate(zip(edges, node_num)):
            edge_batch += [i] * e.shape[1]

            edge_indices.append(e + offset)
            unique_edge_batch += [i] * (e.shape[1] // 2)
            unique_edge_num_per_robot.append(e.shape[1] // 2)
            unique_edge_indices.append((e + offset)[:, : e.shape[1] // 2])
            offset += n_num

        edge_batch = t.tensor(edge_batch, dtype=t.long, device=device)
        edge_indices = t.concatenate(edge_indices, dim=1)
        unique_edge_batch = t.tensor(unique_edge_batch, dtype=t.long, device=device)
        unique_edge_indices = t.concatenate(unique_edge_indices, dim=1)
        edge_features = t.concatenate(edge_features, dim=0)

        node_voxel_features_batch = self.voxel_decoder_net(
            norm_node_positions_batch, latent
        )
        node_voxel_features = t.concatenate(
            [node_voxel_features_batch[i, :num] for i, num in enumerate(node_num)],
            dim=0,
        )

        # Shape [N_0 + N_1 + ... N_(B-1), 8 + 14 = 22]
        node_features = t.concatenate(
            (node_voxel_features, t.concatenate(node_features, dim=0).to(device)),
            dim=1,
        )

        out1 = self.kinematic_encoder_conv1(node_features, edge_indices, edge_features)
        out2 = F.leaky_relu(
            self.kinematic_encoder_conv2(
                F.leaky_relu(out1, negative_slope=0.2), edge_indices, edge_features
            ),
            negative_slope=0.2,
        )
        out_node_feature = self.kinematic_encoder_conv3(
            out2, edge_indices, edge_features
        )

        pooled_out_node_feature = global_max_pool(
            out_node_feature, node_batch, size=com.shape[0]
        )

        return (
            out_node_feature,
            pooled_out_node_feature,
            unique_edge_batch,
            unique_edge_indices,
            unique_edge_num_per_robot,
        )


class RobotActor(NeuralNetworkModule):
    def __init__(
        self,
        encoder: RobotEncoder,
        freeze_encoder: bool = False,
        resolution: int = 5,
    ):
        super().__init__()
        self.resolution = resolution
        self.encoder = encoder
        self.action_resnet = nn.Sequential(
            ResnetBlockFC(1539, 256),
            ResnetBlockFC(256, 64),
            ResnetBlockFC(64, resolution),
        )
        self.freeze_encoder = freeze_encoder
        self.set_input_module(self.action_resnet)
        self.set_output_module(self.action_resnet)

    def forward(
        self,
        com,
        relative_voxel_positions,
        voxel_features,
        relative_node_positions,
        node_features,
        edges,
        edge_features,
        velocity,
        action=None,
    ):
        # add velocity argument
        batch_size = len(com)
        if self.freeze_encoder:
            with t.no_grad():
                (
                    out_node_features,
                    pooled_out_node_features,
                    out_unique_edge_batch,
                    out_unique_edge_indices,
                    unique_edge_num_per_robot,
                ) = self.encoder(
                    com,
                    relative_voxel_positions,
                    voxel_features,
                    relative_node_positions,
                    node_features,
                    edges,
                    edge_features,
                )
        else:
            (
                out_node_features,
                pooled_out_node_features,
                out_unique_edge_batch,
                out_unique_edge_indices,
                unique_edge_num_per_robot,
            ) = self.encoder(
                com,
                relative_voxel_positions,
                voxel_features,
                relative_node_positions,
                node_features,
                edges,
                edge_features,
            )
        # Shape [E_0 + E_1 + ... + E_(B-1), 512]
        edge_total = out_unique_edge_indices.shape[1]
        first_node_features = t.gather(
            out_node_features,
            dim=0,
            index=out_unique_edge_indices[0].unsqueeze(1).expand(edge_total, 512),
        )
        second_node_features = t.gather(
            out_node_features,
            dim=0,
            index=out_unique_edge_indices[1].unsqueeze(1).expand(edge_total, 512),
        )
        node_pooled_features = t.gather(
            pooled_out_node_features,
            dim=0,
            index=out_unique_edge_batch.unsqueeze(1).expand(
                out_unique_edge_batch.shape[0], 512
            ),
        )

        # Shape [E_0 + E_1 + ... + E_(B-1), 512+512+512=1536]
        edge_features_of_robot = t.cat(
            (
                first_node_features,
                second_node_features,
                node_pooled_features,
            ),
            dim=1,
        )

        # Shape [E_0 + E_1 + ... + E_(B-1), 1536+3=1539]
        edge_features_of_robot = t.cat(
            (
                edge_features_of_robot,
                t.cat(
                    [
                        v.expand(unique_edge_num_per_robot[i], 3)
                        for i, v in enumerate(velocity)
                    ],
                    dim=0,
                ),
            ),
            dim=1,
        )

        action_logits = self.action_resnet(edge_features_of_robot)
        dist = Categorical(logits=action_logits)
        raw_action = dist.sample() if action is None else t.concatenate(action)
        raw_log_prob = dist.log_prob(raw_action)
        raw_entropy = dist.entropy()

        # Now convert from concatenated batch to list batch
        action = []
        log_prob = []
        entropy = []
        for idx in range(batch_size):
            action.append(raw_action[out_unique_edge_batch == idx])
            log_prob.append(t.sum(raw_log_prob[out_unique_edge_batch == idx]))
            entropy.append(t.sum(raw_entropy[out_unique_edge_batch == idx]))
        log_prob = t.stack(log_prob)
        entropy = t.stack(entropy)
        return action, log_prob, entropy


class RobotCritic(NeuralNetworkModule):
    def __init__(
        self,
        encoder: RobotEncoder,
        freeze_encoder: bool = False,
    ):
        super().__init__()
        self.encoder = encoder
        self.value_resnet = nn.Sequential(
            ResnetBlockFC(515, 256),
            ResnetBlockFC(256, 64),
            ResnetBlockFC(64, 1),
        )
        self.freeze_encoder = freeze_encoder
        self.set_input_module(self.value_resnet)
        self.set_output_module(self.value_resnet)

    def forward(
        self,
        com,
        relative_voxel_positions,
        voxel_features,
        relative_node_positions,
        node_features,
        edges,
        edge_features,
        velocity,
    ):
        if self.freeze_encoder:
            with t.no_grad():
                _, pooled_out_node_features, *__ = self.encoder(
                    com,
                    relative_voxel_positions,
                    voxel_features,
                    relative_node_positions,
                    node_features,
                    edges,
                    edge_features,
                )
        else:
            _, pooled_out_node_features, *__ = self.encoder(
                com,
                relative_voxel_positions,
                voxel_features,
                relative_node_positions,
                node_features,
                edges,
                edge_features,
            )

        # Shape [B, 512+3=515]
        pooled_out_node_features = t.cat(
            (
                pooled_out_node_features,
                t.cat([v for v in velocity], dim=0),
            ),
            dim=1,
        )

        value = self.value_resnet(pooled_out_node_features)
        return value.view(-1)
