from torch import nn
import torch
torch.manual_seed(666)

class LearnRiskModel(nn.Module):
    def __init__(self, in_dim=24, hidden_dim=64, n_id=1):
        super(LearnRiskModel, self).__init__()
        self.sigma = nn.Parameter(torch.randn(in_dim, n_id))
        self.w = nn.Parameter(torch.randn(in_dim, n_id))

        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, in_dim)

    def forward(self, mu, activate_vector):
        """
        :param mu: Shape B, N, n_id
        :param activate_vector: Shape B, N, n_id
        """
        mu_ = mu[0].permute(1, 0).contiguous()

        mu_ = self.fc1(mu_)
        mu_ = torch.relu(mu_)
        mu_ = self.fc2(mu_)
        mu_ = torch.relu(mu_)
        mu_ = self.fc3(mu_)
        mu_ = mu_.permute(1, 0).contiguous()

        mu_ = self.w * mu_
        mu_ = torch.sum(activate_vector * mu_, dim=1)

        sigma = self.w**2 * self.sigma ** 2
        sigma = torch.sum(activate_vector * sigma, dim=1)

        sigma = torch.clamp(sigma, min=1e-3, max=10)

        return mu_, torch.sqrt(sigma)


