import torch
import torch.nn as nn
from sklearn.mixture import GaussianMixture
import numpy as np

def obtain1(model, y_hat):
    model.eval()
    embed_feat_bank = y_hat
    
    global known_space_basis
    global unknown_space_basis
    
    src_cls_feat = model.fc.weight.data  # [C, D]
    u, s, vt = torch.linalg.svd(src_cls_feat.t())
    main_r = 200
    known_space_basis = u[:, :main_r].t()  # [C, D]
    known_space_basis = known_space_basis / torch.norm(known_space_basis, p=2, dim=-1, keepdim=True)  # [C, D]
    unknown_space_basis = u[:, main_r:].t()  # [D-C, D]
    unknown_space_basis = unknown_space_basis / torch.norm(unknown_space_basis, p=2, dim=-1, keepdim=True)  # [D-C, D]
    
    known_proj_cords = torch.einsum("nd, cd -> nc", embed_feat_bank, known_space_basis)  # [N, C]
    unknown_proj_cords = torch.einsum("nd, cd -> nc", embed_feat_bank, unknown_space_basis)  # [N, D-C]
    
    known_proj_feat = torch.einsum("nc, cd -> nd", known_proj_cords, known_space_basis)
    unknown_proj_feat = torch.einsum("nc, cd -> nd", unknown_proj_cords, unknown_space_basis)
    
    known_proj_norm = torch.norm(known_proj_cords, p=2, dim=-1)  # [N]
    unknown_proj_norm = torch.norm(unknown_proj_cords, p=2, dim=-1)  # [N]
    
    unknown_space_norm_gm = GaussianMixture(n_components=2, random_state=0).fit(unknown_proj_norm.detach().cpu().view(-1, 1).numpy())
    gaussian_two_mus = torch.tensor(unknown_space_norm_gm.means_).squeeze()
    
    gaussian_mu1 = torch.min(gaussian_two_mus)
    gaussian_mu2 = torch.max(gaussian_two_mus)

    return gaussian_mu1, gaussian_mu2


def obtain(model, y_hat):
    model.eval()
    embed_feat_bank = y_hat
    
    global known_space_basis
    global unknown_space_basis
    
    src_cls_feat = model.fc.weight.data  # [C, D]
    u, s, vt = torch.linalg.svd(src_cls_feat.t())
    main_r = 200
    known_space_basis = u[:, :main_r].t()  # [C, D]
    known_space_basis = known_space_basis / torch.norm(known_space_basis, p=2, dim=-1, keepdim=True)  # [C, D]
    unknown_space_basis = u[:, main_r:].t()  # [D-C, D]
    unknown_space_basis = unknown_space_basis / torch.norm(unknown_space_basis, p=2, dim=-1, keepdim=True)  # [D-C, D]
    
    known_proj_cords = torch.einsum("nd, cd -> nc", embed_feat_bank, known_space_basis)  # [N, C]
    unknown_proj_cords = torch.einsum("nd, cd -> nc", embed_feat_bank, unknown_space_basis)  # [N, D-C]
    
    known_proj_feat = torch.einsum("nc, cd -> nd", known_proj_cords, known_space_basis)
    unknown_proj_feat = torch.einsum("nc, cd -> nd", unknown_proj_cords, unknown_space_basis)
    
    known_proj_norm = torch.norm(known_proj_cords, p=2, dim=-1)  # [N]
    unknown_proj_norm = torch.norm(unknown_proj_cords, p=2, dim=-1)  # [N]
    
    # Normalize unknown_proj_norm to [0, 1]
    unknown_proj_norm_normalized = (unknown_proj_norm - unknown_proj_norm.min()) / (
        unknown_proj_norm.max() - unknown_proj_norm.min()
    )
    
    unknown_space_norm_gm = GaussianMixture(n_components=2, random_state=0).fit(
        unknown_proj_norm_normalized.detach().cpu().view(-1, 1).numpy()
    )
    gaussian_two_mus = torch.tensor(unknown_space_norm_gm.means_).squeeze()
    
    gaussian_mu1 = torch.min(gaussian_two_mus)
    gaussian_mu2 = torch.max(gaussian_two_mus)

    return gaussian_mu1, gaussian_mu2

