import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def clampped_one_hot(x, num_classes):
    
    mask = (x >= 0) & (x < num_classes)                      
    x = x.clamp(min=0, max=num_classes - 1)
    y = F.one_hot(x, num_classes) * mask[..., None]          
    return y


def random_normal_so3(std_idx, angular_distrib, device='cpu'):
    
    size = std_idx.size()
    u = F.normalize(torch.randn(list(size) + [3], device=device), dim=-1)  
    theta = angular_distrib.sample(std_idx)                                 
    w = u * theta[..., None]                                                
    return w


def so3vec_to_rotation(w):
    return exp_skewsym(so3vec_to_skewsym(w))

def so3vec_to_skewsym(w):
    x, y, z = torch.unbind(w, dim=-1)
    o = torch.zeros_like(x)
    S = torch.stack([
        o,  z, -y,
       -z,  o,  x,
        y, -x,  o,
    ], dim=-1).reshape(w.shape[:-1] + (3, 3))
    return S

def exp_skewsym(S):
    x = torch.linalg.norm(skewsym_to_so3vec(S), dim=-1)  
    I = torch.eye(3, device=S.device, dtype=S.dtype).view(
        *([1] * (S.dim() - 2)), 3, 3
    )
    sinx, cosx = torch.sin(x), torch.cos(x)
    b = (sinx + 1e-8) / (x + 1e-8)
    c = (1 - cosx + 1e-8) / (x**2 + 2e-8)  

    S2 = S @ S
    return I + b[..., None, None] * S + c[..., None, None] * S2


def skewsym_to_so3vec(S):
    x = S[..., 1, 2]
    y = S[..., 2, 0]
    z = S[..., 0, 1]
    w = torch.stack([x, y, z], dim=-1)
    return w


def rotation_to_so3vec(R):
    logR = log_rotation(R)
    w = skewsym_to_so3vec(logR)
    return w


def log_rotation(R):
    trace = R[..., range(3), range(3)].sum(-1)
    min_cos = -0.999 if torch.is_grad_enabled() else -1.0
    cos_theta = ((trace - 1) / 2).clamp_min(min_cos)
    sin_theta = torch.sqrt(1 - cos_theta**2)
    theta = torch.acos(cos_theta)
    coef = ((theta + 1e-8) / (2 * sin_theta + 1e-8))[..., None, None]
    logR = coef * (R - R.transpose(-1, -2))
    return logR


class ApproxAngularDistribution(nn.Module):
    def __init__(self, stddevs, std_threshold=0.1, num_bins=8192, num_iters=1024):
        super().__init__()
        self.std_threshold = std_threshold
        self.num_bins = num_bins
        self.num_iters = num_iters
        self.register_buffer('stddevs', torch.FloatTensor(stddevs))
        self.register_buffer('approx_flag', self.stddevs <= std_threshold)
        self._precompute_histograms()

    @staticmethod
    def _pdf(x, e, L):
        x = x[:, None]                                  
        c = ((1 - torch.cos(x)) / math.pi)              
        l = torch.arange(0, L, device=x.device)[None]   
        a = (2*l + 1) * torch.exp(-l*(l+1)*(e**2))      
        b = (torch.sin((l + 0.5) * x) + 1e-6) / (torch.sin(x / 2) + 1e-6)  
        f = (c * a * b).sum(dim=1)
        return f

    def _precompute_histograms(self):
        X, Y = [], []
        for std in self.stddevs:
            std = std.item()
            x = torch.linspace(0, math.pi, self.num_bins)                
            y = self._pdf(x, std, self.num_iters)                        
            y = torch.nan_to_num(y).clamp_min(0)
            X.append(x)
            Y.append(y)
        self.register_buffer('X', torch.stack(X, dim=0))  
        self.register_buffer('Y', torch.stack(Y, dim=0))  

    def sample(self, std_idx):
        
        size = std_idx.size()
        std_idx = std_idx.flatten()                
        prob = self.Y[std_idx]                     
        bin_idx = torch.multinomial(prob[:, :-1], num_samples=1).squeeze(-1)   
        bin_start = self.X[std_idx, bin_idx]       
        bin_width = self.X[std_idx, bin_idx + 1] - self.X[std_idx, bin_idx]
        samples_hist = bin_start + torch.rand_like(bin_start) * bin_width

        mean_gaussian = self.stddevs[std_idx] * 2
        std_gaussian = self.stddevs[std_idx]
        samples_gaussian = mean_gaussian + torch.randn_like(mean_gaussian) * std_gaussian
        samples_gaussian = samples_gaussian.abs() % math.pi

        gaussian_flag = self.approx_flag[std_idx]
        samples = torch.where(gaussian_flag, samples_gaussian, samples_hist)
        return samples.reshape(size)


