"""
�������
C�� �����ܶȣ� ƽ���ܶ���C log n /n����
tau: latent variable z_i ~ N(mu(label(i)), tau^2 I)
gamma: community size
radius: n-classed dim, control the density in different communities.
"""


import copy
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
from torch import optim
import torch.nn.functional as F
from load import get_P, get_Pd, get_W_lg
import random
from controlsnr import find_a_given_snr
from scipy.sparse import csr_matrix, save_npz
def simple_collate_fn(batch):
    """
    ƴ�� batch������������ͼ��Сһ�µ������
    ���أ�
      - adj: [B, N, N]
      - labels: [B, N]
    """
    adjs = [torch.tensor(sample['adj'].toarray(), dtype=torch.float32) for sample in batch]
    labels = [torch.tensor(sample['labels'], dtype=torch.long) for sample in batch]

    adj_batch = torch.stack(adjs)       # [B, N, N]
    label_batch = torch.stack(labels)   # [B, N]

    return {
        'adj': adj_batch,
        'labels': label_batch
    }

# def midpoints(seq):
#     return [(a + b) / 2 for a, b in zip(seq[:-1], seq[1:])]
#
# # ����ģ�͵Ĳ�������(1500)
# tau_train   = [0, 0, 0.25, 0.5]                # 3
# gamma_train = [0.30, 1.20, 3.00, 5.00]        # 3
# C_train     = [ 20, 30, 40, 50]        # 3
# latent_dim = 4
#
# radii_train = [
#     [0.55,0.60,0.58,0.62],
#     [0.90,0.70,0.50,0.30],
#     [0.80,0.75,0.35,0.40],
# ]
# norm_train = [0.25, 0.5, 0.75 , 1]           # 5
# per_cell_tr = 5
#
#
# C_train_sbm = np.logspace(np.log10(5), np.log10(45), 10)
# SNR_train_sbm = np.logspace(np.log10(0.5), np.log10(4), 10)
# C_val = np.sqrt(C_train_sbm[:-1] * C_train_sbm[1:])
# SNR_val = np.sqrt(SNR_train_sbm[:-1] * SNR_train_sbm[1:])
# gamma_val = np.array([0.30, 1.20, 3.00, 5.00])
# per_cell_v = 2


# # ���� ���Լ���=12��
# tau_test = [0.60, 0.80, 1.00]
# gamma_test = [0.15, 0.60, 1.50, 3.00]
# norm_test = [0.025, 0.1]
# radii_test = [
#     [1.2, 0.8, 0.8,0.6,0.6],
#     [1, 1, 1,1,1]
# ]
# # radii_test = [
# #     [1, 0.6, 0.2],
# #     [0.8, 0.8, 0.5],
# #     [1, 0.4, 0]
# # ]
# C_test = (10.0,)
# per_cell_te = 1


