import torch
import torch.nn as nn
import torch.nn.functional as F

class TestMLP(nn.Module):
    def __init__(self):
        super(TestMLP, self).__init__()
        # 定义第一个全连接层，输入维度是1，输出维度是32
        self.fc1 = nn.Linear(1, 32)
        # 定义第二个全连接层，输入维度是32，输出维度是32
        self.fc2 = nn.Linear(32, 256)
        # 定义第三个全连接层，输入维度是32，输出维度是16
        self.fc3 = nn.Linear(256, 768)

    def forward(self, x):
        # 应用第一个全连接层，然后使用ReLU激活函数
        x = F.relu(self.fc1(x))
        # 应用第二个全连接层，然后使用ReLU激活函数
        x = F.relu(self.fc2(x))
        # 应用第三个全连接层，得到最终输出
        x = self.fc3(x)
        return x

class DepthOnlyFCBackbone224x224(nn.Module):
    def __init__(self, output_dim, output_activation=None, num_frames=1):
        super().__init__()

        self.num_frames = num_frames
        activation = nn.ELU()
        self.image_compression = nn.Sequential(
            # [1, 224, 224]
            nn.Conv2d(in_channels=self.num_frames, out_channels=32, kernel_size=5),
            # [32, 220, 220]
            nn.MaxPool2d(kernel_size=2, stride=2),
            # [32, 110, 110]
            activation,
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            activation,
            nn.Flatten(),
            # [64, 108, 108]
            nn.Linear(64 * 108 * 108, 128),
            activation,
            nn.Linear(128, output_dim)
        )

        if output_activation == "tanh":
            self.output_activation = nn.Tanh()
        else:
            self.output_activation = activation

    def forward(self, images: torch.Tensor):
        images_compressed = self.image_compression(images)
        latent = self.output_activation(images_compressed)

        return latent

class Fusion(nn.Module):
    def __init__(self, obs_emd_size, lang_emd_size):
        super(Fusion, self).__init__()
        self.input_size = obs_emd_size + lang_emd_size
        self.linear1 = torch.nn.Linear(
            in_features=self.input_size, out_features=obs_emd_size
        )
        self.linear2 = torch.nn.Linear(in_features=obs_emd_size, out_features=obs_emd_size)
        self.linear_add1 = torch.nn.Linear(in_features=lang_emd_size, out_features=lang_emd_size)

    def forward(self, obs, lang):
        x = self.linear_add1(lang)
        x = F.relu(x)
        x = torch.cat([obs, x], axis=-1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)

        return x


class FiLM(nn.Module):
    def __init__(self, input_dim, condition_dim):
        super(FiLM, self).__init__()
        # 全连接层，用于生成y和β参数
        self.fc_gamma = nn.Linear(condition_dim, input_dim)
        self.fc_beta = nn.Linear(condition_dim, input_dim)

    def forward(self, x, condition):
        # 根据条件特征获取缩放scale参数和移位参数shift，即计算y和β参数
        gamma = self.fc_gamma(condition)
        beta = self.fc_beta(condition)
        # 对输入特征x进行缩放和偏移，实现条件特征调整输入特征
        y = gamma * x + beta

        return y

class MultiStagePointNetEncoder(nn.Module):
    def __init__(self, h_dim=128, out_channels=128, num_layers=4, **kwargs):
        super().__init__()

        self.h_dim = h_dim
        self.out_channels = out_channels
        self.num_layers = num_layers

        self.act = nn.LeakyReLU(negative_slope=0.0, inplace=False)

        self.conv_in = nn.Conv1d(3, h_dim, kernel_size=1)
        self.layers, self.global_layers = nn.ModuleList(), nn.ModuleList()
        for i in range(self.num_layers):
            self.layers.append(nn.Conv1d(h_dim, h_dim, kernel_size=1))
            self.global_layers.append(nn.Conv1d(h_dim * 2, h_dim, kernel_size=1))
        self.conv_out = nn.Conv1d(h_dim * self.num_layers, out_channels, kernel_size=1)

    def forward(self, x):
        x = x.transpose(1, 2) # [B, N, 3] --> [B, 3, N]
        y = self.act(self.conv_in(x))
        feat_list = []
        for i in range(self.num_layers):
            y = self.act(self.layers[i](y))
            y_global = y.max(-1, keepdim=True).values
            y = torch.cat([y, y_global.expand_as(y)], dim=1)
            y = self.act(self.global_layers[i](y))
            feat_list.append(y)
        x = torch.cat(feat_list, dim=1)
        x = self.conv_out(x)

        x_global = x.max(-1).values

        return x_global


class FiLMBlock(nn.Module):
    def __init__(self):
        super(FiLMBlock, self).__init__()

    def forward(self, x, gamma, beta):
        beta = beta.view(x.size(0), x.size(1))
        gamma = gamma.view(x.size(0), x.size(1))

        x = gamma * x + beta

        return x

