import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch_geometric.nn.conv import TransformerConv
from torch_geometric.nn.pool import global_max_pool
from rl.net import NeuralNetworkModule
from model.resnet import ResnetBlockFC
from model.robot.encoder import RLEncoder
from model.robot.decoder import RLDecoder

from model.robot.resnet_conv import ResNetFeatureExtractor
from model.robot.multi_head_attention import MultiHeadAttention
from model.robot.position_encoder import PositionEncoder


class VisionEncoder(nn.Module):
    def __init__(
        self,
        position_encoder: PositionEncoder,
        feature_extractor: ResNetFeatureExtractor,
        node_camera_attention: MultiHeadAttention,
        image_size: int = 128,
    ):
        super().__init__()
        self.position_encoder = position_encoder
        self.feature_extractor = feature_extractor
        self.node_camera_attention = node_camera_attention
        self.image_size = image_size

        self.camera_feature_linear1 = nn.Linear(128 + 64, 128)
        self.camera_feature_linear2 = nn.Linear(128, 128)

    def forward(
        self, camera_feature_tensor, camera_pixels, node_features_batched, node_mask
    ):
        """
        Note: Camera number N_C is variable across batch
        Args:
            camera_feature_tensor: list of length B, inner shape [N_C, 7]
            camera_pixels: list of length B, inner shape [N_C, 3, H, W]
            node_features_batched: Shape [batch_size, max_node_num, 128]
        """

        # Deal with different camera numbers in the batch dimension, add padding to the shorter ones
        device = camera_feature_tensor[0].device
        batch_size = len(camera_feature_tensor)
        num_cameras = [c.shape[0] for c in camera_feature_tensor]
        max_num_cameras = max(num_cameras)

        camera_feature_tensor_aligned = [
            t.cat([c, t.zeros(max_num_cameras - c.shape[0], 7, device=device)], dim=0)
            for c in camera_feature_tensor
        ]
        camera_pixels_aligned = [
            t.cat(
                [
                    c,
                    t.zeros(
                        max_num_cameras - c.shape[0],
                        3,
                        self.image_size,
                        self.image_size,
                        device=device,
                    ),
                ],
                dim=0,
            )
            for c in camera_pixels
        ]
        camera_feature_tensor_stacked = t.stack(camera_feature_tensor_aligned, dim=0)
        camera_pixels_stacked = t.stack(camera_pixels_aligned, dim=0)

        # Extract features from the camera pixels and position
        camera_pixels_features = self.feature_extractor(
            camera_pixels_stacked.view(
                batch_size * max_num_cameras, 3, self.image_size, self.image_size
            )
        )
        camera_position_embedding = self.position_encoder(
            camera_feature_tensor_stacked.view(batch_size * max_num_cameras, 7)
        )

        # concat camera_pixels_features and camera_position_embedding
        # Shape [batch_size * max_num_cameras, 128 + 64 = 192]
        vision_features = t.cat(
            (camera_pixels_features, camera_position_embedding), dim=-1
        )
        # Shape [batch_size * max_num_cameras, 128]
        vision_features = F.leaky_relu(self.camera_feature_linear1(vision_features))
        # Shape [batch_size * max_num_cameras, 128]
        vision_features = F.leaky_relu(self.camera_feature_linear2(vision_features))

        # Shape [batch_size, max_num_cameras, 128]
        vision_features = vision_features.view(batch_size, max_num_cameras, -1)

        camera_mask = t.zeros(batch_size, max_num_cameras, dtype=t.bool, device=device)
        for i, num in enumerate(num_cameras):
            camera_mask[i, :num] = True

        # shape [batch_size, max_node_num, max_num_cameras]
        attention_mask = node_mask.unsqueeze(2) & camera_mask.unsqueeze(1)

        # Shape [batch_size, max_num_cameras, 128]
        node_camera_feature = self.node_camera_attention(
            node_features_batched, vision_features, vision_features, mask=attention_mask
        )

        return node_camera_feature