class Generator(object):
    def __init__(self, N_train=50, N_test=100, N_val=50, generative_model='LSM_multiclass',
                     p_SBM=0.8, q_SBM=0.2, n_classes=2, path_dataset='dataset',
                     num_examples_train=100, num_examples_test=10, num_examples_val=10,
                     # 添加 LSM 参数
                     lsm_C=20.0, lsm_norm_mu=0.5, lsm_radii=None):
            # 原有参数
        self.N_train = N_train
        self.N_test = N_test
        # ... 其他原有参数 ...

        # 添加 LSM 参数
        self.lsm_tau = 0.25
        self.lsm_alpha_std = 1
        self.lsm_latent_dim = 4

        # self.lsm_gamma = lsm_gamma
        self.lsm_C = lsm_C
        self.lsm_norm_mu = lsm_norm_mu
        self.lsm_radii = lsm_radii if lsm_radii is not None else [1.0] * 4


        self.N_train = N_train
        self.N_test = N_test
        self.N_val = N_val

        self.generative_model = generative_model
        self.p_SBM = p_SBM
        self.q_SBM = q_SBM
        self.n_classes = n_classes
        self.path_dataset = path_dataset

        self.data_train = None
        self.data_test = None
        self.data_val = None

        self.num_examples_train = num_examples_train
        self.num_examples_test = num_examples_test
        self.num_examples_val = num_examples_val

        self.fixed_class_sizes = [
            (500, 500),
            (400, 600),
            (300, 700),
            (200, 800),
            (100, 900),
            (50, 950)
        ]


    def gen_one_lsm(self, class_sizes, C, norm_mu, radii, is_training=True,iter=1):
        """
        核心 LSM 生成函数
        class_sizes: 类别大小列表
        C: 平均度参数
        norm_mu: 潜在空间范数参数
        radii: 半径参数列表
        """
        if is_training:
            N = self.N_train
        else:
            N = self.N_test

        # 调用您的 LSM 生成逻辑
        max_class_size = max(class_sizes) if class_sizes else 0
        seed = self.n_classes * 100000 + max_class_size * 10000 + int(C * 1000) + int(norm_mu * 100) + iter*10

        # 初始化随机数生成器
        rng = np.random.default_rng(seed)

        class_sizes, W_sparse, labels, eigvecs_top, alpha_bar, B = self._gen_lsm_internal(
            N=N, n_classes=self.n_classes, class_sizes=class_sizes,
            C=C, norm_mu=norm_mu, tau=0.25, latent_dim=self.n_classes, radii=radii,rng=rng
        )

        # 转换为适合返回的格式
        W_dense = W_sparse.toarray()
        labels = np.expand_dims(labels, 0)
        labels = torch.from_numpy(labels)
        W = np.expand_dims(W_dense, 0)

        if torch.cuda.is_available():
            W = torch.tensor(W, dtype=torch.float32).cuda()
        else:
            W = torch.tensor(W, dtype=torch.float32)

        return W, labels, B



    # def imbalanced_sample_otf_single(self, class_sizes, is_training=True, cuda=True):
    #     """
    #     为兼容性保留的接口
    #     """
    #     # 使用固定的 C、norm_mu、radii，或者从其他地方获取
    #     C = 20.0  # 固定值或从配置获取
    #     norm_mu = 0.5  # 固定值
    #     radii = [1.0, 1.0, 1.0, 1.0]  # 固定值
    #
    #     return self.gen_one_lsm(class_sizes, C, norm_mu, radii, is_training)
    #
    # def sample_otf_single(self, is_training=True, cuda=True):
    #     """
    #     为兼容性保留的接口 - 生成平衡图
    #     """
    #     N = self.N_train if is_training else self.N_test
    #     class_sizes = [N // self.n_classes] * self.n_classes
    #     remainder = N % self.n_classes
    #     for i in range(remainder):
    #         class_sizes[i] += 1
    #
    #     return self.imbalanced_sample_otf_single(class_sizes, is_training, cuda)

    # def prepare_data(self):
    #     def get_npz_dataset(path, mode, *, tau_grid, gamma_grid,C_grid,  norm_grid, radii_grid, latent_dim, per_cell, min_size=50, base_seed=0):
    #         if not os.path.exists(path):
    #             os.makedirs(path)
    #             print(f"[创建数据集] {mode} 数据目录不存在，已新建：{path}")
    #
    #         npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
    #         if not npz_files:
    #             print(f"[创建数据集] {mode} 数据未找到，开始生成...")
    #             self.create_dataset_grid(
    #                 path, mode=mode,
    #                 tau_grid=tau_grid,
    #                 gamma_grid=gamma_grid,
    #                 C_grid = C_grid,
    #                 norm_grid = norm_grid,
    #                 radii_grid=radii_grid,
    #                 latent_dim = latent_dim,
    #                 per_cell=per_cell,
    #                 min_size=min_size,
    #                 base_seed=base_seed
    #             )
    #             npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
    #         else:
    #             print(f"[读取数据] {mode} 集已存在，共 {len(npz_files)} 张图：{path}")
    #
    #         return [os.path.join(path, f) for f in npz_files]
    #
    #     # ==== 目录 ====
    #     train_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gstr{self.N_train}_numtr{self.num_examples_train}"
    #     test_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gste{self.N_test}_numte{self.num_examples_test}"
    #     val_dir = f"{self.generative_model}_nc{self.n_classes}_rand_val{self.N_val}_numval{self.num_examples_val}"
    #
    #     train_path = os.path.join(self.path_dataset, train_dir)
    #     test_path = os.path.join(self.path_dataset, test_dir)
    #     val_path = os.path.join(self.path_dataset, val_dir)
    #
    #     # ==== 采用上面的三套参数 ====
    #     self.data_train = get_npz_dataset(
    #         train_path, 'train',
    #         tau_grid=tau_train, gamma_grid=gamma_train, C_grid = C_train,norm_grid = norm_train,   radii_grid=radii_train, latent_dim = latent_dim, per_cell=per_cell_tr,
    #         min_size=50, base_seed=123
    #     )
    #     # self.data_val = get_npz_dataset(
    #     #     val_path, 'val',
    #     #     tau_grid=tau_val, gamma_grid=gamma_val, C_grid = C_val, norm_grid = norm_val,  radii_grid=radii_val, latent_dim = latent_dim, per_cell=per_cell_v,
    #     #     min_size=50, base_seed=2025
    #     # )
    #     # self.data_test = get_npz_dataset(
    #     #     test_path, 'test',
    #     #     tau_grid=tau_test, gamma_grid=gamma_test, C_grid = C_test, norm_grid = norm_test, radii_grid=radii_test, latent_dim = latent_dim, per_cell=per_cell_te,
    #     #     min_size=50, base_seed=31415
    #     # )
    #     if not os.path.exists(val_path):
    #         os.makedirs(val_path)
    #         print(f"[创建验证集] 使用 SBM 生成验证集：{val_path}")
    #
    #         # 初始化 SBM 生成器
    #         sbm_gen = SBMGenerator(
    #             N_train=self.N_train,
    #             N_test=self.N_train,
    #             N_val=self.N_train,
    #             generative_model='SBM_multiclass',
    #             n_classes=self.n_classes,
    #             path_dataset=self.path_dataset,
    #             num_examples_train=0,
    #             num_examples_test=0,
    #             num_examples_val=len(gamma_val) * len(C_val) * per_cell_v  # 根据需要调整数量
    #         )
    #
    #         # 使用 SBM 生成验证集
    #         sbm_gen.create_dataset_grid(
    #             directory=val_path,
    #             mode='val',
    #             snr_grid=SNR_val,  # 使用 SBM 的验证集 SNR
    #             gamma_grid=gamma_val,  # 使用 SBM 的验证集 gamma
    #             C_grid=C_val,  # 使用 SBM 的验证集 C
    #             per_cell=per_cell_v,  # 每个网格点生成数量
    #             min_size=50,
    #             base_seed=2025
    #         )
    #     # 加载验证集路径
    #     self.data_val = [os.path.join(val_path, f) for f in os.listdir(val_path) if f.endswith('.npz')]
    #     print(f"[读取验证集] 共 {len(self.data_val)} 张图：{val_path}")


    # 根据之前的参数，生成标签、概率连接矩阵、邻接矩阵
    def _gen_lsm_internal(
            self, N, n_classes, class_sizes, C,norm_mu,tau,  latent_dim=2, radii=[1,1,1,1,1], rng=None,
    ):

        import numpy as np
        from numpy.random import multivariate_normal
        from scipy.special import expit as sigmoid


        # 在 d 维空间中，通过等角构造 K 个向量,代表mu1, mu2, ... mu K的方向
        # 按照半径进行缩放，从而控制p1, p2, ... p_k的大小不同
        def simplex_mu_dense(K, d, radii=None, rng=None):
            if rng is None:
                rng = np.random.default_rng()
            if d < K - 1:
                raise ValueError("需要 d >= K-1")

            # Step 1: 构造标准 simplex (K x (K-1))
            S = np.eye(K) - np.ones((K, K)) / K
            vals, vecs = np.linalg.eigh(S)
            idx = np.argsort(vals)[::-1][:K - 1]
            basis = vecs[:, idx]  # K x (K-1)
            coords = S @ basis
            coords = coords / np.linalg.norm(coords, axis=1, keepdims=True)

            # Step 2: 随机正交变换，嵌入到 R^d
            A = rng.normal(size=(d, K - 1))
            Q, _ = np.linalg.qr(A)  # d x (K-1)，正交列
            mu = coords @ Q.T  # (K, d)

            # Step 3: 半径缩放
            if radii is not None:
                radii = np.asarray(radii)
                if radii.shape[0] < K:
                    raise ValueError("radii 长度至少需为 K")
                radii = radii[:K]
                mu = mu * radii[:, None]
            return mu

        # class_sizes = self._sample_class_sizes_dirichlet(N, n_classes, gamma, min_size, rng)
        boundaries = np.cumsum([0] + class_sizes)

        # 生成并置乱标签
        labels = np.zeros(N, dtype=int)
        for i in range(n_classes):
            start, end = boundaries[i], boundaries[i + 1]
            labels[start:end] = i
        # 打乱节点顺序
        perm = rng.permutation(N)
        labels = labels[perm]

        rng = np.random.default_rng() if rng is None else rng
        # === 1) 生成mu_1,... mu_k===
        """
        思路：椭圆上的随机分布，可以保证社区内连接概率比社区间高，且社区内能有区别
        也可以固定夹角生成数据
        """
        mu = np.zeros((n_classes, latent_dim))
        # if n_classes == 2:
        #     mu[0] = radii[:latent_dim]
        #     mu[1] = [x * (-1) for x in mu[0] ]
        # else:
        radii = np.array(radii)
        scaling_factor = np.sqrt(norm_mu) / np.max(radii)
        radii_scaled = radii * scaling_factor
        mu = simplex_mu_dense(n_classes, latent_dim, radii_scaled, rng)


        # === 2) 生成z_i ===
        z_s = np.zeros((N, latent_dim))
        for k in range(n_classes):
            start, end = boundaries[k], boundaries[k + 1]
            z_ks = rng.multivariate_normal(mu[k], tau**2 * np.eye(latent_dim), size=class_sizes[k])
            z_s[start:end] = z_ks
        z_s = z_s[perm]


        # === 3) 生成alpha_i ===
        logn = np.log(N)
        alpha_bar = np.log(logn / N * C/n_classes) * 0.5
        alpha_s = rng.normal(alpha_bar, 1, size=N)

        #生成alpha的方差参数设置为1

        # === 4) 计算连接概率矩阵 ===
        # 创建所有向量对的外积
        alpha_matrix =np.add.outer(alpha_s, alpha_s)  # 所有alpha_i + alpha_j
        # 计算所有z向量的点积矩阵
        z_dot_matrix = np.einsum('ik,jk->ij', z_s, z_s)  # 矩阵乘法得到所有点积
        # 计算所有对的logit
        logit_matrix = alpha_matrix + z_dot_matrix
        # 应用sigmoid，然后设置对角线为0
        connectivity_mat = sigmoid(logit_matrix)
        np.fill_diagonal(connectivity_mat, 0)  # 移除自环

        eigvals, eigvecs = np.linalg.eigh(connectivity_mat)
        idx = np.argsort(eigvals)[::-1]
        eigvecs_top = eigvecs[:, idx[:n_classes]]

        # === 5) 生成邻接矩阵 W  ===
        # 使用稀疏矩阵直接生成
        from scipy.sparse import random as sparse_random
        from scipy.sparse import lil_matrix

        W_sparse = lil_matrix((N, N), dtype=np.int8)
        for i in range(N):
            for j in range(i + 1, N):
                if rng.random() < connectivity_mat[i, j]:
                    W_sparse[i, j] = 1
        W_sparse = W_sparse + W_sparse.T

        # === 6) 计算概率连接矩阵的期望 B，exp(2 alpha + zi'zj)  ===
        mu_prod = mu @ mu.T
        B = sigmoid(mu_prod + 2*alpha_bar)

        return  class_sizes, W_sparse, labels, eigvecs_top, alpha_bar, B

    # def imbalanced_lsm_sample_otf_single(self, class_sizes, tau=None, gamma=None, C=None,
    #                                      norm_mu=None, alpha_std=None, latent_dim=None,
    #                                      radii=None, is_training=True, cuda=True):
    #     if is_training:
    #         N = self.N_train
    #     else:
    #         N = self.N_test
    #
    #     # 使用传入参数或默认参数
    #     C = C if C is not None else self.lsm_C
    #     norm_mu = norm_mu if norm_mu is not None else self.lsm_norm_mu
    #     radii = radii if radii is not None else self.lsm_radii
    #
    #     # 调用 LSM 生成方法
    #     class_sizes, gamma, W_sparse, labels, eigvecs_top, alpha_bar, B = self.gen_one_lsm_by_targets(
    #         N=N, n_classes=self.n_classes, class_sizes=class_sizes,
    #         C=C, norm_mu=norm_mu,
    #         latent_dim=latent_dim, radii=radii,  rng=np.random.default_rng()
    #     )
    #
    #     # 转换为稠密矩阵和 torch tensor
    #     W_dense = W_sparse.toarray()
    #     labels = np.expand_dims(labels, 0)
    #     labels = torch.from_numpy(labels)
    #     W = np.expand_dims(W_dense, 0)
    #
    #     if cuda and torch.cuda.is_available():
    #         W = torch.tensor(W, dtype=torch.float32).cuda()
    #     else:
    #         W = torch.tensor(W, dtype=torch.float32)
    #
    #     return W, labels, eigvecs_top

    # def create_dataset_grid(self, directory, mode='train', *,
    #                         tau_grid = (0.2,0.4,0.6,0.8),
    #                         gamma_grid=(0.15, 0.3, 0.6, 1.0, 2.0),
    #                         C_grid = [5,10,15],
    #                         norm_grid = [0.1, 0.2],
    #                         radii_grid = [
    #                             [1,1,1,1,1],
    #                             [1.2,1.1, 1, 0.9, 0.8]
    #                         ],
    #                         latent_dim = 5,
    #                         per_cell=20,
    #                         min_size=5,
    #                         base_seed=0):
    #     """
    #     在 (SNR × gamma × C) 的笛卡尔网格上生成数据；每个网格点生成 per_cell 张图。
    #     文件名写入网格信息；.npz 中保存全部元数据，保证复现性与分析便利。
    #     """
    #
    #     os.makedirs(directory, exist_ok=True)
    #
    #     if mode == 'train':
    #         N = self.N_train
    #         num_graphs_expected = len(tau_grid) * len(gamma_grid) *len(C_grid)* len(norm_grid)*len(radii_grid) * per_cell
    #         self.data_train = directory
    #     elif mode == 'val':
    #         N = self.N_val
    #         num_graphs_expected = len(tau_grid) * len(gamma_grid)*len(C_grid) * len(norm_grid)* len(radii_grid) * per_cell
    #         self.data_val = directory
    #     elif mode == 'test':
    #         N = self.N_test
    #         num_graphs_expected = len(tau_grid) * len(gamma_grid) *len(C_grid)* len(radii_grid) * per_cell
    #         self.data_test = directory
    #     else:
    #         raise ValueError(f"Unsupported mode: {mode}")
    #
    #     from concurrent.futures import ProcessPoolExecutor
    #     import itertools
    #
    #     idx = 0
    #     for t_idx, tau in enumerate(tau_grid):
    #         for r_idx, radii in enumerate(radii_grid):
    #             for C_idx, C in enumerate(C_grid):
    #                 norm_grdi_for_C = norm_grid[C_idx]
    #                 for g_idx, gamma in enumerate(gamma_grid):
    #                     for n_idx, norm_mu in enumerate(norm_grdi_for_C):
    #                         # 每个格点单独的 RNG，保证可复现
    #                         cell_seed = base_seed + (t_idx * 10_000_000
    #                                                  + r_idx * 10_000
    #                                                  + n_idx *1000
    #                                                  + C_idx * 100
    #                                                  + g_idx * 10)
    #                         rng = np.random.default_rng(cell_seed)
    #
    #                         for rep in range(per_cell):
    #
    #                             class_sizes, gamma, W_sparse, labels,eigvecs_top, alpha_bar,B = self.gen_one_lsm_by_targets(
    #                                 N=N, n_classes=self.n_classes,   class_sizes=class_sizes, gamma = gamma, C=C,  alpha_std = 0, norm_mu = norm_mu, min_size = min_size, latent_dim = latent_dim,
    #                                 radii = radii, tau = tau, rng = rng)
    #                             # W_sparse = csr_matrix(W_dense)
    #
    #                             fname = (f"{mode}_i{idx:05d}"
    #                                     f"__dim{latent_dim}d"
    #                                      f"__tau{tau:.2f}"
    #                                      f"__g{gamma:.3f}__rep{rep:02d}.npz")
    #                             path = os.path.join(directory, fname)
    #
    #                             np.savez_compressed(
    #                                 path,
    #                                 adj_data=W_sparse.data,
    #                                 adj_indices=W_sparse.indices,
    #                                 adj_indptr=W_sparse.indptr,
    #                                 adj_shape=W_sparse.shape,
    #                                 labels=labels,
    #                                 tau =  tau,
    #                                 radii = radii,
    #                                 gamma=gamma,
    #                                 C = C,
    #                                 alpha = alpha_bar,
    #                                 alpha_std = 1,
    #                                 connect_expected = B,
    #                                 class_sizes=np.array(class_sizes, dtype=np.int32),
    #                                 eigvecs_top=eigvecs_top  # 若体积大可去掉
    #                             )
    #                             idx += 1
    #
    #                             if tau ==0: #加的是SBM
    #                                 class_sizes, gamma, W_sparse, labels, eigvecs_top, alpha_bar, B = self.gen_one_lsm_by_targets(
    #                                     N=N, n_classes=self.n_classes, gamma=gamma, C=C, alpha_std=1, norm_mu=norm_mu,
    #                                     min_size=min_size, latent_dim=latent_dim,
    #                                     radii=radii, tau=tau, rng=rng)
    #                                 # W_sparse = csr_matrix(W_dense)
    #
    #                                 fname = (f"{mode}_i{idx:05d}"
    #                                          f"_SBM__dim{latent_dim}d"  # 在dim前添加_DCBM
    #                                          f"__tau{tau:.2f}"
    #                                          f"__g{gamma:.3f}__rep{rep:02d}.npz")
    #                                 path = os.path.join(directory, fname)
    #
    #                                 np.savez_compressed(
    #                                     path,
    #                                     adj_data=W_sparse.data,
    #                                     adj_indices=W_sparse.indices,
    #                                     adj_indptr=W_sparse.indptr,
    #                                     adj_shape=W_sparse.shape,
    #                                     labels=labels,
    #                                     tau=tau,
    #                                     radii=radii,
    #                                     gamma=gamma,
    #                                     C=C,
    #                                     alpha=alpha_bar,
    #                                     alpha_std=0,
    #                                     connect_expected=B,
    #                                     class_sizes=np.array(class_sizes, dtype=np.int32),
    #                                     eigvecs_top=eigvecs_top  # 若体积大可去掉
    #                                 )
    #
    #
    #     print(f"[{mode}] 网格数据完成: 共 {idx} 张（期望 {num_graphs_expected}）。目录: {directory}")

    def copy(self):
        return copy.deepcopy(self)