# -*- coding: utf-8 -*-
"""
�������
C�� �����ܶȣ� ƽ���ܶ���C log n /n����
tau: latent variable z_i ~ N(mu(label(i)), tau^2 I)
gamma: community size
radius: n-classed dim, control the density in different communities.

new version: adjust the angle by angel between mu
"""


import copy
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
from torch import optim
import torch.nn.functional as F
from load import get_P, get_Pd, get_W_lg
import random
from controlsnr import find_a_given_snr
from scipy.sparse import csr_matrix, save_npz
from data_generator import Generator as SBMGenerator

def simple_collate_fn(batch):
    """
    ƴ�� batch������������ͼ��Сһ�µ������
    ���أ�
      - adj: [B, N, N]
      - labels: [B, N]
    """
    adjs = [torch.tensor(sample['adj'].toarray(), dtype=torch.float32) for sample in batch]
    labels = [torch.tensor(sample['labels'], dtype=torch.long) for sample in batch]

    adj_batch = torch.stack(adjs)       # [B, N, N]
    label_batch = torch.stack(labels)   # [B, N]

    return {
        'adj': adj_batch,
        'labels': label_batch
    }

def midpoints(seq):
    return [(a + b) / 2 for a, b in zip(seq[:-1], seq[1:])]

tau_train   = [ 0.25, 0.5, 0]
C_train = np.logspace(np.log10(1.5), np.log10(3.5), 8)
snr_train = np.logspace(np.log10(1.2),np.log10(3),10)
# snr_train = [1]
latent_dim = 2
gamma_train = [ 0.3, 1.2, 3, 5]
radii_train = [
    [1.2,  1],
    # [1, 1]

]
per_cell_tr = 4


C_mid = np.sqrt(C_train[:-1] * C_train[1:])
C_idx = np.linspace(0, len(C_mid)-1, 8, dtype=int)
C_val = C_mid[C_idx]
snr_mid = np.sqrt(snr_train[:-1] * snr_train[1:])
snr_idx = np.linspace(0, len(snr_mid)-1, 8, dtype = int)
snr_val = snr_mid[snr_idx]
# snr_val = [1]
gamma_val = gamma_train.copy()
radii_val = [0.85, 0.85, 0.9, 0.9, 0.95, 0.95, 1, 1]
per_cell_v = 1



