
import torch
from torch import nn
from hetreg.models import get_activation, get_head
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from functools import partial
from .modules import initialize_prior_knowledge, BaseLipBlock_sep, MLP, SubnetMLPNN, SubnetMonoNN
from .covariance import extract_distributional_pars_from_vec, sample_from_multivariate
from .util_funcs import *
from .DeepSetEncoder import DeepSetEncoder


def initialize_prior_knowledge(list_covariates, dict_prior, type='pred'):
    # if type is pred, the value of dict is a scalar {-1, 0, 1}
    # if type is corr, the value of the dict should be a dict
    def impute_prior(current_cov):
        if type == 'pred':
            return 0
        elif type == 'corr':
            dict_prior_corr = {}
            for ith_imp_corr in list_covariates:
                if ith_imp_corr == current_cov:
                    dict_prior_corr[ith_imp_corr] = 1
                else:
                    dict_prior_corr[ith_imp_corr] = 0
            return dict_prior_corr
        else:
            print("Wrong type for prior")


    for ith_cov in list_covariates:
        if ith_cov in list(dict_prior.keys()):
            continue
        else:
            dict_prior[ith_cov] = impute_prior(ith_cov)
    return dict_prior



class LucidAtlas_1d_v22(nn.Module):
    def __init__(self,
                 list_covariates: list,
                 dict_prior: dict = {},
                 dict_cor_prior: dict = {},
                 in_cov_features: int = 1,
                 in_geo_features=0,
                 hidden_features: int = 128,
                 hidden_layers: int = 6,
                 out_features: int = 2,
                 device: str = 'cuda:0',
                 activation='gelu',
                 head='gaussian',
                 head_activation='softplus',

                 ):
        '''
        in_cov_feature: num of covariates for each subnetwork
        '''
        super().__init__()
        self.pos_enc = False


        self.in_cov_features = in_cov_features
        self.in_geo_features = in_geo_features


        self.dict_prior = initialize_prior_knowledge(list_covariates, dict_prior, type='pred')
        self.dict_cor_prior = initialize_prior_knowledge(list_covariates, dict_cor_prior, type='corr')



        self.dict_idx_covariates = {}
        self.dict_cov_idx = {}
        for i in range(len(list_covariates)):
            self.dict_idx_covariates[i] = list_covariates[i]
            self.dict_cov_idx[list_covariates[i]] = i

        self.covariate_names = list_covariates


        self.device = device
        self.net_per_cov = nn.ModuleDict({})
        self.correlation = nn.ModuleDict({})

        for ith_attri in self.covariate_names:

            if self.dict_prior[ith_attri] != 0:

                self.net_per_cov[ith_attri] = SubnetMonoNN(
                    in_cov_features=self.in_cov_features,
                    in_geo_features=self.in_geo_features,
                    val_prior=self.dict_prior[ith_attri],
                    hidden_features=hidden_features,
                    hidden_layers=hidden_layers)

            else:
                self.net_per_cov[ith_attri] = SubnetMLPNN(
                    in_cov_features=self.in_cov_features,
                    in_geo_features=self.in_geo_features,
                    output_size=1,
                    hidden_features=hidden_features,
                    hidden_layers=hidden_layers)


            # self.correlation[ith_attri] = BaseBlock_correlation(
            #     in_cov_features = self.in_cov_features,
            #     hidden_features=hidden_features,
            #     hidden_layers=hidden_layers,
            #     dict_cor_prior=dict_cor_prior[ith_attri],
            #     list_covariates=list_covariates)

        self.correlation = DeepSetEncoder(num_covariates=len(self.covariate_names),
                                          val_dim=1,
                                          phi_dim=512,
                                          hidden_layers=hidden_layers,
                                          rho_dim=512)

        self.atlas = MLP(
                    input_size=self.in_geo_features,
                    width=hidden_features,
                    depth=hidden_layers,
                    output_size=out_features,
                    activation=activation)

        self.head = get_head(head)(activation=head_activation)
        self.cor_head = get_head(head)(activation=head_activation)


    def f_cov_pred(self, coords_with_cov: torch.Tensor, which_cov_name: str):
        csa_per_cov = self.net_per_cov[which_cov_name](coords_with_cov)
        mu_and_variance = self.head(csa_per_cov)

        # csa_per_cov = self.net_per_cov[which_cov_name](coords_with_cov)
        # coords_with_zero = torch.cat((self.get_coords(coords_with_cov), torch.zeros_like(coords_with_cov[..., self.in_geo_features::])), dim=-1)
        #
        # csa_per_cov_zero = self.net_per_cov[which_cov_name](coords_with_zero)
        # mu_and_variance = self.head(csa_per_cov - csa_per_cov_zero)

        return mu_and_variance

    def atlas_pred(self, coords_with_cov: torch.Tensor):
        arr_coords = self.get_coords(model_input=coords_with_cov)
        mu_and_variance = self.atlas(arr_coords)
        mu_and_variance = self.head(mu_and_variance)
        return mu_and_variance

    def g_corr_pred(self, covariate, which_covariate: str or list):
        set_S = which_covariate
        if isinstance(set_S, str):
            set_S = [set_S]


        arr_input = torch.zeros((covariate.shape[0], len(self.covariate_names)), device=covariate.device)
        idxes = names2idxes(self.dict_cov_idx, set_S)
        arr_input[..., idxes] = covariate

        arr_muter = torch.zeros((covariate.shape[0], len(self.covariate_names)), device=covariate.device)
        arr_muter[..., idxes] = 1

        arr_pred = self.correlation(arr_input, mask=arr_muter)
        num_of_cov = len(self.covariate_names)
        arr_mu, arr_var, arr_covariance = extract_distributional_pars_from_vec(arr_pred, num_of_cov)

        return arr_mu, arr_var, arr_covariance

    def select_src_cov(self, tgt_cov, confusion_matrix=None, HOW='first'):
        if HOW == 'first':
            for ith_cov in self.covariate_names:
                if ith_cov != tgt_cov:
                    return ith_cov
        elif HOW=='max' and confusion_matrix is not None:
            arr_cfs_mtx = confusion_matrix.clone()
            idx_tgt_cov = self.dict_cov_idx[tgt_cov]
            arr_cfs_mtx[..., idx_tgt_cov] = -1000
            idx_matrix = torch.argmax(arr_cfs_mtx[:, idx_tgt_cov, :], dim=-1)
            return idx_matrix


    def concat_geo_and_cov_from_input(self, model_input: torch.Tensor, list_cov_names: list):
        idxes = names2idxes(self.dict_cov_idx, list_cov_names)
        idxes = torch.from_numpy(np.array(idxes)).long()
        if self.in_geo_features == 0:

            arr_cov = model_input[..., self.in_geo_features + idxes]
            arr_out = arr_cov
        else:
            arr_geo = model_input[..., range(self.in_geo_features)]
            arr_cov = model_input[..., [self.in_geo_features + idxes]]
            arr_out = torch.cat((arr_geo, arr_cov), dim=-1)
        return arr_out


    def concat_geo_and_updated_cov(self, model_input: torch.Tensor, arr_covariates: torch.Tensor):
        if self.in_geo_features == 0:
            arr_out = arr_covariates
        else:
            arr_geo = model_input[..., range(self.in_geo_features)]
            arr_out = torch.cat((arr_geo, arr_covariates), dim=-1)
        return arr_out


    def get_covariates(self, model_input):
        arr_features = model_input[..., self.in_geo_features:(len(self.covariate_names) + self.in_geo_features)]
        return arr_features


    def get_muters(self, model_input):
        arr_muters = 1. - model_input[..., (self.in_geo_features+len(self.covariate_names))::]
        return arr_muters


    def encode_coord(self, coords):
        if self.pos_enc:
            coords_encoded = self.pos_encoder(coords)
        else:
            coords_encoded = coords
        return coords_encoded

    def train_correlation(self, arr_covariates: torch.Tensor, arr_mask: torch.Tensor):

        self.correlation.training = True
        arr_pred = self.correlation(arr_covariates, arr_mask)
        num_of_cov = len(self.covariate_names)
        arr_mu, arr_var, arr_covariance = extract_distributional_pars_from_vec(arr_pred, num_of_cov)

        return arr_mu, arr_var, arr_covariance

    def forward(self, model_input):

        # extract covariates
        arr_covariates = self.get_covariates(model_input=model_input)
        # extract muters
        arr_covariate_muter = self.get_muters(model_input)  # 0 means muted and 1 means valid
        # extract correlation
        arr_z_mu_all, arr_z_var_all, arr_z_covariance_all = self.infer_all_correlation(arr_covariates, arr_covariate_muter)


        arr_imputed_mu = arr_z_mu_all
        arr_covariates_ori = self.get_covariates(model_input).clone()
        # impute missing data
        arr_covariates = arr_covariates_ori * arr_covariate_muter + arr_imputed_mu.detach() * (1 - arr_covariate_muter)
        arr_imputed_input = self.concat_geo_and_updated_cov(model_input, arr_covariates)

        overall_mu, overall_var, list_mu, list_var = self.infer_mu_and_var(arr_imputed_input)

        overall_mu_and_var = torch.cat((overall_mu, overall_var), dim=-1)
        arr_muter = arr_covariate_muter.bool() #arr_covariate_muter[:, None, :] * arr_covariate_muter[:, :, None]


        arr_z_mu_all, arr_z_var_all, arr_z_covariance_all = self.train_correlation(arr_covariates_ori, arr_covariate_muter)#self.infer_all_correlation(arr_covariates)
        # about covariance dependence
        dict_correlation = {'mean': arr_z_mu_all, 'variance': arr_z_var_all, 'covariance': arr_z_covariance_all}

        #
        dict_pred_output = {"overall_mu_and_var": overall_mu_and_var,
                            "list_mu": list_mu,
                            "list_var": list_var,
                            "dict_correlation": dict_correlation,
                            "arr_covariates_ori": arr_covariates_ori,
                            "arr_muter": arr_muter,
                            }

        return dict_pred_output

    def infer_mu_and_var(self, model_input):
        # extract covariates
        list_mu = []
        list_var = []
        for idx_cov in range(len(self.covariate_names)):
            arr_coords_cov = self.concat_geo_and_cov_from_input(model_input, [self.dict_idx_covariates[idx_cov]])
            arr_mu_and_var = self.f_cov_pred(coords_with_cov=arr_coords_cov, which_cov_name=self.dict_idx_covariates[idx_cov])
            list_mu.append(arr_mu_and_var[..., [0]])
            list_var.append(arr_mu_and_var[..., [1]])

        bias_mu_and_var = self.atlas_pred(model_input)
        arr_mu = torch.sum(torch.cat(list_mu, dim=-1), dim=-1, keepdim=True) + bias_mu_and_var[..., [0]]
        arr_var = torch.sum(torch.cat(list_var, dim=-1), dim=-1, keepdim=True) + bias_mu_and_var[..., [1]]
        #list_var.append(bias_mu_and_var[..., [1]])
        return arr_mu, arr_var, list_mu, list_var



    def get_coords(self, model_input):
        coords = model_input[..., range(self.in_geo_features)]
        return coords
    def infer_mu_and_var_testing(self, model_input):


        # extract covariates
        arr_covariates = self.get_covariates(model_input=model_input)
        # extract muters
        arr_covariate_muter = self.get_muters(model_input)  # 0 means muted and 1 means valid
        # extract correlation
        arr_z_mu_all, arr_z_var_all, arr_z_covariance_all = self.infer_all_correlation(arr_covariates, arr_covariate_muter)


        arr_imputed_mu = arr_z_mu_all

        # impute missing data
        arr_covariates = arr_covariates * arr_covariate_muter + arr_imputed_mu.detach() * (1 - arr_covariate_muter)
        arr_imputed_input = self.concat_geo_and_updated_cov(model_input, arr_covariates)

        overall_mu, overall_var, _, _ = self.infer_mu_and_var(arr_imputed_input)
        return overall_mu, overall_var


    def infer_all_correlation(self, arr_covariates: torch.Tensor, arr_mask: torch.Tensor):

        self.correlation.training = False
        arr_pred = self.correlation(arr_covariates, arr_mask)
        num_of_cov = len(self.covariate_names)
        arr_mu, arr_var, arr_covariance = extract_distributional_pars_from_vec(arr_pred, num_of_cov)

        return arr_mu, arr_var, arr_covariance


    def infer_with_subnetwork(self, model_input: torch.Tensor, set_S: str or list):

        if isinstance(set_S, str):
            set_S = [set_S]


        # extract covariates
        list_mu = []
        list_var = []
        # for idx_cov in range(len(self.covariate_names)):
        #     if idx_cov == which_idx:

        for ith_feat in range(len(self.covariate_names)):
            k_name = self.dict_idx_covariates[ith_feat]
            if k_name in set_S:
                coords_with_cov = self.concat_geo_and_cov_from_input(model_input, [k_name])
                arr_mu_var = self.f_cov_pred(coords_with_cov=coords_with_cov, which_cov_name=self.dict_idx_covariates[ith_feat])

                list_mu.append(arr_mu_var[..., [0]])
                list_var.append(arr_mu_var[..., [1]])

            # else:
            #     coords_with_cov = self.concat_geo_and_cov_from_input(model_input, [k_name])
            #     coords_with_cov[..., [-1]] = 0
            #     arr_mu_var = self.f_cov_pred(coords_with_cov, self.dict_idx_covariates[ith_feat])
            #
            #     list_var.append(arr_mu_var[..., [1]])

        # atlas
        bias_mu_var = self.atlas_pred(model_input)
        arr_mu = torch.sum(torch.cat(list_mu, dim=-1), dim=-1, keepdim=True) + bias_mu_var[..., [0]]
        arr_var = torch.sum(torch.cat(list_var, dim=-1), dim=-1, keepdim=True) + bias_mu_var[..., [1]]
        return arr_mu, arr_var


    def infer_global_importance(self, model_input: torch.Tensor, set_S: str or list, IGNORE_CORR=False):
        # marginalization
        '''
        model_input: source covariate has different numbers, other are set to 0
        '''

        if isinstance(set_S, str):
            set_S = [set_S]

        # collect along feature dimension
        # corresponding to E[y|c_i, x]
        list_E_of_E = []
        # corresponding to E_c[Var(y|c, x)| ci, x)], E_c is obtained by sampling over c for different f^v(c), i.e., predicted uncertainties
        list_E_of_Var = []
        # corresponding to part 1 of V_c(E[y|c, x]| ci, x)), V_c is obtained by sampling over c for f^m(c), i.e., predicted expectations
        list_Var_of_E_part1 = []



        # we iterate all covariate and sample
        for ith_feat in range(len(self.covariate_names)):
            k_name = self.dict_idx_covariates[ith_feat]
            if k_name in set_S:
                # for source covariate, we use f(c_i) directly from NAM
                coords_cov = self.concat_geo_and_cov_from_input(model_input, [k_name])
                arr_mu_var = self.f_cov_pred(coords_with_cov=coords_cov, which_cov_name=self.dict_idx_covariates[ith_feat])
                list_E_of_E.append(arr_mu_var[..., [0]])
                list_E_of_Var.append(arr_mu_var[..., [1]])
            else:
                # for jth covariate, we need to sample,
                # 1. sample for terms with p(c_j|c_i)
                if IGNORE_CORR:
                    E_of_E, E_of_Var, Var_of_E_part1 = self.E_of_dis_y_j_indp(model_input, set_S, k_name, num_of_samples=5000)
                else:
                    E_of_E, E_of_Var, Var_of_E_part1 = self.E_of_dis_y_j_given_c_i(model_input, set_S, k_name, num_of_samples=5000)
                # 1.1, the expectation of f^m(c, x)
                list_E_of_E.append(E_of_E)
                # 1.2, the expectation of f^v(c, x)
                list_E_of_Var.append(E_of_Var)
                # 1.2, the variance of f^e(c, x), the apart depending on p(c_j|c_i)
                list_Var_of_E_part1.append(Var_of_E_part1)

        if not IGNORE_CORR:
            # 2, the variance of f^e(c, x), the part depending on covariances p(c_j, c_k|c_i)
            arr_Var_of_E_part2 = self.V_of_E_given_c_i_part2(model_input, set_S, list_E_of_E, num_of_samples=5000)


        # atlas
        bias_mu_var = self.atlas_pred(model_input)

        # add along feature dimension
        arr_E = torch.sum(torch.cat(list_E_of_E, dim=-1), dim=-1, keepdim=True) + bias_mu_var[..., [0]]
        arr_E_of_Var = torch.sum(torch.cat(list_E_of_Var, dim=-1), dim=-1, keepdim=True) + bias_mu_var[..., [1]]
        arr_Var_of_E_part1 = torch.sum(torch.cat(list_Var_of_E_part1, dim=-1), dim=-1, keepdim=True)

        # add the part2
        # Var = E_of_Var + Var_of_E = E_of_Var = Var_of_E_part1 + Var_of_E_part2
        if IGNORE_CORR:
            arr_Var = arr_E_of_Var + arr_Var_of_E_part1
        else:
            arr_Var = arr_E_of_Var + arr_Var_of_E_part1 + arr_Var_of_E_part2

        if arr_Var.min()<0:
            print(arr_Var.min())
        return arr_E, arr_Var


    def E_of_dis_y_j_indp(self, model_input, src_cov_name, tgt_cov_name, num_of_samples = 10000):
        # reparameterization
        arr_uniforms = torch.rand((model_input.shape[0], num_of_samples), device=model_input.device) - 0.5 #torch.normal(mean=0, std=1, size=(model_input.shape[0], num_of_samples), device=model_input.device)
        arr_cov_samples = arr_uniforms * 4

        # query subnetwork
        current_input = self.make_query_samples_for_f(model_input, arr_cov_samples, ck_name=tgt_cov_name, num_of_samples=num_of_samples)

        arr_mu_var = self.f_cov_pred(coords_with_cov=current_input.reshape(-1, self.in_geo_features+1), which_cov_name=tgt_cov_name)
        arr_mu_var = arr_mu_var.reshape(model_input.shape[0], -1, 2) # 2nd dim is the sampling dim,  to be averaged

        # calculate expectations
        E_of_E = torch.mean(arr_mu_var[..., [-2]], dim=-2)  # average f^m
        E_of_Var = torch.mean(arr_mu_var[..., [-1]], dim=-2) # average f^v #arr_mu_var_for_E[..., [-1]]
        # calculate variances
        V_of_E_part1 = torch.var(arr_mu_var[..., [-2]], dim=-2) # var f^m
        return E_of_E, E_of_Var, V_of_E_part1



    def growth_velocity_distri(self, model_input_t0_ori, gt_t0_ori, model_input_t1_ori):
        model_input_t0 = model_input_t0_ori.clone()
        model_input_t1 = model_input_t1_ori.clone()
        gt_t0 = gt_t0_ori.clone()

        # sometimes of t1 is incomplete observation
        if model_input_t0.shape[0] >= model_input_t1.shape[0]:
            model_input_t0 = model_input_t0[0: model_input_t1.shape[0], ...]
            gt_t0 = gt_t0[0: model_input_t1.shape[0], ...]

            f_mu_t0, f_var_t0 = self.infer_mu_and_var_testing(model_input_t0)
            f_mu_t1, f_var_t1 = self.infer_mu_and_var_testing(model_input_t1)

            f_ind_t1 = f_mu_t1 + (gt_t0 - f_mu_t0) * torch.sqrt(f_var_t1) / torch.sqrt(f_var_t0)
            return f_ind_t1, f_mu_t1, gt_t0
        else:
            # dim of t1>t0
            f_mu_t0, f_var_t0 = self.infer_mu_and_var_testing(model_input_t0)
            f_mu_t1, f_var_t1 = self.infer_mu_and_var_testing(model_input_t1)
            f_ind_t1 = f_mu_t1.clone()
            f_ind_t1[0: model_input_t0.shape[0], ...] = \
                f_mu_t1[0: model_input_t0.shape[0], ...] + (gt_t0 - f_mu_t0) * torch.sqrt(f_var_t1[0: model_input_t0.shape[0], ...]) / torch.sqrt(f_var_t0)

            gt_t0_pseudo = f_mu_t1.clone()
            gt_t0_pseudo[0: model_input_t0.shape[0], ...] = gt_t0

            return f_ind_t1, f_mu_t1, gt_t0_pseudo

    def growth_velocity_mean(self, model_input_t0_ori, gt_t0_ori, model_input_t1_ori):
        model_input_t0 = model_input_t0_ori.clone()
        model_input_t1 = model_input_t1_ori.clone()
        gt_t0 = gt_t0_ori.clone()

        # sometimes of t1 is incomplete observation
        if model_input_t0.shape[0] >= model_input_t1.shape[0]:
            model_input_t0 = model_input_t0[0: model_input_t1.shape[0], ...]
            gt_t0 = gt_t0[0: model_input_t1.shape[0], ...]

            f_mu_t0, f_var_t0 = self.infer_mu_and_var_testing(model_input_t0)
            f_mu_t1, f_var_t1 = self.infer_mu_and_var_testing(model_input_t1)

            f_ind_t1 = gt_t0 + (f_mu_t1 - f_mu_t0) #f_mu_t1 + (gt_t0 - f_mu_t0) * torch.sqrt(f_var_t1) / torch.sqrt(f_var_t0)
            return f_ind_t1, f_mu_t1, gt_t0
        else:
            # dim of t1>t0
            f_mu_t0, f_var_t0 = self.infer_mu_and_var_testing(model_input_t0)
            f_mu_t1, f_var_t1 = self.infer_mu_and_var_testing(model_input_t1)
            f_ind_t1 = f_mu_t1.clone()
            f_ind_t1[0: model_input_t0.shape[0], ...] = gt_t0 + (f_mu_t1[0: model_input_t0.shape[0], ...] - f_mu_t0)

            gt_t0_pseudo = f_mu_t1.clone()
            gt_t0_pseudo[0: model_input_t0.shape[0], ...] = gt_t0

            return f_ind_t1, f_mu_t1, gt_t0_pseudo

    def make_query_samples_for_f(self, model_input: torch.Tensor, arr_ck: torch.Tensor, ck_name: str, num_of_samples: int):

        arr_rep_input = torch.repeat_interleave(model_input[:, None, ...], dim=1, repeats=num_of_samples)  # the depth of airway
        ck_idx = self.dict_cov_idx[ck_name]
        arr_rep_input[..., ck_idx + self.in_geo_features] = arr_ck.squeeze()
        return arr_rep_input


    def p_ck_given_cS(self, arr_covariates: torch.Tensor, set_S_names: list, k_names: str or list):
        """
        Compute p(z_S | z_k) using Gaussian copula:
        p(z_S | c_k) = p(z_S | z_k)

        Args:
            arr_covariates: torch tensor of covariates
            set_S_names: names of covariates in set S
            k_name: name of k-th covariate

        Returns:
            Scalar tensor: estimated conditional density
        """
        if isinstance(k_names, str):
            k_names = [k_names]


        idxes_S = names2idxes(self.dict_cov_idx, set_S_names)
        arr_covariates_used = arr_covariates[..., idxes_S]

        # Step 2: neural model predicts μ(z_k), Σ(z_k)
        mu_z, var_z, covar_z = self.g_corr_pred(covariate=arr_covariates_used, which_covariate=set_S_names)

        idxes_k = names2idxes(self.dict_cov_idx, k_names)
        # Extract target-specific parameters
        mu_zS = mu_z[..., idxes_k]  # (B,)
        var_zS = var_z[..., idxes_k]  # (B,)
        cov_zS = covar_z[..., idxes_k, :][..., :, idxes_k]  # (B,)
        return mu_zS, var_zS, cov_zS



    def sampling_from_p_ck_given_cS(self, mu_ck: torch.Tensor, covar_ck: torch.Tensor, num_samples: int, clamp_std: float = 2.0):
        """
        从条件高斯分布 p(c_k | c_S) ~ N(mu_ck, covar_ck) 中采样

        Args:
            mu_ck: Tensor of shape (B, D_k), 条件均值
            covar_ck: Tensor of shape (B, D_k, D_k), 条件协方差
            num_samples: int, 每个 batch 样本数
            clamp_std: float, 用于截断的标准差范围（默认 ±2）

        Returns:
            samples: Tensor of shape (B, num_samples, D_k)
        """
        B, D_k = mu_ck.shape

        if mu_ck.shape[-1] == 1:
            L = torch.sqrt(covar_ck)
        else:
            # 1. 对协方差矩阵做Cholesky分解
            L = torch.linalg.cholesky(
                covar_ck + 1e-6 * torch.eye(D_k, device=covar_ck.device).unsqueeze(0))  # (B, D_k, D_k)

        # 2. 生成标准正态 eps (B, num_samples, D_k)
        if covar_ck.sum() == 0: # no noise in covariates
            eps = torch.zeros((B, num_samples, D_k), device=mu_ck.device)
        else:
            eps = torch.randn(B, num_samples, D_k, device=mu_ck.device)

        eps = torch.clamp(eps, -clamp_std, clamp_std)

        # 3. 乘以 L 得到协方差结构 (B, num_samples, D_k)
        samples = mu_ck.unsqueeze(1) + torch.matmul(eps, L.transpose(1, 2))  # (B, num_samples, D_k)

        return samples





    def E_of_dis_y_j_given_c_i(self, model_input: torch.Tensor, set_S_names: list, ck_name: str, num_of_samples: int = 512):
        # # monte carlo

        arr_covariates = self.get_covariates(model_input)
        mu_ck, var_ck, covar_ck = self.p_ck_given_cS(arr_covariates, set_S_names, [ck_name])
        #mu_ck, covar_ck = self.select_per_sample_best_s1(arr_covariates, set_S_names, [ck_name])


        samples_ck = self.sampling_from_p_ck_given_cS(mu_ck, covar_ck, num_of_samples)
        arr_samples_ = self.make_query_samples_for_f(model_input, samples_ck, ck_name, num_of_samples)

        # query subnetwork
        arr_samples_shape = arr_samples_.shape
        arr_samples = arr_samples_.clone().reshape(-1, arr_samples_shape[-1])
        coords_cov = self.concat_geo_and_cov_from_input(arr_samples, [ck_name])

        # f
        arr_mu_var = self.f_cov_pred(coords_with_cov=coords_cov.reshape(-1, self.in_geo_features + 1), which_cov_name=ck_name)
        arr_mu_var = arr_mu_var.reshape(model_input.shape[0], -1, 2) # 2nd dim is the sampling dim,  to be averaged

        #
        f_ck = arr_mu_var[..., [-2]].reshape((arr_samples_shape[0], -1, 1))
        V_ck = arr_mu_var[..., [-1]].reshape((arr_samples_shape[0], -1, 1))

        # 5) Monte Carlo integrals:
        # normalization p(c_S)
        # E[f^v] = E(f_v * p_cS_ck)
        E_of_Var = V_ck.mean(dim=1, keepdim=False)

        # E[f^m] = Ef_m * p_cS_ck)
        E_of_E = f_ck.mean(dim=1, keepdim=False)

        # Var[f^m] = (mean((f_m - E_of_E)^2 * p_cS_ck)
        V_of_E_part1 = ((f_ck - E_of_E[:, None, ...]) ** 2).mean(dim=1, keepdim=False)

        return E_of_E, E_of_Var, V_of_E_part1







    def make_query_samples_for_Cov_K1K2(self,
                                        model_input: torch.Tensor,
                                        samples_c_K1: torch.Tensor,
                                        samples_c_K2: torch.Tensor,
                                        k1_name: str,
                                        k2_name: str,
                                        num_of_samples: int):


        assert self.in_geo_features <= model_input.shape[-1], "Geo feature index exceeds input dimension!"


        arr_rep_input = torch.repeat_interleave(model_input[:, None, ...], dim=1, repeats=num_of_samples)  # the depth of airway
        # arr_rep_c_K1 = torch.repeat_interleave(samples_c_K1[None, ...], dim=0, repeats=model_input.shape[0])
        # arr_rep_c_K2 = torch.repeat_interleave(samples_c_K2[None, ...], dim=0, repeats=model_input.shape[0])

        K1_idx = self.dict_cov_idx[k1_name]
        K2_idx = self.dict_cov_idx[k2_name]

        arr_rep_input[..., K1_idx + self.in_geo_features] = samples_c_K1.squeeze()
        arr_rep_input[..., K2_idx + self.in_geo_features] = samples_c_K2.squeeze()

        return arr_rep_input





    def V_of_E_given_c_i_part2(self, model_input: torch.Tensor, set_S_names: str or list, list_of_E: list, num_of_samples=1000):
        set_S_names = check_S_set(set_S_names)
        list_S_idxes = names2idxes(self.dict_cov_idx, set_S_names)
        num_of_cov = len(self.covariate_names)
        list_of_cov_K1K2 = []



        for k1 in range(num_of_cov):
            for k2 in range(num_of_cov):
                if (k1 not in list_S_idxes) and (k2 not in list_S_idxes) and k1!=k2:

                    k1_name = self.dict_idx_covariates[k1]
                    k2_name = self.dict_idx_covariates[k2]

                    arr_covariates = self.get_covariates(model_input)
                    mu_ck, _, covar_ck = self.p_ck_given_cS(arr_covariates, set_S_names, [k1_name, k2_name])

                    arr_samples = self.sampling_from_p_ck_given_cS(mu_ck, covar_ck, num_of_samples)

                    samples_c_K1, samples_c_K2 = arr_samples[..., [0]], arr_samples[..., [1]]

                    # f_{k1}
                    arr_samples = self.make_query_samples_for_Cov_K1K2(model_input, samples_c_K1, samples_c_K2, k1_name, k2_name, samples_c_K1.shape[1])
                    coords_cov_K1 = self.concat_geo_and_cov_from_input(arr_samples, [k1_name])

                    f_mu_k1 = self.f_cov_pred(coords_with_cov=coords_cov_K1.reshape(-1, self.in_geo_features + 1), which_cov_name=k1_name)[..., [0]]
                    f_mu_k1 = f_mu_k1.reshape(model_input.shape[0], -1)

                    # f_{k2}
                    coords_cov_K2 = self.concat_geo_and_cov_from_input(arr_samples, [k2_name])

                    f_mu_k2 = self.f_cov_pred(coords_with_cov=coords_cov_K2.reshape(-1, self.in_geo_features + 1), which_cov_name=k2_name)[..., [0]]
                    f_mu_k2 = f_mu_k2.reshape(model_input.shape[0], -1)


                    # E[f1*f2] = mean(f1 * f2 * w12) / p_cS
                    cov_K1K2 = (f_mu_k1 * f_mu_k2).mean(dim=1, keepdim=False)
                    circ4 = cov_K1K2[..., None] - list_of_E[k1] * list_of_E[k2]
                    list_of_cov_K1K2.append(circ4)

        if len(list_of_cov_K1K2) == 0:
            V_of_E_part2 = 0
        else:
            V_of_E_part2 = torch.sum(torch.cat(list_of_cov_K1K2, dim=-1), dim=-1, keepdim=True)

        return V_of_E_part2

    def infer_cov_correlation(self, arr_covariate, src_cov_name, tgt_cov_name):
        tgt_idx = self.dict_cov_idx[tgt_cov_name]
        src_idx = self.dict_cov_idx[src_cov_name]
        arr_feat_mu, arr_feat_var, arr_feat_covariance =  self.g_corr_pred(covariate=arr_covariate[..., [src_idx]], which_covariate=src_cov_name)
        arr_feat_mu_tgt, arr_feat_var_tgt = arr_feat_mu[..., [tgt_idx]],  arr_feat_var[..., [tgt_idx]]
        return arr_feat_mu_tgt, arr_feat_var_tgt, arr_feat_covariance