def obtain_with_overlap(model, y_hat):
    model.eval()
    embed_feat_bank = y_hat
    
    global known_space_basis
    global unknown_space_basis
    
    src_cls_feat = model.fc.weight.data  # [C, D]
    u, s, vt = torch.linalg.svd(src_cls_feat.t())
    main_r = 200
    known_space_basis = u[:, :main_r].t()  # [C, D]
    known_space_basis = known_space_basis / torch.norm(known_space_basis, p=2, dim=-1, keepdim=True)  # [C, D]
    unknown_space_basis = u[:, main_r:].t()  # [D-C, D]
    unknown_space_basis = unknown_space_basis / torch.norm(unknown_space_basis, p=2, dim=-1, keepdim=True)  # [D-C, D]
    
    known_proj_cords = torch.einsum("nd, cd -> nc", embed_feat_bank, known_space_basis)  # [N, C]
    unknown_proj_cords = torch.einsum("nd, cd -> nc", embed_feat_bank, unknown_space_basis)  # [N, D-C]
    
    known_proj_feat = torch.einsum("nc, cd -> nd", known_proj_cords, known_space_basis)
    unknown_proj_feat = torch.einsum("nc, cd -> nd", unknown_proj_cords, unknown_space_basis)
    
    known_proj_norm = torch.norm(known_proj_cords, p=2, dim=-1)  # [N]
    unknown_proj_norm = torch.norm(unknown_proj_cords, p=2, dim=-1)  # [N]
    
    # Normalize unknown_proj_norm to [0, 1]
    unknown_proj_norm_normalized = (unknown_proj_norm - unknown_proj_norm.min()) / (
        unknown_proj_norm.max() - unknown_proj_norm.min()
    )
    
    # Fit GaussianMixture model
    unknown_space_norm_gm = GaussianMixture(n_components=2, random_state=0).fit(
        unknown_proj_norm_normalized.detach().cpu().view(-1, 1).numpy()
    )
    
    # Extract Gaussian parameters
    gaussian_means = unknown_space_norm_gm.means_.squeeze()
    gaussian_covariances = unknown_space_norm_gm.covariances_.squeeze()
    
    # Determine overlap region
    pdf_1 = lambda x: (1 / (np.sqrt(2 * np.pi * gaussian_covariances[0]))) * np.exp(
        -0.5 * ((x - gaussian_means[0]) ** 2) / gaussian_covariances[0]
    )
    pdf_2 = lambda x: (1 / (np.sqrt(2 * np.pi * gaussian_covariances[1]))) * np.exp(
        -0.5 * ((x - gaussian_means[1]) ** 2) / gaussian_covariances[1]
    )
    
    # Compute probabilities for normalized values
    normalized_vals = unknown_proj_norm_normalized.detach().cpu().numpy()
    unknown_proj_probs = pdf_1(normalized_vals) + pdf_2(normalized_vals)
    
    # Define threshold for overlap (e.g., where both PDFs contribute significantly)
    overlap_threshold = np.percentile(unknown_proj_probs, 90)  # Top 10% as overlap
    overlap_indices = np.where(unknown_proj_probs > overlap_threshold)[0]
    
    # mu1 是最小的均值，mu2 是最大的均值
    mu1 = torch.min(torch.tensor(gaussian_means)).item()
    mu2 = torch.max(torch.tensor(gaussian_means)).item()

    return mu1, mu2, overlap_indices

