import numpy as np
import torch
import torch.nn as nn
from sklearn.mixture import GaussianMixture
from sklearn.mixture._gaussian_mixture import _compute_precision_cholesky
from scipy.stats import rankdata
import os
import pickle
import einops
from linetimer import CodeTimer
from torch.distributions import Beta, Normal
import torchsort
from einops import rearrange
#from random_norm_flow import run_realnvp_batch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import torch
import torch.nn as nn
import torch.nn.functional as F

class CouplingLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.scale_net = nn.Sequential(
            nn.Linear(input_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim // 2),
            nn.Tanh()
        )
        self.translate_net = nn.Sequential(
            nn.Linear(input_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim // 2)
        )
        
    def forward(self, x):
        x_a, x_b = x.chunk(2, dim=-1)
        s = self.scale_net(x_a)
        t = self.translate_net(x_a)
        y_b = x_b * torch.exp(s) + t
        y = torch.cat([x_a, y_b], dim=-1)
        return y

class RealNVP(nn.Module):
    def __init__(self, dim=16, hidden_dim=128, n_coupling_layers=4):
        super().__init__()
        self.dim = dim
        self.n_coupling_layers = n_coupling_layers
        self.layers = nn.ModuleList([
            CouplingLayer(dim, hidden_dim) for _ in range(n_coupling_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    def randomize_params_partial(self, percent=0.5, control_vector=None):
        n = len(self.layers)
        k = int(n * percent)
        indices = torch.randperm(n)[:k]
        for idx in indices:
            layer = self.layers[idx]
            if control_vector is not None:
                control = control_vector.to(layer.scale_net[0].bias.device)
                with torch.no_grad():
                    layer.scale_net[0].bias.copy_(control)

def run_realnvp_batch(batch_input):
    batchsize, seq_len, dim = batch_input.shape

    outputs = torch.zeros_like(batch_input)

    for i in range(batchsize):
        model = RealNVP(dim=dim, hidden_dim=128, n_coupling_layers=4)

        def init_weights(m):
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.05)
                nn.init.constant_(m.bias, 0)
        model.apply(init_weights)

        control_vector = torch.randn(128) * 0.01
        model.randomize_params_partial(percent=0.5, control_vector=control_vector)

        with torch.no_grad():
            x_i = batch_input[i:i+1]  # [1, 5000, 16]
            z_i = model(x_i)  # [1, 5000, 16]
            outputs[i] = z_i.squeeze(0)  # [5000, 16]

    return outputs

def scale_data_pytorch(input_tensor, method="linear_scaling"):

    if method=="linear_scaling":
        ## linear scale [batchsize, seq_len, dim] to [-1,1] on seq_len
        min_val = torch.min(input_tensor, dim=1, keepdim=True).values
        max_val = torch.max(input_tensor, dim=1, keepdim=True).values
        scaled_tensor = 2 * (input_tensor - min_val) / (max_val - min_val) - 1

    elif method=="standardization":
        #assert False, 'Forbid to use'
        mean_val = torch.mean(input_tensor, dim=1, keepdim=True)
        std_val = torch.std(input_tensor, dim=1, keepdim=True)
        scaled_tensor = (input_tensor - mean_val) / std_val

    return scaled_tensor

def softrank_preprocessing(input_tensor, regularization_strength=0.1, gauss_copula=True, gauss_range=1.1):
    ## x has shape [batchsize, seq_len, dim]
    b, n, d = input_tensor.shape
    min_val = torch.min(input_tensor, dim=1, keepdim=True).values
    max_val = torch.max(input_tensor, dim=1, keepdim=True).values
    scaled_tensor = (input_tensor - min_val) / (max_val - min_val)
    
    softrank = torchsort.soft_rank(rearrange(scaled_tensor, 'b n d -> (d b) n'), regularization_strength=regularization_strength)
    softrank = rearrange(softrank, '(d b) n -> b n d', d=d)

    min_val = torch.min(softrank, dim=1, keepdim=True).values
    max_val = torch.max(softrank, dim=1, keepdim=True).values
    scaled_softrank = ((softrank - min_val) / (max_val - min_val))/gauss_range + (0.5-0.5/gauss_range)
    
    if gauss_copula==True:
        normal_dist = torch.distributions.Normal(0, 1)
        scaled_softrank = scaled_softrank.permute(0, 2, 1)
        scaled_softrank = normal_dist.icdf(scaled_softrank)
        scaled_softrank = scaled_softrank.permute(0, 2, 1)

    return scaled_softrank

def softrank_preprocessing_new(input_tensor, regularization_strength=0.1, gauss_copula=True, gauss_range=1.1):
    ## x has shape [batchsize, seq_len, dim]
    b, n, d = input_tensor.shape
    #min_val = torch.min(input_tensor, dim=1, keepdim=True).values
    #max_val = torch.max(input_tensor, dim=1, keepdim=True).values
    #scaled_tensor = (input_tensor - min_val) / (max_val - min_val)
    scaled_tensor = input_tensor
    
    softrank = torchsort.soft_rank(rearrange(scaled_tensor, 'b n d -> (d b) n'), regularization_strength=regularization_strength)
    softrank = rearrange(softrank, '(d b) n -> b n d', d=d)
    
    softrank = torchsort.soft_rank(rearrange(scaled_tensor, 'b n d -> (d b) n'), regularization_strength=regularization_strength)
    softrank = rearrange(softrank, '(d b) n -> b n d', d=d)

    min_val = torch.min(softrank, dim=1, keepdim=True).values
    max_val = torch.max(softrank, dim=1, keepdim=True).values
    scaled_softrank = ((softrank - min_val) / (max_val - min_val))/gauss_range + (0.5-0.5/gauss_range)
    
    if gauss_copula==True:
        normal_dist = torch.distributions.Normal(0, 1)
        scaled_softrank = scaled_softrank.permute(0, 2, 1)
        scaled_softrank = normal_dist.icdf(scaled_softrank)
        scaled_softrank = scaled_softrank.permute(0, 2, 1)
    
    return scaled_softrank

def gauss_noise_padding(batch, aim_dim=32, perm=True):
    # batch has shape [batchsize, seq_len, current_dim]
    current_dim = batch.shape[2]
    batchsize = batch.shape[0]
    seq_len = batch.shape[1]

    if current_dim>aim_dim:
        raise ValueError("current dimension is larger than the padding dimension!")
    elif current_dim==aim_dim:
        return batch
    
    padding_dim = aim_dim - current_dim
    
    noise = torch.randn(batchsize, seq_len, padding_dim)
    
    padded_batch = torch.cat((batch, noise), dim=2)

    if perm==True:
        re_setdim = torch.randperm(aim_dim).to(device)
        padded_batch = torch.index_select(padded_batch, 2, re_setdim)

    padded_batch = scale_data_pytorch(padded_batch)
    
    return padded_batch

def random_corr_factor(batchsize, d, k_values, device='cuda'):
    max_k = torch.max(k_values)  # padding max k
    W = torch.randn(batchsize, d, max_k, device=device)  

    mask = (torch.arange(max_k, device=device)[None, None, :] < k_values[:, None, None])
    W *= mask.float() 

    S = torch.bmm(W, W.transpose(1, 2)) + torch.diag_embed(torch.rand(batchsize, d, device=device))

    diag_S = torch.diagonal(S, dim1=-2, dim2=-1)
    D_inv_sqrt = torch.diag_embed(1.0 / torch.sqrt(diag_S))
    S_normalized = torch.bmm(torch.bmm(D_inv_sqrt, S), D_inv_sqrt)

    return S_normalized           

def gen_gmm_samples_factor(batchsize, seq_len, dim, num_mixtures, eigen_range=10, device='cuda'):
    ## due to vectorization limit, in each batch, all the gmm has same gauss components number and weights 
    while True:
        try:
            weights = torch.rand(num_mixtures, device=device) + 1e-6
            weights /= weights.sum()
            if torch.any(weights<0):
                weights = torch.clamp(weights, min=0.0)
                weights /= weights.sum()
            #weights = weights.unsqueeze(0).repeat(batchsize, 1)

            means = torch.rand(batchsize, num_mixtures, 2*dim, device=device) * 10 - 5  # Uniform(-5, 5)

            
            # gen covariance matrix
            # random factor (influence the off-diagonal value, less k larger off-diagonal)
            if dim<=10:
                k_values = torch.randint(dim, int(2*dim+1), (batchsize*num_mixtures.item(),)).to(device)
            else:
                k_values = torch.randint(1, int(torch.round(6*dim/3.5)+1), (batchsize*num_mixtures.item(),)).to(device)
            
            correlation_matrices = random_corr_factor(batchsize*num_mixtures, 2*dim, k_values, device=device).reshape(batchsize, num_mixtures, 2*dim, 2*dim)
            standard_deviation = torch.sqrt(torch.rand(batchsize, num_mixtures, 2*dim, device=device) * eigen_range + 0.01)
            SD = torch.diag_embed(standard_deviation)
            cov = torch.einsum('bmij,bmjk,bmkl->bmil', SD, correlation_matrices, SD)

            scale_tril = torch.linalg.cholesky(cov)

            components = torch.multinomial(weights, seq_len, replacement=True)  # Shape: [batchsize, seq_len]

            samples = torch.zeros(batchsize, seq_len, 2*dim, device=device)

            for mix in range(num_mixtures):
                
                mask = components == mix
                num_samples = torch.sum(mask)
              
                if num_samples > 0:
                    selected_means = means[:, mix, :]
                    selected_trils = scale_tril[:, mix, :, :]
                    # dist = torch.distributions.MultivariateNormal(selected_means, scale_tril=selected_trils)
                    sample = einops.rearrange(torch.distributions.MultivariateNormal(selected_means, scale_tril=selected_trils).sample((num_samples,)),
                                      'len bsz dim -> bsz len dim')
                    
                    samples[:, mask, :] = sample 
                    
            return samples

        except RuntimeError as e:
            if 'cholesky' in str(e).lower() or 'positive definite' in str(e).lower():
                pass
            else:
                raise

def pad_gmm_data(batch, aim_dim=32, perm=True, device="cuda"):
    # batch has shape [batchsize, seq_len, current_dim*2]
    current_dim = batch.shape[2]//2
    batchsize = batch.shape[0]
    seq_len = batch.shape[1]

    if current_dim>aim_dim:
        raise ValueError("current dimension is larger than the padding dimension!")
    elif current_dim==aim_dim:
        return batch
    
    x_samples = batch[:,:,:current_dim]
    y_samples = batch[:,:,current_dim:2*current_dim]
    padding_dim = aim_dim - current_dim
    
    noise_x = torch.randn(batchsize, seq_len, padding_dim, device=device)
    noise_y = torch.randn(batchsize, seq_len, padding_dim, device=device)
    
    padded_x = torch.cat((x_samples, noise_x), dim=2)
    padded_y = torch.cat((y_samples, noise_y), dim=2)

    if perm==True:
        re_setdim = torch.randperm(aim_dim).to(device)
        padded_x = torch.index_select(padded_x, 2, re_setdim)
        padded_y = torch.index_select(padded_y, 2, re_setdim)

    padded_batch = torch.cat((padded_x, padded_y), dim=2)
    #padded_batch = scale_data_pytorch(padded_batch, method="standardization")
    padded_batch = scale_data_pytorch(padded_batch, method="linear_scaling")
    return padded_batch

def zero_pad_data(batch, aim_dim=16, perm=False, device="cuda"):
    # batch has shape [batchsize, seq_len, current_dim*2]
    current_dim = batch.shape[2] // 2
    batchsize = batch.shape[0]
    seq_len = batch.shape[1]

    if current_dim > aim_dim:
        raise ValueError("current dimension is larger than the padding dimension!")
    elif current_dim == aim_dim:
        return batch

    x_samples = batch[:, :, :current_dim]
    y_samples = batch[:, :, current_dim:2*current_dim]
    padding_dim = aim_dim - current_dim

    noise_x = torch.zeros(batchsize, seq_len, padding_dim, device=device)
    noise_y = torch.zeros(batchsize, seq_len, padding_dim, device=device)

    padded_x = torch.cat((x_samples, noise_x), dim=2)
    padded_y = torch.cat((y_samples, noise_y), dim=2)
    
    if perm:
        re_setdim = torch.randperm(aim_dim).to(device)
        padded_x = torch.index_select(padded_x, 2, re_setdim)
        padded_y = torch.index_select(padded_y, 2, re_setdim)

    padded_batch = torch.cat((padded_x, padded_y), dim=2)
    #padded_batch = scale_data_pytorch(padded_batch, method="linear_scaling")
    return padded_batch

def gen_train_dataset(batchsize, seq_len, dim, max_num_mixtures=50, eigen_range=10, regularization_strength=0.1, device="cuda"):
    
    dim = torch.tensor(dim).to(device)
    num_weights = torch.tensor([i for i in range(1, max_num_mixtures+1)])
    num_prob = num_weights / num_weights.sum()
    num_mixtures = torch.multinomial(num_prob, 1) + 1
    num_mixtures = num_mixtures.to(device)

    gmm_sample = gen_gmm_samples_factor(batchsize=batchsize, seq_len=seq_len, dim=dim, num_mixtures=num_mixtures, eigen_range=eigen_range, device=device)
    padded_samples = pad_gmm_data(gmm_sample, aim_dim=dim, perm=True, device=device)
    padded_samples = softrank_preprocessing(padded_samples, regularization_strength=regularization_strength, gauss_copula=True)

    return padded_samples

def gen_train_dataset_normflow(batchsize, seq_len, dim, max_num_mixtures=50, eigen_range=10, regularization_strength=0.1, device="cuda"):
    
    dim = torch.tensor(dim).to(device)
    num_weights = torch.tensor([i for i in range(1, max_num_mixtures+1)])
    num_prob = num_weights / num_weights.sum()
    num_mixtures = torch.multinomial(num_prob, 1) + 1
    num_mixtures = num_mixtures.to(device)

    gmm_sample = gen_gmm_samples_factor(batchsize=batchsize, seq_len=seq_len, dim=dim, num_mixtures=num_mixtures, eigen_range=eigen_range, device=device)
    if torch.rand(1, device=device).item() < 0.8:
        gmm_normflow = run_realnvp_batch(gmm_sample)
    else:
        gmm_normflow = gmm_sample
    #gmm_normflow = run_realnvp_batch(gmm_sample)
    padded_samples = pad_gmm_data(gmm_normflow, aim_dim=dim, perm=True, device=device)
    padded_samples = softrank_preprocessing(padded_samples, regularization_strength=regularization_strength, gauss_copula=True)

    return padded_samples

def gen_train_dataset_mixed_dim(batchsize, seq_len, max_dim=10, max_num_mixtures=60, eigen_range=10, regularization_strength=0.1, device="cuda"):
    
    dim_weights = torch.tensor([i**2.7 for i in range(1, max_dim + 1)], dtype=torch.float32)
    dim_prob = dim_weights / dim_weights.sum()
    sampled_dim = torch.multinomial(dim_prob, batchsize, replacement=True) + 1 
    padded_samples_list = []
    for dim in sampled_dim:

        num_weights = torch.tensor([i**1.5 for i in range(1, max_num_mixtures + 1)])
        num_prob = num_weights / num_weights.sum()
        num_mixtures = torch.multinomial(num_prob, 1) + 1
        num_mixtures = num_mixtures.to(device)

        gmm_sample = gen_gmm_samples_factor(
            batchsize=1,  
            seq_len=seq_len,
            dim=dim,
            num_mixtures=num_mixtures,
            eigen_range=eigen_range,
            device=device,
        )
        '''
        if torch.rand(1, device=device).item() < 0.5:
            gmm_sample = run_realnvp_batch(gmm_sample)
        '''
        padded_samples_list.append(gmm_sample)

    padded_samples = torch.cat(padded_samples_list, dim=0)
    padded_samples = pad_gmm_data(padded_samples, aim_dim=max_dim, perm=False, device=device)
    padded_samples = softrank_preprocessing(padded_samples, regularization_strength=regularization_strength, gauss_copula=True)
    #padded_samples = zero_pad_data(padded_samples, aim_dim=max_dim, perm=False, device=device)

    return padded_samples

def torch_rankdata(tensor, method='average'):
    
    tensor_device = tensor.device 
    tensor = tensor.float()
    sorted_indices = torch.argsort(tensor)
    ranks = torch.empty_like(tensor, dtype=torch.float32)

    if method == 'average':
        unique_values, inverse_indices, counts = torch.unique(tensor, return_inverse=True, return_counts=True)
        cumulative_counts = torch.cumsum(counts, dim=0) - counts / 2 - 0.5
        ranks = cumulative_counts[inverse_indices]
    elif method == 'min':
        ranks[sorted_indices] = torch.arange(1, len(tensor) + 1, dtype=torch.float32, device=tensor_device)
    elif method == 'max':
        ranks[sorted_indices] = torch.arange(1, len(tensor) + 1, dtype=torch.float32, device=tensor_device)[::-1]
    else:
        raise ValueError("Unsupported method. Use 'average', 'min', or 'max'.")
    
    return ranks + 1

def gen_train_dataset_lowdim(batchsize, seq_len, max_dim=5, max_num_mixtures=50, eigen_range=10, regularization_strength=0.1, device="cuda"):
    
    dim_weights = torch.tensor([i**2.5 for i in range(1, max_dim + 1)], dtype=torch.float32)
    dim_prob = dim_weights / dim_weights.sum()
    sampled_dim = torch.multinomial(dim_prob, batchsize, replacement=True) + 1 
    padded_samples_list = []
    for dim in sampled_dim:

        num_weights = torch.tensor([i**1.5 for i in range(1, max_num_mixtures + 1)])
        num_prob = num_weights / num_weights.sum()
        num_mixtures = torch.multinomial(num_prob, 1) + 1
        num_mixtures = num_mixtures.to(device)

        gmm_sample = gen_gmm_samples_factor(
            batchsize=1,  
            seq_len=seq_len,
            dim=dim,
            num_mixtures=num_mixtures,
            eigen_range=eigen_range,
            device=device,
        )
        if torch.rand(1, device=device).item() < 0.75:
            gmm_sample = run_realnvp_batch(gmm_sample)
        
        padded_samples_list.append(gmm_sample)

    padded_samples = torch.cat(padded_samples_list, dim=0)
    padded_samples = softrank_preprocessing(padded_samples, regularization_strength=regularization_strength, gauss_copula=True)
    padded_samples = pad_gmm_data(padded_samples, aim_dim=max_dim, perm=False, device=device)

    return padded_samples

def gen_train_dataset_lowdim_hardrank(batchsize, seq_len, max_dim=10, max_num_mixtures=50, eigen_range=10, regularization_strength=0.1, device="cuda"):
    
    dim_weights = torch.tensor([i**2 for i in range(1, max_dim + 1)], dtype=torch.float32)
    dim_prob = dim_weights / dim_weights.sum()
    sampled_dim = torch.multinomial(dim_prob, batchsize, replacement=True) + 1 
    padded_samples_list = []
    for dim in sampled_dim:

        num_weights = torch.tensor([i**1.5 for i in range(1, max_num_mixtures + 1)])
        num_prob = num_weights / num_weights.sum()
        num_mixtures = torch.multinomial(num_prob, 1) + 1
        num_mixtures = num_mixtures.to(device)

        gmm_sample = gen_gmm_samples_factor(
            batchsize=1,  
            seq_len=seq_len,
            dim=dim,
            num_mixtures=num_mixtures,
            eigen_range=eigen_range,
            device=device,
        )
        
        if torch.rand(1, device=device).item() < 0.7:
            gmm_sample = run_realnvp_batch(gmm_sample)
        
        padded_samples_list.append(gmm_sample)

    padded_samples = torch.cat(padded_samples_list, dim=0)
    padded_samples = pad_gmm_data(padded_samples, aim_dim=max_dim, perm=False, device=device)
    padded_samples = softrank_preprocessing(padded_samples, regularization_strength=regularization_strength, gauss_copula=True)
    #padded_samples = zero_pad_data(padded_samples, aim_dim=max_dim, perm=False, device=device)

    return padded_samples


def random_corr_factor(batchsize, d, k_values, device='cpu',
                       corr_strength=None, diag_strength=None):
    """
    生成归一化的相关性矩阵 S_normalized
    Args:
        batchsize (int): 批大小
        d (int): 矩阵维度
        k_values (Tensor): 控制非零列数的向量，shape=(batchsize,)
        device (str): 'cpu' 或 CUDA 设备
        corr_strength (Tensor or float, optional): 非对角相关强度，shape=(batchsize,1) 或 scalar
        diag_strength (Tensor or float, optional): 对角项噪声强度，shape=(batchsize,) or scalar
    Returns:
        S_normalized (Tensor): 归一化后的相关性矩阵，shape=(batchsize, d, d)
    """
    max_k = torch.max(k_values)
    W = torch.randn(batchsize, d, max_k, device=device)

    # 如果没有提供 corr_strength，则随机生成 [0.5, 1.0) 之间的相关度
    if corr_strength is None:
        corr_strength = torch.rand(batchsize, 1, device=device) * 0.5 + 0.5

    # 生成掩码并应用到 W 上
    mask = (torch.arange(max_k, device=device)[None, None, :] < k_values[:, None, None])
    W = W * mask.float()

    # 调节非对角相关强度
    W = W * corr_strength.view(-1, 1, 1)

    # 计算未归一化的协方差矩阵 S
    S = torch.bmm(W, W.transpose(1, 2))

    # 对角噪声项：如果提供了 diag_strength，则使用固定强度；否则随机
    if diag_strength is None:
        # 默认对角噪声在 [0, 1) 范围
        diag_noise = torch.rand(batchsize, d, device=device)
    else:
        # 如果是 float 或长度为1的 tensor，则扩展到每个元素
        if isinstance(diag_strength, (float, int)):
            diag_strength = torch.full((batchsize, d), float(diag_strength), device=device)
        diag_noise = torch.rand(batchsize, d, device=device) * diag_strength

    S = S + torch.diag_embed(diag_noise)

    # 对 S 做归一化：D^{-1/2} S D^{-1/2}
    diag_S = torch.diagonal(S, dim1=-2, dim2=-1)
    D_inv_sqrt = torch.diag_embed(1.0 / torch.sqrt(diag_S + 1e-12))
    S_normalized = torch.bmm(torch.bmm(D_inv_sqrt, S), D_inv_sqrt)

    return S_normalized


def gen_gmm_parameters(batchsize, dim, num_mixtures,
                       eigen_range=10, device='cpu',
                       corr_strength_value=0.1, diag_strength_value=2):
    """
    生成 GMM 的参数：weights, means, covariances
    Returns:
        weights: Tensor(batchsize, num_mixtures)
        means: Tensor(batchsize, num_mixtures, 2*dim)
        cov: Tensor(batchsize, num_mixtures, 2*dim, 2*dim)
    """
    # 1. 混合权重
    weights = torch.rand(batchsize, num_mixtures, device=device) + 1e-6
    weights = weights / weights.sum(dim=-1, keepdim=True)

    # 2. 均值
    total_dim = 2 * dim
    means = torch.rand(batchsize, num_mixtures, total_dim, device=device) * 10 - 5

    # 3. k_values
    k_values = torch.randint(1, total_dim + 1, (batchsize * num_mixtures,), device=device)
    corr_strength = torch.full((batchsize * num_mixtures, 1), corr_strength_value, device=device)

    # 4. 相关矩阵
    corr_mats = random_corr_factor(
        batchsize * num_mixtures, total_dim, k_values,
        device=device, corr_strength=corr_strength,
        diag_strength=diag_strength_value
    )
    corr_mats = corr_mats.view(batchsize, num_mixtures, total_dim, total_dim)

    # 5. 标准差矩阵
    std = torch.sqrt(torch.rand(batchsize, num_mixtures, total_dim, device=device) * eigen_range + 0.01)
    SD = torch.diag_embed(std)

    # 6. 协方差
    cov = torch.einsum('bmij,bmjk,bmkl->bmil', SD, corr_mats, SD)
    cov += 1e-6 * torch.eye(total_dim, device=device)[None, None]

    return weights, means, cov

def is_positive_definite_cholesky(cov):
    try:
        # 尝试进行 Cholesky 分解
        torch.linalg.cholesky(cov)
        return True
    except RuntimeError:
        # 如果分解失败（例如矩阵不是正定的），抛出异常
        return False

def sample_multivariate_t(mean, cov, df, size=1, random_state=None):
    """
    从多元 Student t 分布采样：
        mean: (D,)
        cov:  (D,D)
        df:   自由度
        size: 样本数
    """
    rng = np.random.default_rng(random_state)
    D = mean.shape[0]

    # 1) 生成独立标准正态
    #print("====================", is_positive_definite_cholesky(torch.from_numpy(cov)))
    X = rng.multivariate_normal(np.zeros(D), cov, size=size)
    # 2) 生成卡方
    U = rng.chisquare(df, size=size)
    # 3) 缩放并加均值
    T = mean + X / np.sqrt(U / df)[:, None]
    return T

def sample_mixture_of_t(weights, means, covs, dfs, n_samples, random_state=None):
    """
    混合 Student t 采样
    weights: (K,)
    means:   list of (D,) 向量
    covs:    list of (D,D) 矩阵
    dfs:     list of 自由度
    返回:
      X: (n_samples, D)
      z: (n_samples,) 标签
    """
    rng = np.random.default_rng(random_state)
    K = len(weights)
    # 1) 先抽成分标签
    z = rng.choice(K, size=n_samples, p=weights)

    # 2) 针对每个组件批量采样
    X = np.zeros((n_samples, means[0].shape[0]))
    for k in range(K):
        idx = np.where(z == k)[0]
        if idx.size:
            X[idx] = sample_multivariate_t(
                mean=means[k],
                cov=covs[k],
                df=dfs[k],
                size=idx.size,
                random_state=rng
            )
    return X

def sample_skewed_df(K, a=0.2, b=10.0, h=0.5, seed=None):
    """
    K  : 样本数量
    a  : 最小值
    b  : 最大值
    h  : 幂指数 (0<h<1，越小越偏向小值)
    """
    if seed is not None:
        torch.manual_seed(seed)
    u = torch.rand(K)        # U(0,1)
    v = u ** h               # 拉向 0
    dfs = a + (b - a) * v    # 映射到 [a,b]
    return dfs