class RobotEncoder(nn.Module):
    def __init__(
        self,
        normalize_radius: float,
        vision_encoder: VisionEncoder,
        grid_resolution: int = 64,
    ):
        super(RobotEncoder, self).__init__()
        self.normalize_radius = normalize_radius
        self.voxel_encoder_net = RLEncoder(
            f_dim=1,
            c_dim=8,
            hidden_dim=8,
            scatter_type="max",
            grid_resolution=grid_resolution,
            padding=0.02,
            unet3d_kwargs={"layer_order": "cl", "num_levels": 2},
        )
        self.voxel_decoder_net = RLDecoder(
            f_dim=8,
            c_dim=8,
            hidden_dim=8,
            grid_resolution=grid_resolution,
            sample_mode="bilinear",
            padding=0.02,
        )
        self.kinematic_encoder_conv1 = TransformerConv(
            in_channels=22, out_channels=64, edge_dim=9
        )
        self.kinematic_encoder_conv2 = TransformerConv(
            in_channels=64, out_channels=128, edge_dim=9
        )
        self.kinematic_encoder_conv3 = TransformerConv(
            in_channels=128, out_channels=128, edge_dim=9
        )

        self.vision_encoder = vision_encoder

        self.mask_voxel = False
        self.mask_node = False
        self.mask_edge = False

        self.voxel_noise_rate = 0
        self.node_noise_rate = 0
        self.edge_noise_rate = 0

    def set_mask(self, mask_voxel: bool, mask_node: bool, mask_edge: bool):
        self.mask_voxel = mask_voxel
        self.mask_node = mask_node
        self.mask_edge = mask_edge

    def set_noise_rate(
        self, voxel_noise_rate: float, node_noise_rate: float, edge_noise_rate: float
    ):
        self.voxel_noise_rate = voxel_noise_rate
        self.node_noise_rate = node_noise_rate
        self.edge_noise_rate = edge_noise_rate

    def forward(
        self,
        com,
        voxel_positions,
        voxel_features,
        node_positions,
        node_features,
        edges,
        edge_features,
        camera_feature_tensor,
        camera_pixels,
    ):
        """
        Note:
            Voxel num V* corresponds to voxel num, variable across batch
            Node num N* corresponds to rigid body num, variable across batch
            Edge num E* corresponds to joint num, variable across batch
            There are E* * 2 edges for edges in both directions
        Args:
            com: List of length B, inner shape [1, 3]
            voxel_positions: List of length B, inner shape [V*, 3],
            voxel_features: List of length B, inner shape [V*, 1]
            node_positions: List of length B, inner shape [N*, 3]
            node_features: List of length B, inner shape [N*, 14]
            edges: List of length B, inner shape [2, E* * 2]
            edge_features: List of length B, inner shape [E* * 2, 9]
        Returns:
            out_node_features: Shape [N_0 + N_1 + ... + N_(B-1), 128]
            out_pooled_features: Shape [B, 128]
            out_node_batch: Index tensor, Shape [N_0 + N_1 + ... + N_(B-1)]
            out_edge_batch: Index tensor, Shape [(E_0 + E_1 + ... + E_(B-1)) * 2]
            out_edges: Shape [2, (E_0 + E_1 + ... + E_(B-1)) * 2]
            out_unique_edge_batch: Index tensor, Shape [E_0 + E_1 + ... + E_(B-1)]
            out_unique_edges: Shape [2, E_0 + E_1 + ... + E_(B-1)]
        """
        if self.voxel_noise_rate > 0:
            voxel_features = [
                t.where(t.rand_like(v) < self.edge_noise_rate, t.randn_like(v), v)
                for v in voxel_features
            ]
        if self.node_noise_rate > 0:
            node_features = [
                t.where(t.rand_like(n) < self.edge_noise_rate, t.randn_like(n), n)
                for n in node_features
            ]
        if self.edge_noise_rate > 0:
            edge_features = [
                t.where(t.rand_like(e) < self.edge_noise_rate, t.randn_like(e), e)
                for e in edge_features
            ]

        if self.mask_voxel:
            voxel_features = [t.zeros_like(v) for v in voxel_features]
        if self.mask_node:
            node_features = [t.zeros_like(n) for n in node_features]
        if self.mask_edge:
            edge_features = [t.zeros_like(e) for e in edge_features]

        com = t.concatenate(com, dim=0)

        # Sample equal number of voxels for each sample in the batch
        sample_num = min([len(v) for v in voxel_positions])
        voxel_sample_mask = [t.randperm(len(v))[:sample_num] for v in voxel_positions]
        voxel_positions = t.stack(
            [v[mask] for v, mask in zip(voxel_positions, voxel_sample_mask)], dim=0
        )
        voxel_features = t.stack(
            [v[mask] for v, mask in zip(voxel_features, voxel_sample_mask)], dim=0
        )

        device = com.device
        norm_voxel_positions = (voxel_positions / 2 / self.normalize_radius) + 0.5
        latent = self.voxel_encoder_net(norm_voxel_positions, voxel_features)

        batch_size = len(com)
        node_batch = []
        node_num = []
        max_node_num = 0
        for i in range(len(node_positions)):
            node_batch += [i] * len(node_positions[i])
            node_num.append(len(node_positions[i]))
            max_node_num = max(max_node_num, len(node_positions[i]))
        node_batch = t.tensor(node_batch, dtype=t.long, device=device)
        norm_node_positions_batch = t.full(
            [com.shape[0], max_node_num, 3],
            0.5,
            dtype=voxel_positions.dtype,
            device=device,
        )
        for i in range(len(node_positions)):
            norm_node_positions_batch[i, : len(node_positions[i])] += (
                (node_positions[i] - com[i]) / 2 / self.normalize_radius
            )
        # Add offsets to edge indices
        edge_indices = []
        unique_edge_indices = []
        edge_batch = []
        unique_edge_batch = []
        unique_edge_num_per_robot = []
        offset = 0
        for i, (e, n_num) in enumerate(zip(edges, node_num)):
            edge_batch += [i] * e.shape[1]

            edge_indices.append(e + offset)
            unique_edge_batch += [i] * (e.shape[1] // 2)
            unique_edge_num_per_robot.append(e.shape[1] // 2)
            unique_edge_indices.append((e + offset)[:, : e.shape[1] // 2])
            offset += n_num

        edge_batch = t.tensor(edge_batch, dtype=t.long, device=device)
        edge_indices = t.concatenate(edge_indices, dim=1)
        unique_edge_batch = t.tensor(unique_edge_batch, dtype=t.long, device=device)
        unique_edge_indices = t.concatenate(unique_edge_indices, dim=1)
        edge_features = t.concatenate(edge_features, dim=0)

        node_voxel_features_batch = self.voxel_decoder_net(
            norm_node_positions_batch, latent
        )
        node_voxel_features = t.concatenate(
            [node_voxel_features_batch[i, :num] for i, num in enumerate(node_num)],
            dim=0,
        )
        # Shape [N_0 + N_1 + ... N_(B-1), 8 + 14 = 22]
        node_features = t.concatenate(
            (
                node_voxel_features,
                t.concatenate(node_features, dim=0).to(device),
            ),
            dim=1,
        )

        # node_features = t.concatenate(node_features, dim=0).to(device)

        out1 = self.kinematic_encoder_conv1(node_features, edge_indices, edge_features)
        out2 = F.leaky_relu(
            self.kinematic_encoder_conv2(
                F.leaky_relu(out1, negative_slope=0.2), edge_indices, edge_features
            ),
            negative_slope=0.2,
        )
        out_node_feature = self.kinematic_encoder_conv3(
            out2, edge_indices, edge_features
        )

        # use node number to split out_node_feature into different parts
        # pad each item in out_node_feature_split with 0 to the same length, and stack
        out_node_feature_split = t.split(out_node_feature, node_num, dim=0)
        node_features_batched = t.stack(
            [
                t.cat(
                    [
                        item,
                        t.zeros(
                            max_node_num - item.shape[0], item.shape[1], device=device
                        ),
                    ],
                    dim=0,
                )
                for item in out_node_feature_split
            ],
            dim=0,
        )

        node_mask = t.zeros(batch_size, max_node_num, dtype=t.bool, device=device)
        for i, num in enumerate(node_num):
            node_mask[i, :num] = True

        # Shape [batch, max_node_num, 128]
        node_camera_feature = self.vision_encoder(
            camera_feature_tensor, camera_pixels, node_features_batched, node_mask
        )

        # Concatenate edge_features_batched and node_camera_feature
        # Shape [batch, max_node_num, 128 + 128 = 256]
        node_features_batched_concat = t.cat(
            (node_features_batched, node_camera_feature), dim=2
        )

        # Shape [N_0 + N_1 + ... + N_(B-1), 256]
        out_node_feature_with_camera = t.cat(
            [node_features_batched_concat[i, : node_num[i]] for i in range(batch_size)],
            dim=0,
        )

        pooled_out_node_feature_with_camera = global_max_pool(
            out_node_feature_with_camera, node_batch, size=com.shape[0]
        )

        return (
            out_node_feature_with_camera,
            pooled_out_node_feature_with_camera,
            unique_edge_batch,
            unique_edge_indices,
            unique_edge_num_per_robot,
        )


class RobotActor(NeuralNetworkModule):
    def __init__(
        self,
        encoder: RobotEncoder,
        freeze_encoder: bool = False,
        resolution: int = 5,
    ):
        super().__init__()
        self.resolution = resolution
        self.encoder = encoder
        self.action_resnet = nn.Sequential(
            # ResnetBlockFC(771, 256),
            # ResnetBlockFC(256, 64),
            # ResnetBlockFC(64, 16),
            # ResnetBlockFC(16, resolution),
            ResnetBlockFC(771, 64),
            ResnetBlockFC(64, 16),
            ResnetBlockFC(16, resolution),
        )
        self.freeze_encoder = freeze_encoder
        self.set_input_module(self.action_resnet)
        self.set_output_module(self.action_resnet)

    def forward(
        self,
        com,
        voxel_positions,
        voxel_features,
        node_positions,
        node_features,
        edges,
        edge_features,
        camera_feature_tensor,
        camera_pixels,
        velocity,
        # control_vector,
        action=None,
    ):
        # add velocity argument
        batch_size = len(com)
        if self.freeze_encoder:
            with t.no_grad():
                (
                    out_node_features_with_camera,
                    pooled_out_node_features_with_camera,
                    out_unique_edge_batch,
                    out_unique_edge_indices,
                    unique_edge_num_per_robot,
                ) = self.encoder(
                    com,
                    voxel_positions,
                    voxel_features,
                    node_positions,
                    node_features,
                    edges,
                    edge_features,
                    camera_feature_tensor,
                    camera_pixels,
                )
        else:
            (
                out_node_features_with_camera,
                pooled_out_node_features_with_camera,
                out_unique_edge_batch,
                out_unique_edge_indices,
                unique_edge_num_per_robot,
            ) = self.encoder(
                com,
                voxel_positions,
                voxel_features,
                node_positions,
                node_features,
                edges,
                edge_features,
                camera_feature_tensor,
                camera_pixels,
            )
        # Shape [E_0 + E_1 + ... + E_(B-1), 64]
        edge_total = out_unique_edge_indices.shape[1]
        first_node_features_with_camera = t.gather(
            out_node_features_with_camera,
            dim=0,
            index=out_unique_edge_indices[0].unsqueeze(1).expand(edge_total, 256),
        )
        second_node_features_with_camera = t.gather(
            out_node_features_with_camera,
            dim=0,
            index=out_unique_edge_indices[1].unsqueeze(1).expand(edge_total, 256),
        )
        node_pooled_features_with_camera = t.gather(
            pooled_out_node_features_with_camera,
            dim=0,
            index=out_unique_edge_batch.unsqueeze(1).expand(
                out_unique_edge_batch.shape[0], 256
            ),
        )

        # Shape [E_0 + E_1 + ... + E_(B-1), 256+256+256=768]
        node_features_of_robot_with_camera = t.cat(
            (
                first_node_features_with_camera,
                second_node_features_with_camera,
                node_pooled_features_with_camera,
            ),
            dim=1,
        )

        # Shape [E_0 + E_1 + ... + E_(B-1), 768+3+1=772]
        node_features_of_robot_with_camera = t.cat(
            (
                node_features_of_robot_with_camera,
                t.cat(
                    [
                        v.expand(unique_edge_num_per_robot[i], 3)
                        for i, v in enumerate(velocity)
                    ],
                    dim=0,
                ),
                # t.cat(
                #     [
                #         c.expand(unique_edge_num_per_robot[i], 1)
                #         for i, c in enumerate(control_vector)
                #     ],
                #     dim=0,
                # )
            ),
            dim=1,
        )

        action_logits = self.action_resnet(node_features_of_robot_with_camera)
        dist = Categorical(logits=action_logits)
        raw_action = dist.sample() if action is None else t.concatenate(action)
        raw_log_prob = dist.log_prob(raw_action)
        raw_entropy = dist.entropy()

        # Now convert from concatenated batch to list batch
        action = []
        log_prob = []
        entropy = []
        for idx in range(batch_size):
            action.append(raw_action[out_unique_edge_batch == idx])
            log_prob.append(t.mean(raw_log_prob[out_unique_edge_batch == idx]))
            entropy.append(t.mean(raw_entropy[out_unique_edge_batch == idx]))
        log_prob = t.stack(log_prob)
        entropy = t.stack(entropy)
        return action, log_prob, entropy


class RobotCritic(NeuralNetworkModule):
    def __init__(
        self,
        encoder: RobotEncoder,
        freeze_encoder: bool = False,
    ):
        super().__init__()
        self.encoder = encoder
        self.value_resnet = nn.Sequential(
            # ResnetBlockFC(259, 128),
            # ResnetBlockFC(128, 64),
            # ResnetBlockFC(64, 32),
            # ResnetBlockFC(32, 16),
            # ResnetBlockFC(16, 1),
            ResnetBlockFC(259, 32),
            ResnetBlockFC(32, 8),
            ResnetBlockFC(8, 1),
        )
        self.freeze_encoder = freeze_encoder
        self.set_input_module(self.value_resnet)
        self.set_output_module(self.value_resnet)

    def forward(
        self,
        com,
        voxel_positions,
        voxel_features,
        node_positions,
        node_features,
        edges,
        edge_features,
        camera_feature_tensor,
        camera_pixels,
        velocity,
        # control_vector,
    ):
        if self.freeze_encoder:
            with t.no_grad():
                _, pooled_out_node_features_with_camera, *__ = self.encoder(
                    com,
                    voxel_positions,
                    voxel_features,
                    node_positions,
                    node_features,
                    edges,
                    edge_features,
                    camera_feature_tensor,
                    camera_pixels,
                )
        else:
            _, pooled_out_node_features_with_camera, *__ = self.encoder(
                com,
                voxel_positions,
                voxel_features,
                node_positions,
                node_features,
                edges,
                edge_features,
                camera_feature_tensor,
                camera_pixels,
            )

        # Shape [B, 256+3+1=260]
        pooled_out_node_features_with_camera = t.cat(
            (
                pooled_out_node_features_with_camera,
                t.cat([v for v in velocity], dim=0),
                # t.cat([c for c in control_vector], dim=0),
            ),
            dim=1,
        )

        value = self.value_resnet(pooled_out_node_features_with_camera)
        return value.view(-1)