def obtain_with_indices1(model, y_hat):


    model.eval()
    embed_feat_bank = y_hat

    
    # 获取模型分类层的权重
    src_cls_feat = model.fc.weight.data  # [C, D]
    u, s, vt = torch.linalg.svd(src_cls_feat.t())
    main_r = 200
    
    # 计算已知和未知空间的基向量
    known_space_basis = u[:, :main_r].t()  # [C, D]
    known_space_basis = known_space_basis / torch.norm(known_space_basis, p=2, dim=-1, keepdim=True)  # [C, D]
    unknown_space_basis = u[:, main_r:].t()  # [D-C, D]
    unknown_space_basis = unknown_space_basis / torch.norm(unknown_space_basis, p=2, dim=-1, keepdim=True)  # [D-C, D]
    
    # 计算投影到已知和未知空间的坐标
    known_proj_cords = torch.einsum("nd, cd -> nc", embed_feat_bank, known_space_basis)  # [N, C]
    unknown_proj_cords = torch.einsum("nd, cd -> nc", embed_feat_bank, unknown_space_basis)  # [N, D-C]
    
    # 还原已知和未知空间的投影特征
    known_proj_feat = torch.einsum("nc, cd -> nd", known_proj_cords, known_space_basis)
    unknown_proj_feat = torch.einsum("nc, cd -> nd", unknown_proj_cords, unknown_space_basis)
    
    # 计算已知和未知投影的范数
    known_proj_norm = torch.norm(known_proj_cords, p=2, dim=-1)  # [N]
    unknown_proj_norm = torch.norm(unknown_proj_cords, p=2, dim=-1)  # [N]
    
    # 归一化未知空间的范数
    unknown_proj_norm_normalized = (unknown_proj_norm - unknown_proj_norm.min()) / (
        unknown_proj_norm.max() - unknown_proj_norm.min()
    )
    
    # 使用高斯混合模型拟合未知空间的归一化范数
    unknown_space_norm_gm = GaussianMixture(n_components=2, random_state=0).fit(
        unknown_proj_norm_normalized.detach().cpu().view(-1, 1).numpy()
    )
    gaussian_two_mus = torch.tensor(unknown_space_norm_gm.means_).squeeze()
    gaussian_mu1 = torch.min(gaussian_two_mus)
    gaussian_mu2 = torch.max(gaussian_two_mus)
    
    # 拟合高斯分布的密度函数
    densities = np.linspace(0, 1, 1000)  # 密度范围划分
    gmm_densities = unknown_space_norm_gm.score_samples(densities.reshape(-1, 1))  # 计算密度
    gmm_density_values = np.exp(gmm_densities)  # 转换为概率值
    
    # 找到分布末尾处于 0 的位置
    threshold1_idx = np.argmax(gmm_density_values[:500] < 1e-3)  # 第一个分布的末尾接近0
    threshold2_idx = 500 + np.argmax(gmm_density_values[500:] < 1e-3)  # 第二个分布的末尾接近0
    threshold1 = densities[threshold1_idx]
    threshold2 = densities[threshold2_idx]
    
    # 获取样本索引
    indices_region1 = (unknown_proj_norm_normalized >= 0) & (unknown_proj_norm_normalized <= threshold1)
    indices_region2 = (unknown_proj_norm_normalized > threshold1) & (unknown_proj_norm_normalized <= threshold2)
    indices_region3 = (unknown_proj_norm_normalized > threshold2) & (unknown_proj_norm_normalized <= 1)
    
    region1_indices = torch.nonzero(indices_region1, as_tuple=True)[0]
    region2_indices = torch.nonzero(indices_region2, as_tuple=True)[0]
    region3_indices = torch.nonzero(indices_region3, as_tuple=True)[0]

    print(region1_indices)
    
    return region1_indices, region2_indices, region3_indices

def calculate_entropy(probabilities):
    """计算每个样本的熵值"""
    return -torch.sum(probabilities * torch.log(probabilities + 1e-12), dim=-1)

