import torch
import torch.nn as nn
import sys
import torchvision

# class RecurrentDepthBackbone(nn.Module):
#     def __init__(self, base_backbone, env_cfg) -> None:
#         super().__init__()
#         activation = nn.ELU()
#         last_activation = nn.Tanh()
#         self.base_backbone = base_backbone
#         if env_cfg == None:
#             self.combination_mlp = nn.Sequential(
#                                     nn.Linear(32 + 53, 128),
#                                     activation,
#                                     nn.Linear(128, 32)
#                                 )
#         else:
#             self.combination_mlp = nn.Sequential(
#                                         nn.Linear(32 + env_cfg.env.n_proprio, 128),
#                                         activation,
#                                         nn.Linear(128, 32)
#                                     )
#         self.rnn = nn.GRU(input_size=32, hidden_size=512, batch_first=True)
#         self.output_mlp = nn.Sequential(
#                                 nn.Linear(512, 32+2),
#                                 last_activation
#                             )
#         self.hidden_states = None

#     def forward(self, depth_image, proprioception):
#         depth_image = self.base_backbone(depth_image)
#         depth_latent = self.combination_mlp(torch.cat((depth_image, proprioception), dim=-1))
#         # depth_latent = self.base_backbone(depth_image)
#         depth_latent, self.hidden_states = self.rnn(depth_latent[:, None, :], self.hidden_states)
#         depth_latent = self.output_mlp(depth_latent.squeeze(1))
        
#         return depth_latent

#     def detach_hidden_states(self):
#         self.hidden_states = self.hidden_states.detach().clone()

# class StackDepthEncoder(nn.Module):
#     def __init__(self, base_backbone, env_cfg) -> None:
#         super().__init__()
#         activation = nn.ELU()
#         self.base_backbone = base_backbone
#         self.combination_mlp = nn.Sequential(
#                                     nn.Linear(32 + env_cfg.env.n_proprio, 128),
#                                     activation,
#                                     nn.Linear(128, 32)
#                                 )

#         self.conv1d = nn.Sequential(nn.Conv1d(in_channels=env_cfg.depth.buffer_len, out_channels=16, kernel_size=4, stride=2),  # (30 - 4) / 2 + 1 = 14,
#                                     activation,
#                                     nn.Conv1d(in_channels=16, out_channels=16, kernel_size=2), # 14-2+1 = 13,
#                                     activation)
#         self.mlp = nn.Sequential(nn.Linear(16*14, 32), 
#                                  activation)
        
#     def forward(self, depth_image, proprioception):
#         # depth_image shape: [batch_size, num, 58, 87]
#         depth_latent = self.base_backbone(None, depth_image.flatten(0, 1), None)  # [batch_size * num, 32]
#         depth_latent = depth_latent.reshape(depth_image.shape[0], depth_image.shape[1], -1)  # [batch_size, num, 32]
#         depth_latent = self.conv1d(depth_latent)
#         depth_latent = self.mlp(depth_latent.flatten(1, 2))
#         return depth_latent

    
# class DepthOnlyFCBackbone58x87(nn.Module):
#     def __init__(self, scandots_output_dim, output_activation=None, num_frames=1):
#         super().__init__()

#         self.num_frames = num_frames
#         activation = nn.ELU()
#         self.image_compression = nn.Sequential(
#             # [1, 58, 87]
#             nn.Conv2d(in_channels=self.num_frames, out_channels=16, kernel_size=5),
#             # [32, 54, 83]
#             nn.MaxPool2d(kernel_size=2, stride=2),
#             # [32, 27, 41]
#             activation,
#             nn.Conv2d(in_channels=16, out_channels=4, kernel_size=3),
#             activation,
#             nn.Flatten(),
#             # [32, 25, 39]
#             nn.Linear(4 * 25 * 39, 128),
#             activation,
#             nn.Linear(128, scandots_output_dim)
#         )

#         if output_activation == "tanh":
#             self.output_activation = nn.Tanh()
#         else:
#             self.output_activation = activation

#     def forward(self, images: torch.Tensor):
#         images_compressed = self.image_compression(images.unsqueeze(1))
#         latent = self.output_activation(images_compressed)
#         return latent


class DepthOnlyFCBackbone58x87(nn.Module):
    def __init__(self, scandots_output_dim, output_activation=None, num_frames=1):
        super().__init__()

        self.num_frames = num_frames
        activation = nn.ELU()
        self.image_compression = nn.Sequential(
            # 输入维度 [num_frames, 58, 87]，按照配置，输出通道数为16，卷积核大小为5，步长为2
            nn.Conv2d(in_channels=self.num_frames, out_channels=16, kernel_size=5, stride=2),
            # 按照配置，使用MaxPool，这里默认池化核大小与步长相同
            nn.MaxPool2d(kernel_size=2),
            activation,
            # 按照配置，输出通道数为32，卷积核大小为4，步长为2
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2),
            activation,
            # 按照配置，输出通道数为32，卷积核大小为3，步长为1
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1),
            activation,
            nn.Flatten(),
            # 计算线性层输入维度，这里根据前面卷积后的尺寸计算得到
            nn.Linear(32 * 3 * 7, 128),
            activation,
            nn.Linear(128, scandots_output_dim)
        )

        if output_activation == "tanh":
            self.output_activation = nn.Tanh()
        else:
            self.output_activation = activation

    def forward(self, images: torch.Tensor):
        images_compressed = self.image_compression(images.unsqueeze(1))
        latent = self.output_activation(images_compressed)
        return latent