class VarianceSchedule(nn.Module):
    def __init__(self, num_steps=100, s=0.01):
        super().__init__()
        T = num_steps
        t = torch.arange(0, num_steps + 1, dtype=torch.float)
        f_t = torch.cos((np.pi / 2) * ((t / T) + s) / (1 + s)) ** 2
        alpha_bars = f_t / f_t[0]

        betas = 1 - (alpha_bars[1:] / alpha_bars[:-1])
        betas = torch.cat([torch.zeros([1]), betas], dim=0)
        betas = betas.clamp_max(0.999)

        sigmas = torch.zeros_like(betas)
        for i in range(1, betas.size(0)):
            sigmas[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[i]
        sigmas = torch.sqrt(sigmas)

        self.register_buffer('betas', betas)
        self.register_buffer('alpha_bars', alpha_bars)
        self.register_buffer('alphas', 1 - betas)
        self.register_buffer('sigmas', sigmas)


class PositionTransition(nn.Module):

    def __init__(self, num_steps, var_sched_opt=None):
        super().__init__()
        var_sched_opt = var_sched_opt or {}
        self.var_sched = VarianceSchedule(num_steps, **var_sched_opt)

    def _t_per_node(self, batch, t):
        return t[batch]  

    def add_noise(self, p_0, batch, t):
        t_node = self._t_per_node(batch, t)                     
        alpha_bar = self.var_sched.alpha_bars[t_node]           
        c0 = torch.sqrt(alpha_bar + 1e-8).unsqueeze(-1)         
        c1 = torch.sqrt(1 - alpha_bar + 1e-8).unsqueeze(-1)     

        eps = torch.randn_like(p_0)
        p_t = c0 * p_0 + c1 * eps
        return p_t, eps

    def denoise(self, p_t, eps_pred, batch, t):
        t_node = self._t_per_node(batch, t)                      
        alpha      = self.var_sched.alphas[t_node].clamp_min(self.var_sched.alphas[-2])  
        alpha_bar  = self.var_sched.alpha_bars[t_node]           
        sigma      = self.var_sched.sigmas[t_node].unsqueeze(-1) 

        c0 = (1.0 / torch.sqrt(alpha + 1e-8)).unsqueeze(-1)      
        c1 = ((1 - alpha) / torch.sqrt(1 - alpha_bar + 1e-8)).unsqueeze(-1)  

        
        z = torch.randn_like(p_t)
        z = torch.where((t_node > 1).unsqueeze(-1), z, torch.zeros_like(z))

        p_prev = c0 * (p_t - c1 * eps_pred) + sigma * z
        return p_prev


class RotationTransition(nn.Module):
    def __init__(self, num_steps, angular_distrib_fwd_opt=None, angular_distrib_inv_opt=None):
        super().__init__()
        self.var_sched = VarianceSchedule(num_steps)

        c1 = torch.sqrt(1 - self.var_sched.alpha_bars)  
        self.angular_distrib_fwd = ApproxAngularDistribution(
            c1.tolist(), **(angular_distrib_fwd_opt or {})
        )
        sigma = self.var_sched.sigmas
        self.angular_distrib_inv = ApproxAngularDistribution(
            sigma.tolist(), **(angular_distrib_inv_opt or {})
        )
        self.register_buffer('_dummy', torch.empty([0]))

    def _t_per_node(self, batch, t):
        return t[batch]  

    def add_noise(self, v_0, batch, t):
        device = self._dummy.device
        t_node = self._t_per_node(batch, t)                      
        alpha_bar = self.var_sched.alpha_bars[t_node]            
        c0 = torch.sqrt(alpha_bar + 1e-8).unsqueeze(-1)          
        c1 = torch.sqrt(1 - alpha_bar + 1e-8).unsqueeze(-1)      

        
        e_scaled = random_normal_so3(t_node, self.angular_distrib_fwd, device=device)  
        E_scaled = so3vec_to_rotation(e_scaled)                                        

        
        R0_scaled = so3vec_to_rotation(c0 * v_0)                                       

        R_t = E_scaled @ R0_scaled
        v_t = rotation_to_so3vec(R_t)

        return v_t, e_scaled  

    def denoise(self, v_t, v_next_pred, batch, t):
        device = self._dummy.device
        t_node = self._t_per_node(batch, t)  

        
        e = random_normal_so3(t_node, self.angular_distrib_inv, device=device)  
        e = torch.where((t_node > 1).unsqueeze(-1), e, torch.zeros_like(e))
        E = so3vec_to_rotation(e)  

        R_next = E @ so3vec_to_rotation(v_next_pred)
        v_prev = rotation_to_so3vec(R_next)
        return v_prev


def cosine_beta_schedule_discrete(timesteps, s=0.008):
    steps = timesteps + 2
    x = np.linspace(0, steps, steps, dtype=np.float32)

    alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    alphas_cumprod = np.clip(alphas_cumprod, 1e-12, 1.0)

    alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]).astype(np.float32)
    betas = 1.0 - alphas
    return betas