def obtain_with_indices(model, feature, probabilities):
    
    model.eval()
    embed_feat_bank = feature
    
    global known_space_basis
    global unknown_space_basis
    
    pred_cls_bank = probabilities
    # 获取模型分类层的权重
    src_cls_feat = model.fc.weight.data  # [C, D]
    u, s, vt = torch.linalg.svd(src_cls_feat.t())
    main_r = 200
    
    # 计算已知和未知空间的基向量
    known_space_basis = u[:, :main_r].t()  # [C, D]
    known_space_basis = known_space_basis / torch.norm(known_space_basis, p=2, dim=-1, keepdim=True)  # [C, D]
    unknown_space_basis = u[:, main_r:].t()  # [D-C, D]
    unknown_space_basis = unknown_space_basis / torch.norm(unknown_space_basis, p=2, dim=-1, keepdim=True)  # [D-C, D]
    
    # 计算投影到已知和未知空间的坐标
    known_proj_cords = torch.einsum("nd, cd -> nc", embed_feat_bank, known_space_basis)  # [N, C]
    unknown_proj_cords = torch.einsum("nd, cd -> nc", embed_feat_bank, unknown_space_basis)  # [N, D-C]
    
    # 还原已知和未知空间的投影特征
    known_proj_feat = torch.einsum("nc, cd -> nd", known_proj_cords, known_space_basis)
    unknown_proj_feat = torch.einsum("nc, cd -> nd", unknown_proj_cords, unknown_space_basis)
    
    # 计算已知和未知投影的范数
    known_proj_norm = torch.norm(known_proj_cords, p=2, dim=-1)  # [N]
    unknown_proj_norm = torch.norm(unknown_proj_cords, p=2, dim=-1)  # [N]
    
    known_proj_norm_expand = known_proj_norm.unsqueeze(0).expand([200, -1]) #[C, N]
    unknown_proj_norm_expand = unknown_proj_norm.unsqueeze(0).expand([200, -1]) #[C, N]

    # 归一化未知空间的范数
    unknown_proj_norm_normalized = (unknown_proj_norm - unknown_proj_norm.min()) / (
        unknown_proj_norm.max() - unknown_proj_norm.min()
    )
    
    # 使用高斯混合模型拟合未知空间的归一化范数
    unknown_space_norm_gm = GaussianMixture(n_components=2, random_state=0).fit(
        unknown_proj_norm_normalized.detach().cpu().view(-1, 1).numpy()
    )
    gaussian_two_mus = torch.tensor(unknown_space_norm_gm.means_).squeeze()
    
    gaussian_mu1 = torch.min(gaussian_two_mus)
    gaussian_mu2 = torch.max(gaussian_two_mus)


    # # target prototype construction 
    # embed_feat_bank_expand = embed_feat_bank.unsqueeze(0).expand([200, -1, -1]) #[C, N, D]
    # sorted_pred_cls, sorted_pred_cls_idxs = torch.sort(pred_cls_bank, dim=0, descending=True)
    # pos_topk_idxs = sorted_pred_cls_idxs[:2, :].t() #[C, pos_topk_num]
    # pos_topk_idxs_feat_expand = pos_topk_idxs.unsqueeze(2).expand([-1, -1, 256]) #[C, pos_topk_num, D]
    # pos_feat_sample = torch.gather(embed_feat_bank_expand, 1, pos_topk_idxs_feat_expand) #[C, pos_topk_num, D]
    
    # tar_pos_feat_proto = torch.mean(pos_feat_sample, dim=1) #[C, D]
    # tar_pos_feat_proto = tar_pos_feat_proto / torch.norm(tar_pos_feat_proto, p=2, dim=-1, keepdim=True) #[C, D]
    
    # # source anchors construction 
    # src_pos_feat_proto = model.fc.weight.data / torch.norm(model.fc.weight.data, p=2, dim=-1, keepdim=True) #[C, D]
    
    
    # tar_psd_pos_feat_simi = torch.einsum("nd, cd -> nc", embed_feat_bank, tar_pos_feat_proto) #[N, C]
    # tar_psd_pos_feat_simi = torch.clamp(tar_psd_pos_feat_simi, min=0.0)
    
    # src_psd_pos_feat_simi = torch.einsum("nd, cd -> nc", embed_feat_bank, src_pos_feat_proto) #[N, C]
    # src_psd_pos_feat_simi = torch.clamp(src_psd_pos_feat_simi, min=0.0)


    # # per sample common score
    # per_sample_fuse_common_score = torch.sqrt((1.0 - torch.exp(-tar_psd_pos_feat_simi)) * (torch.exp(src_psd_pos_feat_simi - 1.0)))

    # # Instance-level decision boundaries.
    # per_sample_per_cls_thresh = torch.zeros_like(pred_cls_bank) #[N, C]
    # per_cls_norm_prior = torch.mean(torch.gather(unknown_proj_norm_expand, dim=1, index=pos_topk_idxs), dim=1,).unsqueeze(0) #[1, C]
    # per_sample_per_cls_thresh = per_sample_per_cls_thresh + per_cls_norm_prior#[N, C]
    # per_cls_thresh_gap = torch.clamp(gaussian_mu2 - per_cls_norm_prior, min=0.0) #[1, C]
    # per_sample_per_cls_thresh = per_sample_per_cls_thresh + per_sample_fuse_common_score * per_cls_thresh_gap

    # # Obtain psuedo-labels
    # psd_label = torch.argmax(per_sample_fuse_common_score, dim=-1)
    # psd_label_weight = torch.ones_like(psd_label).float()    
    # psd_label_oh = psd_label.clone()
    # for i in range(200):
    #     label_idxs = torch.where(psd_label == i)[0]
        
    #     alpha = 1e-4
    #     psd_label[label_idxs] = torch.where(unknown_proj_norm[label_idxs] >= per_sample_per_cls_thresh[label_idxs, i], 200, psd_label[label_idxs])


    # print(psd_label)
    # print(gaussian_mu1,gaussian_mu2)
    # 获取样本索引
    indices_region1 = (unknown_proj_norm_normalized >= 0) & (unknown_proj_norm_normalized <= gaussian_mu1)
    indices_region2 = (unknown_proj_norm_normalized > gaussian_mu1) & (unknown_proj_norm_normalized <= gaussian_mu2)
    indices_region3 = (unknown_proj_norm_normalized > gaussian_mu2) & (unknown_proj_norm_normalized <= 1)
    
    region1_indices = torch.nonzero(indices_region1, as_tuple=True)[0]
    region2_indices = torch.nonzero(indices_region2, as_tuple=True)[0]
    region3_indices = torch.nonzero(indices_region3, as_tuple=True)[0]
    
    # 计算样本的熵值
    # entropies = calculate_entropy(probabilities)  # [N]
    
    # print("region1_indices",region1_indices)
    # 过滤样本：region1中熵值小于0.3，region3中熵值大于0.7
    # region1_indices = region1_indices[entropies[region1_indices] < 0.5]
    # region3_indices = region3_indices[entropies[region3_indices] > 0.5]
    
    # print("region1_indices",region1_indices)

    # 返回高斯分布的均值和样本索引
    return region1_indices, region2_indices, region3_indices