class Generator(object):
    def __init__(self, N_train=50, N_test=100, N_val = 50,generative_model='LSM_multiclass', p_SBM=0.8, q_SBM=0.2, n_classes=2, path_dataset='dataset',
                 num_examples_train=100, num_examples_test=10, num_examples_val=10):
        # 这些参数在哪用到的，代表什么
        self.N_train = N_train
        self.N_test = N_test
        self.N_val = N_val

        self.generative_model = generative_model
        self.p_SBM = p_SBM
        self.q_SBM = q_SBM
        self.n_classes = n_classes
        self.path_dataset = path_dataset

        self.data_train = None
        self.data_test = None
        self.data_val = None

        self.num_examples_train = num_examples_train
        self.num_examples_test = num_examples_test
        self.num_examples_val = num_examples_val

        self.fixed_class_sizes = [
            (500, 500),
            (400, 600),
            (300, 700),
            (200, 800),
            (100, 900),
            (50, 950)
        ]



    def prepare_data(self):
        def get_npz_dataset(path, mode, *, tau_grid, gamma_grid,C_grid,  snr_grid, radii_grid, latent_dim, per_cell, min_size=30, base_seed=0):
            if not os.path.exists(path):
                os.makedirs(path)
                print(f"[创建数据集] {mode} 数据目录不存在，已新建：{path}")

            npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
            if not npz_files:
                print(f"[创建数据集] {mode} 数据未找到，开始生成...")
                self.create_dataset_grid(
                    path, mode=mode,
                    tau_grid=tau_grid,
                    gamma_grid=gamma_grid,
                    C_grid = C_grid,
                    # norm_grid = norm_grid,
                    snr_grid = snr_grid,
                    radii_grid=radii_grid,
                    latent_dim = latent_dim,
                    per_cell=per_cell,
                    min_size=min_size,
                    base_seed=base_seed
                )
                npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
            else:
                print(f"[读取数据] {mode} 集已存在，共 {len(npz_files)} 张图：{path}")

            return [os.path.join(path, f) for f in npz_files]

        # ==== 目录 ====
        train_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gstr{self.N_train}_numtr{self.num_examples_train}"
        test_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gste{self.N_test}_numte{self.num_examples_test}"
        val_dir = f"{self.generative_model}_nc{self.n_classes}_rand_val{self.N_val}_numval{self.num_examples_val}"

        train_path = os.path.join(self.path_dataset, train_dir)
        test_path = os.path.join(self.path_dataset, test_dir)
        val_path = os.path.join(self.path_dataset, val_dir)

        # ==== 采用上面的三套参数 ====
        self.data_train = get_npz_dataset(
            train_path, 'train',
            tau_grid=tau_train, gamma_grid=gamma_train, C_grid = C_train, snr_grid = snr_train,   radii_grid=radii_train, latent_dim = latent_dim, per_cell=per_cell_tr,
            min_size=30, base_seed=123
        )
        # self.data_val = get_npz_dataset(
        #     val_path, 'val',
        #     tau_grid=tau_val, gamma_grid=gamma_val, C_grid = C_val, norm_grid = norm_val,  radii_grid=radii_val, latent_dim = latent_dim, per_cell=per_cell_v,
        #     min_size=50, base_seed=2025
        # )
        # self.data_test = get_npz_dataset(
        #     test_path, 'test',
        #     tau_grid=tau_test, gamma_grid=gamma_test, C_grid = C_test, norm_grid = norm_test, radii_grid=radii_test, latent_dim = latent_dim, per_cell=per_cell_te,
        #     min_size=50, base_seed=31415
        # )
        if not os.path.exists(val_path):
            os.makedirs(val_path)
            print(f"[创建验证集] 使用 SBM 生成验证集：{val_path}")

            # 初始化 SBM 生成器
            sbm_gen = SBMGenerator(
                N_train=self.N_train,
                N_test=self.N_train,
                N_val=self.N_train,
                generative_model='SBM_multiclass',
                n_classes=self.n_classes,
                path_dataset=self.path_dataset,
                num_examples_train=0,
                num_examples_test=0,
                num_examples_val=len(gamma_val) * len(C_val) * per_cell_v  # 根据需要调整数量
            )

            # 使用 SBM 生成验证集
            sbm_gen.create_dataset_grid(
                directory=val_path,
                mode='val',
                snr_grid= snr_val,  # 使用 SBM 的验证集 SNR
                gamma_grid=gamma_val,  # 使用 SBM 的验证集 gamma
                C_grid=C_val,  # 使用 SBM 的验证集 C
                per_cell=per_cell_v,  # 每个网格点生成数量
                min_size=30,
                base_seed=2025
            )

        # 加载验证集路径
        self.data_val = [os.path.join(val_path, f) for f in os.listdir(val_path) if f.endswith('.npz')]
        print(f"[读取验证集] 共 {len(self.data_val)} 张图：{val_path}")



    # 生成每个类别的样本个数
    def _sample_class_sizes_dirichlet(self, N, n_classes, gamma, min_size, rng,
                                      gamma_jitter=0.5):
        """按 Dirichlet(gamma) 采样类比例 + min_size 下界，支持 gamma 抖动。"""
        assert N >= min_size * n_classes
        remaining = N - min_size * n_classes

        # 对 gamma 做整体抖动：γ' = γ * U(1-δ, 1+δ)
        if gamma_jitter and gamma_jitter > 0:
            mult = rng.uniform(1.0 - gamma_jitter, 1.0 + gamma_jitter)
            gamma_used = gamma * mult
        else:
            gamma_used = gamma

        probs = rng.dirichlet(np.full(n_classes, gamma_used, dtype=float))
        extras = rng.multinomial(remaining, probs)
        return [min_size + int(e) for e in extras]  # 也可以返回实际使用的 γ





    # 根据之前的参数，生成标签、概率连接矩阵、邻接矩阵
    def gen_one_lsm_by_targets(
            self, N, n_classes, gamma, C,norm_mu, alpha_std, min_size=5, latent_dim=5,radii=[1,1,1,1,1],tau=0.5, rng=None,
    ):

        import numpy as np
        from numpy.random import multivariate_normal
        from scipy.special import expit as sigmoid

        # 在 d 维空间中，通过等角构造 K 个向量,代表mu1, mu2, ... mu K的方向
        # 按照半径进行缩放，从而控制p1, p2, ... p_k的大小不同
        import numpy as np
        def simplex_mu_dense(K, d, radii=None, rng=None):
            if rng is None:
                rng = np.random.default_rng()
            if d < K - 1:
                raise ValueError("需要 d >= K-1")

            # Step 1: 构造标准 simplex (K x (K-1))
            S = np.eye(K) - np.ones((K, K)) / K
            vals, vecs = np.linalg.eigh(S)
            idx = np.argsort(vals)[::-1][:K - 1]
            basis = vecs[:, idx]  # K x (K-1)
            coords = S @ basis
            coords = coords / np.linalg.norm(coords, axis=1, keepdims=True)

            # Step 2: 随机正交变换，嵌入到 R^d
            A = rng.normal(size=(d, K - 1))
            Q, _ = np.linalg.qr(A)  # d x (K-1)，正交列
            mu = coords @ Q.T  # (K, d)

            # Step 3: 半径缩放
            if radii is not None:
                radii = np.asarray(radii)
                if radii.shape[0] < K:
                    raise ValueError("radii 长度至少需为 K")
                radii = radii[:K]
                mu = mu * radii[:, None]
            return mu

        def construct_mu_from_angles(radii, rng,theta_mat, r=1.0, tol=1e-8):
            """
            完全控制: 输入一个 KxK 的角度矩阵 theta_mat (弧度制)，
            返回 Kxd 的向量矩阵 mu，使得向量对之间的夹角尽量接近输入。
            """
            K = theta_mat.shape[0]
            # 检查对称性
            if not np.allclose(theta_mat, theta_mat.T, atol=1e-6):
                raise ValueError("角度矩阵必须对称")
            if not np.allclose(np.diag(theta_mat), 0, atol=1e-6):
                raise ValueError("对角线必须是 0 (theta_kk=0)")

            # 构造 Gram 矩阵 G = cos(theta)
            G = np.cos(theta_mat)

            # 检查半正定
            eigvals = np.linalg.eigvalsh(G)
            if np.min(eigvals) < -tol:
                print("角度矩阵不合法，对应的 cos(theta) 矩阵不是正定的")
                mu = simplex_mu_dense(K, K, radii, rng)
            else:
                # Cholesky/EVD 分解，取低秩表示
                vals, vecs = np.linalg.eigh(G)
                vals[vals < 0] = 0  # 数值修正
                X = vecs @ np.diag(np.sqrt(vals))

                # 缩放半径
                mu = r * X[:, :K]  # KxK 向量
            return mu

        def construct_mu_partial(radii,K, r=1.0,  rng = None,small_angle=60, large_angle=120):
            """
            部分控制: 自动生成角度矩阵.
            - 相邻社区角度小 (q 大一些),
            - 非相邻社区角度大 (q 小一些).
            """
            theta_mat = np.zeros((K, K))

            for i in range(K):
                for j in range(i + 1, K):
                    if abs(i - j) == 1:
                        theta = np.deg2rad(small_angle + rng.normal(0, 2))  # 加点扰动
                    else:
                        theta = np.deg2rad(large_angle + rng.normal(0, 5))
                    theta_mat[i, j] = theta_mat[j, i] = theta

            mu = construct_mu_from_angles(radii, rng, theta_mat, r)
            return mu, theta_mat

        class_sizes = self._sample_class_sizes_dirichlet(N, n_classes, gamma, min_size, rng)
        boundaries = np.cumsum([0] + class_sizes)

        # 生成并置乱标签
        labels = np.zeros(N, dtype=int)
        for i in range(n_classes):
            start, end = boundaries[i], boundaries[i + 1]
            labels[start:end] = i
        # 打乱节点顺序
        perm = rng.permutation(N)
        labels = labels[perm]

        rng = np.random.default_rng() if rng is None else rng
        # === 1) 生成mu_1,... mu_k===
        """
        思路：椭圆上的随机分布，可以保证社区内连接概率比社区间高，且社区内能有区别
        也可以固定夹角生成数据
        """
        # mu = np.zeros((n_classes, latent_dim))
        # if n_classes == 2:
        #     mu[0] = radii[:latent_dim]
        #     mu[1] = [x * (-1) for x in mu[0] ]
        # else:
        radii = np.array(radii)
        # scaling_factor = np.sqrt(norm_mu) / np.max(radii)
        radii_scaled = radii * norm_mu
        mu = simplex_mu_dense(n_classes, latent_dim, radii_scaled, rng)
        # mu,theta_mat = construct_mu_partial( radii, n_classes, norm_mu, rng, small_angle=30, large_angle=80)

        # === 2) 生成z_i ===
        z_s = np.zeros((N, latent_dim))
        for k in range(n_classes):
            start, end = boundaries[k], boundaries[k + 1]
            z_ks = rng.multivariate_normal(mu[k], tau**2 * np.eye(latent_dim), size=class_sizes[k])
            z_s[start:end] = z_ks
        z_s = z_s[perm]


        # === 3) 生成alpha_i ===
        logn = np.log(N)
        alpha_bar = np.log(logn / N * C/n_classes) * 0.5
        alpha_s = rng.normal(alpha_bar, alpha_std, size=N)

        #生成alpha的方差参数设置为1

        # === 4) 计算连接概率矩阵 ===
        # 创建所有向量对的外积
        alpha_matrix =np.add.outer(alpha_s, alpha_s)  # 所有alpha_i + alpha_j
        # 计算所有z向量的点积矩阵
        z_dot_matrix = np.einsum('ik,jk->ij', z_s, z_s)  # 矩阵乘法得到所有点积
        # 计算所有对的logit
        logit_matrix = alpha_matrix + z_dot_matrix
        # 应用sigmoid，然后设置对角线为0
        connectivity_mat = sigmoid(logit_matrix)
        np.fill_diagonal(connectivity_mat, 0)  # 移除自环

        eigvals, eigvecs = np.linalg.eigh(connectivity_mat)
        idx = np.argsort(eigvals)[::-1]
        eigvecs_top = eigvecs[:, idx[:n_classes]]

        # === 5) 生成邻接矩阵 W  ===
        # 使用稀疏矩阵直接生成
        from scipy.sparse import random as sparse_random
        from scipy.sparse import lil_matrix

        W_sparse = lil_matrix((N, N), dtype=np.int8)
        for i in range(N):
            for j in range(i + 1, N):
                if rng.random() < connectivity_mat[i, j]:
                    W_sparse[i, j] = 1
        W_sparse = W_sparse + W_sparse.T

        # === 6) 计算概率连接矩阵的期望 B，exp(2 alpha + zi'zj)  ===
        mu_prod = mu @ mu.T
        B = sigmoid(mu_prod + 2*alpha_bar)


        return  class_sizes, gamma, W_sparse, labels,alpha_bar,B

    def create_dataset_grid(self, directory, mode='train', *,
                            tau_grid = (0.2,0.4,0.6,0.8),
                            gamma_grid=(0.15, 0.3, 0.6, 1.0, 2.0),
                            C_grid = [5,10,15],
                            snr_grid=[0.5, 1, 1.5, 2, 2.5],
                            radii_grid = [
                                [1,1,1,1,1],
                                [1.2,1.1, 1, 0.9, 0.8]
                            ],
                            latent_dim = 5,
                            per_cell=20,
                            min_size=5,
                            base_seed=0):
        """
        在 (SNR × gamma × C) 的笛卡尔网格上生成数据；每个网格点生成 per_cell 张图。
        文件名写入网格信息；.npz 中保存全部元数据，保证复现性与分析便利。
        """

        os.makedirs(directory, exist_ok=True)

        if mode == 'train':
            N = self.N_train
            num_graphs_expected = len(tau_grid) * len(gamma_grid) *len(C_grid)* len(snr_grid)*len(radii_grid) * per_cell
            self.data_train = directory
        elif mode == 'val':
            N = self.N_val
            num_graphs_expected = len(tau_grid) * len(gamma_grid)*len(C_grid) * len(snr_grid)* len(radii_grid) * per_cell
            self.data_val = directory
        elif mode == 'test':
            N = self.N_test
            num_graphs_expected = len(tau_grid) * len(gamma_grid) *len(C_grid)** len(snr_grid)* len(radii_grid) * per_cell
            self.data_test = directory
        else:
            raise ValueError(f"Unsupported mode: {mode}")

        from concurrent.futures import ProcessPoolExecutor
        import itertools
        from scipy.special import expit


        def precompute_scalings(K, n, C, max_r, target_snr):
            """半径全为1时预计算缩放系数"""

            abar = 0.5 * np.log(C * np.log(n) / n)

            def compute_snr(s):
                p = expit(2 * abar + max_r ** 2 * s**2)
                q = expit(2 * abar - 1 / (K - 1) * max_r ** 2 * s**2)
                return (p - q) ** 2 * n / (K * (p + (K - 1) * q) * np.log(n))

            grid = np.logspace(np.log10(0.01), np.log10(10), 100)  # 搜索区间 [0.01, 100]
            snrs = np.array([compute_snr(s) for s in grid])
            idx = np.argmin(np.abs(snrs - target_snr))
            scaling = grid[idx]
            return scaling


        idx = 0
        for t_idx, tau in enumerate(tau_grid):
            for r_idx, radii in enumerate(radii_grid):
                for C_idx, C in enumerate(C_grid):
                    for g_idx, gamma in enumerate(gamma_grid):
                        for s_idx, snr in enumerate(snr_grid):
                            # 每个格点单独的 RNG，保证可复现
                            cell_seed = base_seed + (t_idx * 10_000_000
                                                     + r_idx * 10_000
                                                     + s_idx *1000
                                                     + C_idx * 100
                                                     + g_idx * 10)

                            rng = np.random.default_rng(cell_seed)

                            for rep in range(per_cell):
                                rand_N = int(N + (rng.random() * 2 - 1) * 500)  # 修正：使用 rng 而不是 np.random

                                mean_r = np.mean(radii)
                                norm_mu = precompute_scalings(self.n_classes, N, C,mean_r,snr)

                                class_sizes, gamma, W_sparse, labels, alpha_bar,B = self.gen_one_lsm_by_targets(
                                    N=rand_N, n_classes=self.n_classes,gamma = gamma, C=C,  norm_mu = norm_mu, alpha_std = 0.5, min_size = min_size, latent_dim = latent_dim,
                                    radii = radii, tau = tau, rng = rng)
                                # W_sparse = csr_matrix(W_dense)

                                fname = (f"{mode}_i{idx:05d}"
                                         f"_C{round(C, 2)}"  # 直接四舍五入到2位小数
                                         f"_snr{round(snr, 2)}"
                                         f"_g{round(gamma, 3)}__rep{rep:02d}.npz")
                                path = os.path.join(directory, fname)

                                np.savez_compressed(
                                    path,
                                    adj_data=W_sparse.data,
                                    adj_indices=W_sparse.indices,
                                    adj_indptr=W_sparse.indptr,
                                    adj_shape=W_sparse.shape,
                                    labels=labels,
                                    tau =  tau,
                                    radii = radii,
                                    gamma=gamma,
                                    C = C,
                                    alpha = alpha_bar,
                                    connect_expected = B,
                                    class_sizes=np.array(class_sizes, dtype=np.int32),

                                )
                                if tau == 0:
                                    class_sizes, gamma, W_sparse, labels, alpha_bar, B = self.gen_one_lsm_by_targets(
                                        N=rand_N, n_classes=self.n_classes, gamma=gamma, C=C, norm_mu=norm_mu, alpha_std=0,
                                        min_size=min_size, latent_dim=latent_dim,
                                        radii=radii, tau=tau, rng=rng)
                                    # W_sparse = csr_matrix(W_dense)

                                    fname = (f"{mode}_i{idx:05d}"
                                             f"_DCBM"
                                             f"_C{round(C, 2)}"  # 直接四舍五入到2位小数
                                             f"_snr{round(snr, 2)}"
                                             f"_g{round(gamma, 2)}"
                                             f"_rep{rep:02d}")
                                    path = os.path.join(directory, fname)
                                    np.savez_compressed(
                                        path,
                                        adj_data=W_sparse.data,
                                        adj_indices=W_sparse.indices,
                                        adj_indptr=W_sparse.indptr,
                                        adj_shape=W_sparse.shape,
                                        labels=labels,
                                        tau=tau,
                                        radii=radii,
                                        gamma=gamma,
                                        C=C,
                                        alpha=alpha_bar,
                                        alpha_std=0,
                                        connect_expected=B,
                                        class_sizes=np.array(class_sizes, dtype=np.int32),

                                    )
                                idx += 1

        print(f"[{mode}] 网格数据完成: 共 {idx} 张（期望 {num_graphs_expected}）。目录: {directory}")

    def copy(self):
        return copy.deepcopy(self)