import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torchdiffeq import odeint
from torch_scatter import scatter_mean, scatter_sum
import torch.distributions as dist

import numpy as np


LOG2PI = np.log(2 * np.pi)

class PIFBase(nn.Module):
    # this is a general method which could be used for implement vector field in CNF or
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    # def zero_center_of_mass(self, x_pos, segment_ids):
    #     size = x_pos.size()
    #     assert len(size) == 2  # TODO check this
    #     seg_means = scatter_mean(x_pos, segment_ids, dim=0)
    #     mean_for_each_segment = seg_means.index_select(0, segment_ids)
    #     x = x_pos - mean_for_each_segment

    #     return x

    def get_k_params(self, bins):
        """
        function to get the k parameters for the discretised variable
        """
        # k = torch.ones_like(mu)
        # ones_ = torch.ones((mu.size()[1:])).cuda()
        # ones_ = ones_.unsqueeze(0)
        list_c = []
        list_l = []
        list_r = []
        for k in range(1, int(bins + 1)):
            # k = torch.cat([k,torch.ones_like(mu)*(i+1)],dim=1
            k_c = (2 * k - 1) / bins - 1
            k_l = k_c - 1 / bins
            k_r = k_c + 1 / bins
            list_c.append(k_c)
            list_l.append(k_l)
            list_r.append(k_r)
        # k_c = torch.cat(list_c,dim=0)
        # k_l = torch.cat(list_l,dim=0)
        # k_r = torch.cat(list_r,dim=0)

        return list_c, list_l, list_l

    def discretised_cdf(self, mu, sigma, x):
        """
        cdf function for the discretised variable
        """
        # in this case we use the discretised cdf for the discretised output function
        mu = mu.unsqueeze(1)
        sigma = sigma.unsqueeze(1)  # B,1,D

        f_ = 0.5 * (1 + torch.erf((x - mu) / (sigma * np.sqrt(2))))
        flag_upper = torch.ge(x, 1)
        flag_lower = torch.le(x, -1)
        f_ = torch.where(flag_upper, torch.ones_like(f_), f_)
        f_ = torch.where(flag_lower, torch.zeros_like(f_), f_)

        return f_


    def kabsch_align_by_batch(self, x_pred, x_ref, batch, eps=1e-8):
        device = x_pred.device
        N = x_pred.size(0)
        B = int(batch.max().item()) + 1

        # ---------- 1. 计算每个分子的质心 ----------
        ones = torch.ones(N, 1, device=device)

        count = torch.zeros(B, 1, device=device).scatter_add_(0, batch[:, None], ones)

        c_pred = torch.zeros(B, 3, device=device).scatter_add_(0, batch[:, None].expand(-1, 3), x_pred)
        c_ref  = torch.zeros(B, 3, device=device).scatter_add_(0, batch[:, None].expand(-1, 3), x_ref)

        c_pred = c_pred / (count + eps)
        c_ref  = c_ref  / (count + eps)

        X = x_pred - c_pred[batch]
        Y = x_ref  - c_ref[batch]

        # ---------- 2. 计算每个 batch 的协方差矩阵 ----------
        H = torch.zeros(B, 3, 3, device=device)

        for i in range(3):
            for j in range(3):
                H[:, i, j] = torch.zeros(B, device=device).scatter_add_(
                    0, batch, X[:, i] * Y[:, j]
                )

        # ---------- 3. SVD ----------
        U, S, Vt = torch.linalg.svd(H)

        # ---------- 4. 处理反射 ----------
        det = torch.det(Vt.transpose(-1, -2) @ U.transpose(-1, -2))
        sign = torch.sign(det)

        I = torch.eye(3, device=device).unsqueeze(0).repeat(B, 1, 1)
        I[:, 2, 2] = sign

        R = Vt.transpose(-1, -2) @ I @ U.transpose(-1, -2)  # [B, 3, 3]

        # ---------- 5. 应用旋转 ----------
        X_rot = torch.einsum("ni,bij->nj", X, R[batch])
        x_aligned = X_rot + c_ref[batch]

        # ---------- 6. RMSD ----------
        diff = x_aligned - x_ref
        rmsd = torch.sqrt(
            torch.zeros(B, device=device).scatter_add_(
                0, batch, (diff ** 2).sum(dim=1)
            ) / count.squeeze()
        )

        return x_aligned, rmsd


    def continuous_var_interpolation_update(self, t, x, s0, s1, prior=None):
        # """
        # x: [N, D]
        # """

        if prior is None:
            prior = [torch.zeros_like(x).to(x.device), s0 * torch.ones_like(x).to(x.device)]  # std

        gamma = t  # [0,1]
        

        # s1_modify = torch.where(gamma < 0.9, s1, 10*(1 - gamma) * s1)  # std4
        # s1_modify = torch.where(gamma < 0.4, s1, 1/(1-0.4)*(1 - gamma) * s1)  # std5
        s1_modify = (1 - gamma) * s1  # std6, std7


        e_coeff = torch.sqrt(s0**2*s1_modify**2/((1-gamma)*s1_modify**2 + gamma*s0**2))
        # u_coeff = gamma * (e_coeff/s1)**2
        u_coeff = gamma * s0**2/((1-gamma)*s1_modify**2 + gamma*s0**2)
        x_flow = [u_coeff * x, e_coeff]

        
        mu = x_flow[0] + x_flow[1] * torch.randn_like(x).to(x.device)


        # laplace_dist = torch.distributions.Laplace(x_flow[0], x_flow[1])  # laplace prior
        # mu = laplace_dist.sample().to(self.device)

        return mu


    def discrete_var_interpolation_update(self, t, x, K, s1, prior=None):
        # """
        # x: [N, K]
        # """
        if prior is None:
            # prior = torch.ones_like(x).to(x.device) / K  
            prior = torch.ones_like(x).to(x.device)


        gamma = t  # [0,1]


        s1_modify = torch.clamp((1 - gamma) * s1, min=1e-3)  # 1e-2
        s1_modify = K * s1_modify / (1-s1_modify+K * s1_modify)  # std8
        soft_x = (1 - s1_modify) * x + s1_modify / K
        x_flow = gamma * soft_x + (1 - gamma) * prior





        # s1_modify = torch.clamp((1 - gamma) * s1, min=1e-3)  # 1e-2
        # x_multi = x / s1_modify + prior
        # x_flow = gamma * x_multi  + (1 - gamma) * prior



        dirichlet_dist = torch.distributions.Dirichlet(x_flow)
        theta = dirichlet_dist.sample().to(x.device)


        # x_flow = self.exponential_geodesic(x, prior, gamma, K, soft_p1=True)

        # categorize_dist = torch.distributions.Categorical(probs=x_flow)
        # theta = categorize_dist.sample().to(self.device)
        # theta = F.one_hot(theta, K).float()

        return theta


    def exponential_geodesic(self, p1, p0, gamma, K, soft_p1=True, eps=torch.tensor(1e-2)):
        """
        计算离散分布的指数测地线：p_t ∝ p0^(1-t) * p1^t
        
        参数:
        p0, p1: torch.Tensor, shape = (n,)
            初始和目标分布，要求非负且和为1
        t_values: list or tensor
            想要求的t点, 例如 [0, 0.25, 0.5, 0.75, 1]
        
        返回:
        pts: torch.Tensor, shape = (len(t_values), n)
            各个t下的p_t
        """

        if soft_p1:
            p1 = (1-eps)*p1 + eps*1/K

        log_pt = gamma * torch.log(p1) + (1-gamma) * torch.log(p0)
        pt = torch.exp(log_pt)
        pt = pt / pt.sum(axis=-1,keepdim=True)  # normalize

        return pt

    def dtime4continuous_interpolation_loss(self, t, N, x_pred, x, s0, s1, segment_ids=None):
        # gamma = torch.clamp(t, min=0.1)
        gamma = t


        s0_expand = s0 * torch.ones([x_pred.shape[0],1]).to(x_pred.device)


        s1_modify = torch.clamp((1 - gamma) * s1, min=1e-3)  # std8, 1e-2
        s1_expand = s1_modify * torch.ones([x_pred.shape[0],1]).to(x_pred.device)


        e_coeff = torch.sqrt(s0_expand**2*s1_expand**2/((1-gamma)*s1_expand**2 + gamma*s0_expand**2))
        u_coeff = gamma * s0**2/((1-gamma)*s1_modify**2 + gamma*s0**2)

        theta1 = (u_coeff * x_pred, e_coeff)
        theta2 = (u_coeff * x, e_coeff)


        weight = 1



        if segment_ids is not None:
            # loss = weight * scatter_sum(
            #     self.gauss_kl_batch(theta1, theta2), segment_ids, dim=0
            # )

            loss = weight * scatter_mean(
                self.gauss_kl_batch(theta1, theta2), segment_ids, dim=0
            )

            # theta_0_nature = torch.concat([torch.zeros_like(x) / s0_expand**2, -1/(2*s0_expand**2)], dim=-1)  # (N, 4)
            # theta_1_nature = torch.concat([x / s1_expand**2, -1/(2*s1_expand**2)], dim=-1)  # (N, 4)
            # theta_t_nature = theta_0_nature * (1 - gamma) + theta_1_nature * gamma
            # theta_pred_1_nature = torch.concat([x_pred / s1_expand**2, -1/(2*s1_expand**2)], dim=-1)  # (N, 4)

            # theta_delta_vec = (theta_pred_1_nature - theta_1_nature) # (N, 4)

            # FIM = self.fisher_info_isotropic_gaussian_batch(theta_t_nature)  # (N,4,4)

            # loss_dist = torch.einsum('ni,nij,nj->n', theta_delta_vec, FIM, theta_delta_vec)



            # theta_0_nature = torch.concat([torch.zeros_like(x) / s0_expand**2, -1/(2*s0_expand**2)], dim=-1)  # (N, 4)
            # theta_1_nature = torch.concat([x / s1_expand**2, -1/(2*s1_expand**2)], dim=-1)  # (N, 4)
            # theta_t_nature = theta_0_nature * (1 - gamma) + theta_1_nature * gamma
            # theta_pred_1_nature = torch.concat([x_pred / s1_expand**2, -1/(2*s1_expand**2)], dim=-1)  # (N, 4)
            # theta_pred_t_nature = theta_0_nature * (1 - gamma) + theta_pred_1_nature * gamma

            # theta_delta_vec = (theta_pred_t_nature - theta_t_nature) # (N, 4)

            # FIM = self.fisher_info_isotropic_gaussian_batch(theta_t_nature)  # (N,4,4)

            # loss_dist = torch.einsum('ni,nij,nj->n', theta_delta_vec, FIM, theta_delta_vec)


            # loss = weight * scatter_sum(
            #     loss_dist, segment_ids, dim=0
            # )



        # b = gamma * 0.0 + (1 - gamma) * 1.0  # laplace prior
        # weight = 45.0
        # if segment_ids is not None:
        #     loss = scatter_mean(
        #         weight * (torch.exp(-torch.abs(x_pred - x)/b) + torch.abs(x_pred - x)/b).sum(-1), segment_ids, dim=0
        #     )

        return loss


    def gauss_kl_batch(self, theta1, theta2) -> torch.Tensor:
        u1, s1 = theta1
        u2, s2 = theta2

        term1 = torch.log(s1/s2)

        term2 = ((s1**2 + (u1 - u2)**2) / (2 * s2**2) - 1/2).sum(dim=1)

        return term1 + term2

    def fisher_info_isotropic_gaussian_batch(self, a):
        """
        Batch FIM for 3D isotropic Gaussian, input a: (N,4)
        """
        N = a.shape[0]
        eta1 = a[:, :3]          # (N,3)
        tau = a[:, 3]            # (N,)
        s = (eta1**2).sum(dim=1) # (N,)
        
        # I_11
        I_11 = -0.5 / tau.view(N,1,1) * torch.eye(3, device=a.device).unsqueeze(0)  # (N,3,3)

        # I_12
        I_12 = (0.5 / tau**2).view(N,1,1) * eta1.unsqueeze(2)  # (N,3,1)

        # I_22
        I_22 = -s / (2 * tau**3) + 3 / (2 * tau**2)  # (N,)

        # Assemble FIM
        F = torch.zeros(N, 4, 4, device=a.device)
        F[:, :3, :3] = I_11
        F[:, :3, 3:4] = I_12
        F[:, 3:4, :3] = I_12.transpose(1,2)
        F[:, 3, 3] = I_22

        return F

    def dtime4discrete_interpolation_loss_prob(
        self, t, N, p_0, one_hot_x, K, s1, segment_ids=None
    ):
        # gamma = torch.clamp(t, min=0.1)
        gamma = t
        
        s1_modify = torch.clamp((1 - gamma) * s1, min=1e-3)  # 1e-2
        s1_modify = K * s1_modify / (1-s1_modify+K * s1_modify)  # std8

        # prior = torch.ones_like(p_0).to(p_0.device) / K
        prior = torch.ones_like(p_0).to(p_0.device)
        soft_p_0 = (1 - s1_modify) * p_0 + s1_modify / K # vloss
        soft_one_hot_x = (1 - s1_modify) * one_hot_x + s1_modify / K

        alpha1 = gamma * soft_p_0  + (1 - gamma) * prior
        alpha2 = gamma * soft_one_hot_x  + (1 - gamma) * prior




        # s1_modify = torch.clamp((1 - gamma) * s1, min=1e-3)  # 1e-2
        # prior = torch.ones_like(p_0).to(p_0.device)
        # p_0_multi = p_0 / s1_modify + prior
        # one_hot_x_multi = one_hot_x / s1_modify + prior


        # alpha1 = gamma * p_0_multi  + (1 - gamma) * prior
        # alpha2 = gamma * one_hot_x_multi  + (1 - gamma) * prior


        weight = 1
        

        if segment_ids is not None:
            # loss = weight * scatter_sum(
            #     self.dirichlet_kl_batch(alpha1, alpha2), segment_ids, dim=0
            # )

            loss = weight * scatter_mean(
                self.dirichlet_kl_batch(alpha1, alpha2), segment_ids, dim=0
            )

            # theta_0_nature = prior - 1 
            # theta_1_nature = one_hot_x_multi - 1
            # theta_t_nature = theta_0_nature * (1 - gamma) + theta_1_nature * gamma
            # theta_pred_1_nature = p_0_multi - 1

            # theta_delta_vec = (theta_pred_1_nature - theta_1_nature)

            # FIM = self.fisher_info_dirichlet_batch(theta_t_nature)  # (N,4,4)

            # loss_dist = torch.einsum('ni,nij,nj->n', theta_delta_vec, FIM, theta_delta_vec)


            # theta_0_nature = prior - 1 
            # theta_1_nature = one_hot_x_multi - 1
            # theta_t_nature = theta_0_nature * (1 - gamma) + theta_1_nature * gamma
            # theta_pred_1_nature = p_0_multi - 1
            # theta_pred_t_nature = theta_0_nature * (1 - gamma) + theta_pred_1_nature * gamma

            # theta_delta_vec = (theta_pred_t_nature - theta_t_nature)

            # FIM = self.fisher_info_dirichlet_batch(theta_t_nature)  # (N,4,4)

            # loss_dist = torch.einsum('ni,nij,nj->n', theta_delta_vec, FIM, theta_delta_vec)

            # loss = weight * scatter_sum(
            #     loss_dist, segment_ids, dim=0
            # )



        # alpha1 = self.exponential_geodesic(p_0, 1/K * torch.ones_like(p_0), gamma, K, soft_p1=True)
        # alpha2 = self.exponential_geodesic(one_hot_x, 1/K * torch.ones_like(one_hot_x), gamma, K, soft_p1=True)

        # if segment_ids is not None:
        #     loss = scatter_mean(
        #         self.categorize_kl_batch(alpha1, alpha2), segment_ids, dim=0
        #     )

        return loss


    def dirichlet_kl_batch(self, alpha1: torch.Tensor, alpha2: torch.Tensor) -> torch.Tensor:
        """
        计算批量 Dirichlet 分布 KL 散度。
        alpha1, alpha2: 张量形状 [B, K]，表示 B 个样本，每个样本有 K 个浓度参数。
        返回: 形状 [B] 的 KL 散度。
        """
        # 1) 计算每个批次的浓度参数和，形状 [B]
        sum1 = alpha1.sum(dim=1)       # 
        sum2 = alpha2.sum(dim=1)       # 

        # 2) ln Γ(∑α1) - ln Γ(∑α2)，形状 [B]
        term1 = torch.lgamma(sum1) - torch.lgamma(sum2)

        # 3) ∑ [ln Γ(α2_i) - ln Γ(α1_i)]，形状 [B]
        term2 = torch.lgamma(alpha2).sum(dim=1) - torch.lgamma(alpha1).sum(dim=1)

        # 4) ∑ (α1_i - α2_i) [ψ(α1_i) - ψ(∑α1)]，形状 [B]
        term3 = ((alpha1 - alpha2) *
                (torch.digamma(alpha1) - torch.digamma(sum1).unsqueeze(1))
        ).sum(dim=1)

        return term1 + term2 + term3

    def fisher_info_dirichlet_batch(self, eta):
        """
        Compute Fisher Information Matrices for a batch of Dirichlet distributions
        using natural parameters eta = alpha - 1.

        Parameters
        ----------
        eta : torch.Tensor, shape (N, K)
            Natural parameters of Dirichlet (> -1)

        Returns
        -------
        F : torch.Tensor, shape (N, K, K)
            Batch of Fisher Information Matrices
        """
        N, K = eta.shape
        device = eta.device

        # 转回 concentration 参数
        alpha = eta + 1  # alpha_i > 0

        # trigamma(alpha_i) 对角项
        psi1_alpha = torch.polygamma(1, alpha)                # (N, K)

        # trigamma(sum(alpha)) 全部元素相同
        psi1_sum = torch.polygamma(1, alpha.sum(dim=-1, keepdim=True))  # (N,1)

        # 构造对角矩阵 diag(trigamma(alpha))
        F = torch.zeros(N, K, K, device=device)
        idx = torch.arange(K)
        F[:, idx, idx] = psi1_alpha

        # 减去 trigamma(sum(alpha)) * 1_{KxK}
        F = F - psi1_sum.unsqueeze(-1) * torch.ones(N, K, K, device=device)

        return F

    def categorize_kl_batch(self, p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
        # 计算 p * (log p - log q)
        term = p * torch.log(p/q)
        # sum over last dim
        kl = term.sum(dim=-1)  # shape: (batch,) 或 scalar
        return kl
