import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical, Normal

# 定义 GMM 模型
class GMM(nn.Module):
    def __init__(self, num_components, input_dim):
        super(GMM, self).__init__()
        self.num_components = num_components
        self.input_dim = input_dim
        self.weights = nn.Parameter(torch.ones(num_components) / num_components)
        self.means = nn.Parameter(torch.randn(num_components, input_dim))
        self.covars = nn.Parameter(torch.ones(num_components, input_dim))

    def forward(self, x):
        dists = Normal(self.means, self.covars)
        likelihoods = dists.log_prob(x.unsqueeze(1)).sum(dim=2)
        weighted_likelihoods = self.weights * likelihoods.exp()
        return weighted_likelihoods

# 定义 MDN 模型
class MDN(nn.Module):
    def __init__(self, num_components, input_dim, output_dim):
        super(MDN, self).__init__()
        self.num_components = num_components
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.gmm = GMM(num_components, input_dim)
        self.fc = nn.Linear(input_dim, num_components * output_dim)

    def forward(self, x):
        # breakpoint()
        gmm_output = self.gmm(x)  # x[100,2] gmm_output[100,3]
        fc_output = self.fc(x) # x[100,2] fc_output[100,6]
        fc_output = fc_output.view(-1, self.num_components, self.output_dim) # [100,3,2]
        return gmm_output, fc_output

# 示例数据
input_data = torch.randn(100, 2)  # 输入数据
target_data = torch.randn(100, 2)  # 目标数据

# 构建 MDN 模型
mdn = MDN(num_components=3, input_dim=2, output_dim=2)

# 定义损失函数
loss_fn = nn.MSELoss()

# 定义优化器
optimizer = optim.Adam(mdn.parameters(), lr=0.001)

# 训练模型
num_epochs = 1
for epoch in range(num_epochs):
    optimizer.zero_grad()
    gmm_output, fc_output = mdn(input_data)
    # print("******", fc_output, gmm_output)
    # breakpoint()
    loss = -torch.log(gmm_output) * loss_fn(fc_output, target_data.unsqueeze(1)).mean()
    # breakpoint()
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch: {epoch}, Loss: {loss.item()}")

# 使用训练好的模型进行预测
test_input = torch.randn(10, 2)
breakpoint()
gmm_output, fc_output = mdn(test_input)
predicted_output = (gmm_output.unsqueeze(2) * fc_output).sum(dim=1)
print("Predicted Output:")
print(predicted_output)