def custom_beta_schedule_discrete(timesteps, average_num_nodes=50, s=0.008):
    steps = timesteps + 2
    x = np.linspace(0, steps, steps, dtype=np.float32)

    alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    alphas_cumprod = np.clip(alphas_cumprod, 1e-12, 1.0)

    alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]).astype(np.float32)
    betas = 1.0 - alphas

    assert timesteps >= 100
    p = 4 / 5.0
    num_edges = average_num_nodes * (average_num_nodes - 1) / 2.0

    updates_per_graph = 1.2
    beta_first = updates_per_graph / (p * num_edges)
    betas[:101] = np.maximum(betas[:101], beta_first).astype(np.float32)

    return betas

class PredefinedNoiseScheduleDiscrete(torch.nn.Module):
    

    def __init__(self, noise_schedule, timesteps, noise_type):
        super(PredefinedNoiseScheduleDiscrete, self).__init__()
        self.timesteps = timesteps
        self.noise_type = noise_type

        if noise_schedule == 'cosine':
            betas = cosine_beta_schedule_discrete(timesteps)
        elif noise_schedule == 'custom':
            betas = custom_beta_schedule_discrete(timesteps)
        else:
            raise NotImplementedError(noise_schedule)

        self.register_buffer('betas', torch.from_numpy(betas).float())

        self.alphas = 1 - torch.clamp(self.betas, min=0, max=0.9999)

        log_alpha = torch.log(self.alphas)
        log_alpha_bar = torch.cumsum(log_alpha, dim=0)
        self.alphas_bar = torch.exp(log_alpha_bar)

    def forward(self, t_normalized=None, t_int=None):
        assert int(t_normalized is None) + int(t_int is None) == 1
        if t_int is None:
            t_int = torch.round(t_normalized * self.timesteps)
        if self.betas.device != t_int.device:
            self.betas = self.betas.to(t_int.device)
        return self.betas[t_int.long()]

    def get_alpha_bar(self, t_normalized=None, t_int=None):
        if self.noise_type == 'uniform':
            assert int(t_normalized is None) + int(t_int is None) == 1
            if t_int is None:
                t_int = torch.round(t_normalized * self.timesteps)
            if self.alphas_bar.device != t_int.device:
                self.alphas_bar = self.alphas_bar.to(t_int.device)
            return self.alphas_bar[t_int.long()]
        else:
            return t_normalized