class ResBlock(nn.Module):
    def __init__(self, in_place, out_place):
        super(ResBlock, self).__init__()

        self.conv1 = nn.Linear(in_place, out_place)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Linear(out_place, out_place)
        self.norm2 = nn.BatchNorm1d(out_place)
        self.film = FiLMBlock()
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x, beta, gamma):
        x = self.conv1(x)
        x = self.relu1(x)
        identity = x

        x = self.conv2(x)
        x = self.norm2(x)
        x = self.film(x, beta, gamma)
        x = self.relu2(x)
        x = x + identity

        return x

class Res_FiLM(nn.Module):
    def __init__(self, n_res_blocks, input_dim, condition_dim):
        super(Res_FiLM, self).__init__()

        # self.feature_extractor = FeatureExtractor()
        self.film_generator = nn.Linear(condition_dim, 2 * n_res_blocks * input_dim)
        self.res_blocks = nn.ModuleList()

        for _ in range(n_res_blocks):
            self.res_blocks.append(ResBlock(input_dim, input_dim))

        self.n_res_blocks = n_res_blocks
        self.input_dim = input_dim

    def forward(self, x, condition):
        batch_size = x.size(0)
        # print("x.size:", x.shape)
        film_vector = self.film_generator(condition).view(batch_size, self.n_res_blocks, 2, self.input_dim)

        # device = x.device
        # d = x.size(-1)
        # coordinate = torch.arange(-1, 1 + 0.00001, 2 / (d - 1)).to(device)
        # coordinate_x = coordinate.expand(batch_size, d)
        # coordinate_y = coordinate.expand(batch_size, d)

        for i, res_block in enumerate(self.res_blocks):
            beta = film_vector[:, i, 0, :]
            gamma = film_vector[:, i, 1, :]
            # x = torch.cat([x, coordinate_x, coordinate_y], 1)
            x = res_block(x, beta, gamma)

        feature = x

        return feature


class CrossAttentionLayer(nn.Module):
    def __init__(self, input_dim_1, input_dim_2, hidden_dim):
        """
        input_dim_1: 第一个输入特征维度
        input_dim_2: 第二个输入特征维度
        hidden_dim: 注意力中间维度（可以控制模型容量）
        """
        super(CrossAttentionLayer, self).__init__()

        # 线性变换，将两个输入映射到同一个空间
        self.query_proj = nn.Linear(input_dim_1, hidden_dim)  # 第一个输入特征映射到查询空间
        self.key_proj = nn.Linear(input_dim_2, hidden_dim)  # 第二个输入特征映射到键空间
        self.value_proj = nn.Linear(input_dim_2, hidden_dim)  # 第二个输入特征映射到值空间

        # 输出线性变换
        self.out_proj = nn.Linear(hidden_dim, input_dim_1)  # 输出层，映射回原始维度

    def forward(self, x1, x2):
        """
        x1: 第一个输入特征（例如图像特征）
        x2: 第二个输入特征（例如文本特征）
        """

        # 计算注意力
        query = self.query_proj(x1)  # shape: [batch_size, seq_len_1, hidden_dim]
        key = self.key_proj(x2)  # shape: [batch_size, seq_len_2, hidden_dim]
        value = self.value_proj(x2)  # shape: [batch_size, seq_len_2, hidden_dim]

        # 计算缩放点积注意力
        attn_scores = torch.bmm(query, key.transpose(1, 2))  # [batch_size, seq_len_1, seq_len_2]
        attn_scores = attn_scores / (key.size(-1) ** 0.5)  # 缩放
        attn_weights = F.softmax(attn_scores, dim=-1)  # [batch_size, seq_len_1, seq_len_2]

        # 使用注意力权重加权值
        attn_output = torch.bmm(attn_weights, value)  # [batch_size, seq_len_1, hidden_dim]

        # 输出层映射回原始维度
        output = self.out_proj(attn_output)  # [batch_size, seq_len_1, input_dim_1]

        return output

class WeightedAggregation(nn.Module):
    def __init__(self, input_size):
        super(WeightedAggregation, self).__init__()
        self.fc = nn.Linear(input_size, 1, bias=False)

    def forward(self, x):
        weights = torch.softmax(self.fc(x), dim=1)  # 计算权重
        return (x * weights).sum(dim=1)  # 加权求和

def test01():
    model = TestMLP()
    test_data = torch.randn(2,1)
    print(test_data.shape)
    x = model(test_data)
    print(x.shape)

def test02():
    model = DepthOnlyFCBackbone224x224(256)
    test_data = torch.rand(2, 1, 224, 224)
    print(test_data.shape)
    x = model(test_data)
    print(x.shape)

if __name__ == '__main__':
    test01()
