import torch
from torch import nn
from hetreg.models import get_activation, get_head
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from functools import partial
from .modules import initialize_prior_knowledge, BaseLipBlock_sep, MLP, FeatureNN
from .util_funcs import *
import torch.nn.init as init


def generate_mask_from_indices(index_list: list, num_covariates: int, batch_size: int, device: torch.device):
    """
    Args:
        index_list: list[int] or list[list[int]]
            - [1, 3] 表示所有样本都只观测第 1 和 3 个 covariate；
            - [[1,2], [0,3]] 表示每个样本单独指定观测 subset。
        num_covariates: int, 总的 covariate 数量（例如 4）
        batch_size: int

    Returns:
        mask: FloatTensor of shape (B, num_covariates), 1 表示观测，0 表示未观测
    """
    mask = torch.zeros((batch_size, num_covariates), dtype=torch.float32)
    mask[..., index_list] = 1.0
    mask = mask.to(device)
    return mask



class DeepSetEncoder(nn.Module):
    def __init__(self, num_covariates, val_dim=1, phi_dim=512, hidden_layers=3, rho_dim=512):
        super().__init__()
        self.num_of_covariates = num_covariates
        self.id_embedding = nn.Embedding(num_covariates, phi_dim)
        init.normal_(self.id_embedding.weight, mean=0.0, std=0.01)


        self.val_net = MLP(
                    input_size=val_dim,
                    width=phi_dim,
                    depth=hidden_layers,
                    output_size=phi_dim,
                    activation='gelu')

        self.phi = MLP(
                    input_size=2 * phi_dim,
                    width=phi_dim,
                    depth=hidden_layers,
                    output_size=phi_dim,
                    activation='gelu')

        self.rho_shared = MLP(
            input_size=phi_dim,
            width=phi_dim,
            depth=hidden_layers,
            output_size=rho_dim,
            activation='gelu')


        num_of_dependence = (self.num_of_covariates * (self.num_of_covariates - 1)) // 2
        self.num_of_distri_params = self.num_of_covariates + self.num_of_covariates + num_of_dependence

        #self.to_distri_params = nn.Linear(rho_dim, self.num_of_distri_params)
        self.to_distri_params = MLP(
            input_size=rho_dim,
            width=phi_dim,
            depth=hidden_layers,
            output_size=self.num_of_distri_params,
            activation='gelu')

    def forward(self, arr_input, mask=None):
        covariate_vals = arr_input[..., None]
        B, S = covariate_vals.shape[:2]

        # Automatically generate full mask if not provided
        if mask is None:
            mask = torch.ones((B, S), device=covariate_vals.device)

        # Randomly drop some covariates during training for regularization
        if self.training:
            dropout_mask = (torch.rand_like(mask) > 0.3).float()
            mask_training = mask * dropout_mask
            # Ensure at least one covariate remains per sample
            all_zero = (mask.sum(dim=1) == 0)
            if all_zero.any():
                rand_index = torch.randint(0, mask.shape[1], (all_zero.sum(),), device=mask.device)
                mask_training[all_zero, :] = 0.0
                mask_training[all_zero, rand_index] = 1.0

            mask_training = mask_training * mask
        else:
            mask_training = mask

        # Generate shared covariate IDs
        device = covariate_vals.device
        covariate_ids = torch.arange(S, device=device).unsqueeze(0).repeat(B, 1)  # (B, S)

        # Encode
        id_emb = self.id_embedding(covariate_ids)           # (B, S, phi_dim)
        val_emb = self.val_net(covariate_vals)              # (B, S, phi_dim)
        combined = torch.cat([id_emb, val_emb], dim=-1)
        combined = combined * mask_training.unsqueeze(-1)
        # (B, S, 2*phi_dim)
        phi_out = self.phi(combined)                        # (B, S, phi_dim)

        # Apply mask
            # (B, S, phi_dim)
        pooled = phi_out.sum(dim=1)                         # (B, phi_dim)

        # Latent transformation
        h = self.rho_shared(pooled)                         # (B, rho_dim)
        # if self.training:
        #     h = h + torch.randn_like(h) * 0.1               # latent noise

        # # Output Gaussian parameters
        # mu = self.to_mu(h)                                  # (B, out_dim)
        # logvar = self.to_logvar(h)                          # (B, out_dim)
        #logvar = torch.clamp(logvar, min=-4.0, max=4.0)     # avoid overconfidence
        output = self.to_distri_params(h)
        #output = torch.clamp(output, min=-4.0, max=4.0)
        return output                                # Parameters of p(c_k | c_S)



if __name__ == "__main__":
    # 初始化 encoder
    encoder = DeepSetEncoder(
        num_covariates=4,  # 对应固定的 4 个 covariates
        val_dim=1,         # 每个 covariate 是标量（可以设成 >1 支持嵌套特征）
        phi_dim=64,
        rho_dim=128,
        #out_dim=1          # 预测目标是 1 维（例如某个 c_k）
    )

    # 构造一个 batch：batch_size = 2，S = 4 个 covariates
    covariate_vals = torch.tensor([
        [[25.0], [1.72], [65.0], [120.0]],   # 样本1: age, height, weight, bp
        [[30.0], [1.65], [58.0], [110.0]]    # 样本2
    ])  # shape: (B=2, S=4, val_dim=1)

    # 可选：提供一个 mask 指出哪些 covariates 是可用的
    mask = torch.tensor([
        [1, 1, 1, 1],     # 样本1 全部 covariates 可用
        [1, 0, 1, 1]      # 样本2 height 缺失
    ], dtype=torch.float32)  # shape: (2, 4)

    # 调用 encoder（会自动生成 covariate_ids）
    encoder.train()  # 启用 dropout 和 noise
    mu, logvar = encoder(covariate_vals, mask)

    print("mu:", mu)         # shape: (B, out_dim)
    print("logvar:", logvar) # shape: (B, out_dim)