class BlosumTransition:
    def __init__(self, blosum_path='blosum_substitute.pt', x_classes=20, timestep=500):
        try:
            self.original_score, self.temperature_list, self.Qt_temperature = torch.load(blosum_path)['original_score'], torch.load(blosum_path)['Qtb_temperature'], torch.load(blosum_path)['Qt_temperature']
        except FileNotFoundError:
            blosum_path = '../' + blosum_path
            self.original_score, self.temperature_list, self.Qt_temperature = torch.load(blosum_path)['original_score'], torch.load(blosum_path)['Qtb_temperature'], torch.load(blosum_path)['Qt_temperature']

        self.X_classes = x_classes
        self.timestep = timestep
        temperature_list = self.temperature_list.unsqueeze(dim=0)
        temperature_list = temperature_list.unsqueeze(dim=0)
        Qt_temperature = self.Qt_temperature.unsqueeze(dim=0)
        Qt_temperature = Qt_temperature.unsqueeze(dim=0)
        if temperature_list.shape[0] != self.timestep:
            output_tensor = F.interpolate(temperature_list, size=timestep+1, mode='linear', align_corners=True)
            self.temperature_list = output_tensor.squeeze()
            output_tensor = F.interpolate(Qt_temperature, size=timestep+1, mode='linear', align_corners=True)
            self.Qt_temperature = output_tensor.squeeze()
        else:    
            self.temperature_list = self.temperature_list
            self.Qt_temperature = self.Qt_temperature
    
    def get_Qt_bar(self, t_normal, device):

        self.original_score = self.original_score.to(device)
        self.temperature_list = self.temperature_list.to(device)
        t_int = torch.round(t_normal * self.timestep).to(device)
        temperature = self.temperature_list[t_int.long()]
        q_x = self.original_score.unsqueeze(0)/temperature.unsqueeze(2)
        q_x = torch.softmax(q_x,dim=2)
        q_x[q_x < 1e-6] = 1e-6
        return q_x

    def get_Qt(self, t_normal, device):

        self.original_score = self.original_score.to(device)
        self.Qt_temperature = self.Qt_temperature.to(device)
        t_int = torch.round(t_normal * self.timestep).to(device)
        temperatue = self.Qt_temperature[t_int.long()]       
        q_x = self.original_score.unsqueeze(0)/temperatue.unsqueeze(2)
        q_x = torch.softmax(q_x,dim=2)
        return q_x

class AminoacidCategoricalTransition(nn.Module):
    def __init__(
        self,
        num_steps,
        cfg,
        num_classes=20,
        var_sched_opt=None,
        blosum_path: str = "blosum_substitute.pt",
    ):
        super().__init__()
        self.num_classes = num_classes
        self.noise_type = cfg.get('noise_type', 'uniform')
        self.var_sched = VarianceSchedule(num_steps, **(var_sched_opt or {}))
        self.num_steps = num_steps

        if self.noise_type == "blosum":
            self.blosum = BlosumTransition(
                blosum_path=blosum_path, x_classes=num_classes, timestep=num_steps
            )
            self.predef = PredefinedNoiseScheduleDiscrete(
                noise_schedule="cosine",
                timesteps=num_steps,
                noise_type="uniform",
            )

    @staticmethod
    def _sample(c):
        N, K = c.size()
        x = torch.multinomial((c + 1e-8), 1).view(N)
        return x

    def _to_probs(self, x):
        K = self.num_classes
        if x.dim() == 1:
            return clampped_one_hot(x, num_classes=K).float()
        elif x.dim() == 2 and x.size(-1) == K:
            c = x.float()
            return c / (c.sum(dim=-1, keepdim=True) + 1e-8)
        else:
            raise ValueError("x must be (N,) Long or (N,K) Float")

    @staticmethod
    def _t_per_node(batch, t):
        
        return t[batch]

    
    def _uniform_add_noise(self, c0, batch, t):
        
        t_node = self._t_per_node(batch, t)                          
        alpha_bar = self.var_sched.alpha_bars[t_node].unsqueeze(-1)  
        K = self.num_classes
        c_t = (alpha_bar * c0) + ((1 - alpha_bar) / K)               
        x_t = self._sample(c_t)
        return c_t, x_t

    def _uniform_posterior(self, x_t, c0_pred, batch, t):
        
        K = self.num_classes
        c_t = x_t if (x_t.dim() == 2) else clampped_one_hot(x_t, num_classes=K).float()
        t_node   = self._t_per_node(batch, t)
        alpha     = self.var_sched.alphas[t_node].unsqueeze(-1)      
        alpha_bar = self.var_sched.alpha_bars[t_node].unsqueeze(-1)  

        theta = ((alpha * c_t) + (1 - alpha) / K) * ((alpha_bar * c0_pred) + (1 - alpha_bar) / K)
        theta = theta / (theta.sum(dim=-1, keepdim=True) + 1e-8)
        return theta

    def _blosum_add_noise(self, c0, batch, t):
        
        device = c0.device
        B = int(batch.max().item()) + 1 if batch.numel() > 0 else 1
        
        t_int = t.long()
        t_norm = (t_int.float() / self.num_steps).view(B, 1)
        
        Qtb = self.blosum.get_Qt_bar(t_norm, device=device)
        
        c_t = torch.bmm(Qtb[batch], c0.unsqueeze(-1)).squeeze(-1)    
        c_t = c_t / (c_t.sum(dim=-1, keepdim=True) + 1e-8)
        x_t = self._sample(c_t)
        return c_t, x_t

    def _blosum_posterior(self, x_t, c0_pred, batch, t):
        device = c0_pred.device
        K = self.num_classes
        N = c0_pred.size(0)
        B = int(batch.max().item()) + 1 if batch.numel() > 0 else 1

        c_t = x_t if (x_t.dim() == 2) else clampped_one_hot(x_t, num_classes=K).float()

        t_int = t.long()
        s_int = (t_int - 1).clamp_min(0)
        t_norm = (t_int.float() / self.num_steps).view(B, 1)   
        s_norm = (s_int.float() / self.num_steps).view(B, 1)   

        
        Qtb = self.blosum.get_Qt_bar(t_norm, device=device)    
        Qsb = self.blosum.get_Qt_bar(s_norm, device=device)    

        beta_t = self.predef(t_int=t_int.view(B, 1))           
        Qt  = self.blosum.get_Qt(beta_t, device=device)        

        
        Qt_T = Qt.transpose(-1, -2)                            
        left = torch.bmm(c_t.unsqueeze(1), Qt_T[batch]).squeeze(1)  

        
        right = torch.bmm(c0_pred.unsqueeze(1), Qsb[batch]).squeeze(1)  

        numerator = left * right                               

        
        denom = torch.bmm(
            torch.bmm(c0_pred.unsqueeze(1), Qtb[batch]),       
            c_t.unsqueeze(-1)                                  
        ).squeeze(-1).squeeze(-1)                              

        denom = torch.where(denom == 0, torch.full_like(denom, 1e-6), denom)
        theta = numerator / denom.unsqueeze(-1)                
        theta = theta / (theta.sum(dim=-1, keepdim=True) + 1e-8)

        t_node = self._t_per_node(batch, t)
        mask_t0 = (t_node == 0)
        if mask_t0.any():
            theta[mask_t0] = c0_pred[mask_t0]

        return theta

    def add_noise(self, x_0, batch, t):
        c0 = self._to_probs(x_0)  
        if self.noise_type == "blosum":
            return self._blosum_add_noise(c0, batch, t)
        else:
            return self._uniform_add_noise(c0, batch, t)

    def denoise(self, x_t, c0_pred, batch, t):
        c0_pred = c0_pred / (c0_pred.sum(dim=-1, keepdim=True) + 1e-8)

        if self.noise_type == "blosum":
            post = self._blosum_posterior(x_t, c0_pred, batch, t)
        else:
            post = self._uniform_posterior(x_t, c0_pred, batch, t)

        x_prev = self._sample(post)
        return post, x_prev
