import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from bindgen.encoder import *
from bindgen.utils import *
from bindgen.nnutils import *
from bindgen.data import make_batch
import math

class PositionalEmbedding(nn.Module):
    __doc__ = r"""Computes a positional embedding of timesteps.

    Input:
        x: tensor of shape (N)
    Output:
        tensor of shape (N, dim)
    Args:
        dim (int): embedding dimension
        scale (float): linear scale to be applied to timesteps. Default: 1.0
    """

    def __init__(self, dim, scale=1.0):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.scale = scale

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class RefineDocker(ABModel):

    def __init__(self, args,
                 time_emb_scale = 1.0,
                 time_emb_dim=None):
        super(RefineDocker, self).__init__(args)
        self.rstep = args.rstep
        self.U_i = nn.Linear(self.embedding.dim(), args.hidden_size)
        self.target_mpn = EGNNEncoder(args, update_X=False)
        self.hierarchical = args.hierarchical
        self.time_generate_dim = 256
        self.time_emb_scale = 1.0
        self.time_emb_dim=123
        self.steps = 9
        self.weight = 1
        if args.hierarchical:
            self.struct_mpn = HierEGNNEncoder(args)
        else:
            self.struct_mpn = EGNNEncoder(args)

        self.W_x0 = nn.Sequential(
                nn.Linear(args.hidden_size, args.hidden_size),
                nn.ReLU(),
                nn.Linear(args.hidden_size, args.hidden_size)
        )
        self.U_x0 = nn.Sequential(
                nn.Linear(args.hidden_size, args.hidden_size),
                nn.ReLU(),
                nn.Linear(args.hidden_size, args.hidden_size)
        )
        self.W_a = nn.Linear(self.time_emb_dim*2, self.time_emb_dim)
        self.time_mlp = nn.Sequential(
            PositionalEmbedding(self.time_generate_dim, self.time_emb_scale),
            nn.Linear(self.time_generate_dim, self.time_emb_dim),
            nn.SiLU(),
            nn.Linear(self.time_emb_dim, self.time_emb_dim),
        ) if self.time_emb_dim is not None else None

        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    # def linear_beta_schedule(self, timesteps):
    #     scale = 1000 / timesteps
    #     beta_start = scale * 0.0001
    #     beta_end = scale * 0.02
    #     return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)

    def linear_beta_schedule(self, timesteps):
        #beta_start = 0.0001
        #beta_end = 0.02
        beta_start = 0.0001
        beta_end = 0.7
        return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)

    def get_alpha(self, beta_schedule):
        alphas = 1. - beta_schedule
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

        sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

        posterior_variance = beta_schedule * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        return sqrt_one_minus_alphas_cumprod, sqrt_alphas_cumprod, alphas

    def perturb_x(self, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, x, t, noise):
        return (
            self.extract(sqrt_alphas_cumprod, t, x.shape) * x +
            self.extract(sqrt_one_minus_alphas_cumprod, t, x.shape) * noise
        )

    def sample_normal(self, size, scale=1.0, key=None):
        return np.random.normal(scale=scale, size=size)

    def extract(self, a, t, x_shape):
        a = a.cuda()
        t = torch.tensor(t).cuda()
        out = a.gather(-1, t)
        return out.reshape(1, *((1,) * (len(x_shape) - 1)))
    #def extract(self, a, t, x_shape):
    #    b, *_ = t.shape
    #    a = a.cuda()
    #    out = a.gather(-1, t)
    #    return out.reshape(b, *((1,) * (len(x_shape) - 1)))
    def struct_loss(self, bind_X, tgt_X, true_V, true_R, true_D, inter_D, true_C):
        # dihedral loss
        bind_V = self.features._dihedrals(bind_X)
        vloss = self.mse_loss(bind_V, true_V).sum(dim=-1)
        # local loss
        rdist = bind_X.unsqueeze(2) - bind_X.unsqueeze(3)
        rdist = torch.sum(rdist ** 2, dim=-1)
        rloss = self.huber_loss(rdist, true_R) + 10 * F.relu(1.5 - rdist)
        # full loss
        cdist, _ = full_square_dist(bind_X, tgt_X, torch.ones_like(bind_X)[..., 0], torch.ones_like(tgt_X)[..., 0])
        closs = self.huber_loss(cdist, true_C) + 10 * F.relu(1.5 - cdist)
        # alpha carbon
        bind_X, tgt_X = bind_X[:, :, 1], tgt_X[:, :, 1]
        # CDR self distance
        dist = bind_X.unsqueeze(1) - bind_X.unsqueeze(2)
        dist = torch.sum(dist ** 2, dim=-1)
        dloss = self.huber_loss(dist, true_D) + 10 * F.relu(14.4 - dist)
        # inter distance
        idist = bind_X.unsqueeze(2) - tgt_X.unsqueeze(1)
        idist = torch.sum(idist ** 2, dim=-1)
        iloss = self.huber_loss(idist, inter_D) + 10 * F.relu(14.4 - idist)
        return dloss, vloss, rloss, iloss, closs

    def forward(self, binder, target, surface):
        true_X, true_S, true_A, _ = binder
        tgt_X, tgt_S, tgt_A, _ = target
        bind_surface, tgt_surface = surface

        true_mask = true_A[:, :, 1].clamp(max=1).float()


        # Encode target
        tgt_S = self.embedding(tgt_S)
        tgt_V = self.features._dihedrals(tgt_X)
        tgt_h, _ = self.target_mpn(tgt_X, tgt_V, self.U_i(tgt_S), tgt_A)
        _, tgt_S, _ = self.select_target(tgt_X, tgt_S, tgt_A, tgt_surface)
        tgt_X, tgt_h, tgt_A = self.select_target(tgt_X, tgt_h, tgt_A, tgt_surface)
        tgt_V = self.features._dihedrals(tgt_X)

        B, N, M = true_S.size(0), true_S.size(1), tgt_X.size(1)
        true_mask = true_A[:,:,1].clamp(max=1).float()
        tgt_mask = tgt_A[:,:,1].clamp(max=1).float()

        tgt_mean = (tgt_X[:,:,1] * tgt_mask[...,None]).sum(dim=1) / tgt_mask[...,None].sum(dim=1).clamp(min=1e-4)
        # bind_X = tgt_mean[:,None,None,:] + torch.rand_like(true_X)
        # bind_X2 = tgt_mean[:, None, None, :] + torch.rand_like(true_X)
        init_loss = 0
        dloss = vloss = rloss = iloss = closs = loss_predict = 0
        loss_noise = 0
        # Refine

        origin = False

        origin_change3 = False
        origin_change4 = True

        origin_changewithxt_1_withoutrefine = False


        if origin_change3:
            tgt_mean = (tgt_X[:, :, 1] * tgt_mask[..., None]).sum(dim=1) / tgt_mask[..., None].sum(dim=1).clamp(
                min=1e-4)
            bind_X = tgt_mean[:, None, None, :] + torch.randn_like(
                true_X)  # tgt_mean [B, 3] true_X [B, len(aa), len(atom), 3]

            num_steps = self.steps
            beta_schedule = self.linear_beta_schedule(num_steps)

            alpha_schedule = 1 - beta_schedule
            sqrt_one_minus_alphas_cumprod, sqrt_alphas_cumprod, alphas = self.get_alpha(beta_schedule)

            # print(ss)
            b = true_X.shape[0]
            device = true_X.device
            bb_mask = (true_X != 0).float()
            #t = torch.randint(1, num_steps, (b,), device=device)

            #test

            #test
            noise1 = torch.randn_like(true_X).cpu()
            noise2 = 1 / (num_steps-1) * tgt_mean[:, None, None, :].cpu()
            noise3 = tgt_mean[:, None, None, :].cpu()
            

            perturbed_x_list = []

            # diffusion_schedule
            for i in range(num_steps):

                if i == 0:
                    perturbed_x_list.append(true_X)
                else:

                    t = torch.full((b,), i)
                    
                    perturbed_t = sqrt_alphas_cumprod[t][:, None, None, None] * true_X.cpu() + sqrt_one_minus_alphas_cumprod[t][
                                                                                       :, None, None, None] * (
                                      noise1 + t[:, None, None, None] * noise2).cpu()
                    # perturbed_t = sqrt_alphas_cumprod[t][:, None, None, None] * true_X.cpu() + sqrt_one_minus_alphas_cumprod[t][
                    # :, None, None, None] * (noise1 + noise3).cpu()
                    perturbed_t = perturbed_t.float().cuda()
                    perturbed_x_list.append(perturbed_t)

            # diffusion_schedule

            #linear
            # linear_weight = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1]
            # linear_weight_tensor = torch.from_numpy(np.array(linear_weight))
            # for i in range(num_steps):
            #
            #     if i == 0:
            #         perturbed_x_list.append(true_X)
            #     else:
            #
            #         t = torch.full((b,), i, device=device)
            #         t2 = torch.full((b,), 8-i, device=device)
            #         # perturbed_t = linear_weight_tensor[t2][:, None, None, None] * true_X.cpu() + linear_weight_tensor[t][
            #         #                                                                    :, None, None, None] * (
            #         #                   noise1 + t[:, None, None, None] * noise2).cpu()
            #         perturbed_t = linear_weight_tensor[t2][:, None, None, None] * true_X.cpu() + linear_weight_tensor[t][
            #                       :, None, None, None] * (noise1 + noise3).cpu()
            #         # perturbed_t = linear_weight[8 - i] * true_X.cpu() + \
            #         #               linear_weight[i] * (noise1 + noise3).cpu()
            #         perturbed_t = perturbed_t.float().cuda()
            #         perturbed_x_list.append(perturbed_t)
            #linear




            # zhubudenoise
            for m in reversed(range(num_steps-1)):
                # Interpolated label
            # for m in range(num_steps-1):
            #     ratio = (m + 1) / self.rstep
            #     label_X = true_X * ratio + bind_X.detach() * (1 - ratio)


                label_X = perturbed_x_list[m]
                
                true_V = self.features._dihedrals(label_X)
                true_R, rmask_2D = inner_square_dist(label_X, true_A.clamp(max=1).float())
                true_D, mask_2D = self_square_dist(label_X, true_mask)
                true_C, cmask_2D = full_square_dist(label_X, tgt_X, true_A, tgt_A)
                inter_D, imask_2D = cross_square_dist(label_X, tgt_X, true_mask, tgt_mask)

                bind_V = self.features._dihedrals(bind_X)

                
                V = torch.cat([bind_V, tgt_V], dim=1).detach()
                X = torch.cat([bind_X, tgt_X], dim=1).detach()
                A = torch.cat([true_A, tgt_A], dim=1).detach()
                S = torch.cat([self.embedding(true_S), tgt_S], dim=1).detach()

                h_S = self.W_i(S)
                h, X = self.struct_mpn(X, V, h_S, A)

                bind_X = X[:, :N]

                dloss_t, vloss_t, rloss_t, iloss_t, closs_t = self.struct_loss(
                    bind_X, tgt_X, true_V, true_R, true_D, inter_D, true_C
                )
                vloss = vloss + vloss_t * true_mask
                dloss = dloss + dloss_t * mask_2D
                iloss = iloss + iloss_t * imask_2D
                rloss = rloss + rloss_t * rmask_2D
                closs = closs + closs_t * cmask_2D

                #add
                errors = bind_X - label_X
                errors = errors * bb_mask
                loss_predict = loss_predict + errors ** 2
                #add

            loss_predict = torch.sum(loss_predict) / (torch.sum(bb_mask) + 1e-10)

            dloss = torch.sum(dloss) / mask_2D.sum()
            iloss = torch.sum(iloss) / imask_2D.sum()
            vloss = torch.sum(vloss) / true_mask.sum()
            if self.hierarchical:
                rloss = torch.sum(rloss) / rmask_2D.sum()
            else:
                rloss = torch.sum(rloss[:, :, :4, :4]) / rmask_2D[:, :, :4, :4].sum()

            loss_noise = init_loss + dloss + iloss + vloss + rloss + loss_predict
            #loss_noise = init_loss + dloss + iloss + loss_predict
            #zhubudenoise






            # yibudenoise
            # import random
            # time = random.randint(1, num_steps-1)
            # t = torch.full((b,), time, device=device)
            # t_1 = t - 1
            #
            # perturbed_x = sqrt_alphas_cumprod[t][:, None, None, None] * true_X.cpu() + sqrt_one_minus_alphas_cumprod[t][
            #                                                                            :, None, None, None] * (
            #                       noise1 + t[:, None, None, None] * noise2).cpu()
            # perturbed_x = perturbed_x.float().cuda()
            #
            # perturbed_x_1 = sqrt_alphas_cumprod[t_1][:, None, None, None] * true_X.cpu() + \
            #                 sqrt_one_minus_alphas_cumprod[t_1][:, None, None, None] * (
            #                         noise1 + t_1[:, None, None, None] * noise2).cpu()
            # perturbed_x_1 = perturbed_x_1.float().cuda()
            #
            #
            # # t = torch.randint(1, num_steps,  device=device)
            # # t_1 = t - 1
            # # perturbed_x = sqrt_alphas_cumprod[t][:, None, None, None] * true_X.cpu() + sqrt_one_minus_alphas_cumprod[t][
            # #                                                                            :, None, None, None] * (
            # #                       noise1 + t[:, None, None, None] * noise2).cpu()
            # # perturbed_x = perturbed_x.float().cuda()
            # #
            # # perturbed_x_1 = sqrt_alphas_cumprod[t_1][:, None, None, None] * true_X.cpu() + \
            # #                 sqrt_one_minus_alphas_cumprod[t_1][:, None, None, None] * (
            # #                         noise1 + t_1[:, None, None, None] * noise2).cpu()
            # # perturbed_x_1 = perturbed_x_1.float().cuda()
            #
            # mask = (t_1 == 0)
            # replace_indices = mask.nonzero(as_tuple=True)
            # perturbed_x_1[replace_indices] = true_X[replace_indices]
            #
            # xt_1 = perturbed_x_1
            # xt = perturbed_x
            #
            # label_X = xt_1
            # bind_X = xt
            #
            # true_V = self.features._dihedrals(label_X)
            # true_R, rmask_2D = inner_square_dist(label_X, true_A.clamp(max=1).float())
            # true_D, mask_2D = self_square_dist(label_X, true_mask)
            # true_C, cmask_2D = full_square_dist(label_X, tgt_X, true_A, tgt_A)
            # inter_D, imask_2D = cross_square_dist(label_X, tgt_X, true_mask, tgt_mask)
            #
            # bind_V = self.features._dihedrals(bind_X)
            # V = torch.cat([bind_V, tgt_V], dim=1).detach()
            # X = torch.cat([bind_X, tgt_X], dim=1).detach()
            # A = torch.cat([true_A, tgt_A], dim=1).detach()
            # S = torch.cat([self.embedding(true_S), tgt_S], dim=1).detach()
            #
            # h_S = self.W_i(S)
            # h, X = self.struct_mpn(X, V, h_S, A)
            #
            # bind_X = X[:, :N]
            #
            # dloss_t, vloss_t, rloss_t, iloss_t, closs_t = self.struct_loss(
            #     bind_X, tgt_X, true_V, true_R, true_D, inter_D, true_C
            # )
            # vloss = vloss + vloss_t * true_mask
            # dloss = dloss + dloss_t * mask_2D
            # iloss = iloss + iloss_t * imask_2D
            # rloss = rloss + rloss_t * rmask_2D
            # closs = closs + closs_t * cmask_2D
            #
            # # add
            # errors = bind_X - label_X
            # errors = errors * bb_mask
            # loss_predict = loss_predict + errors ** 2
            # # add
            #
            # loss_predict = torch.sum(loss_predict) / (torch.sum(bb_mask) + 1e-10)
            #
            # dloss = torch.sum(dloss) / mask_2D.sum()
            # iloss = torch.sum(iloss) / imask_2D.sum()
            # vloss = torch.sum(vloss) / true_mask.sum()
            # if self.hierarchical:
            #     rloss = torch.sum(rloss) / rmask_2D.sum()
            # else:
            #     rloss = torch.sum(rloss[:, :, :4, :4]) / rmask_2D[:, :, :4, :4].sum()
            #
            # loss_noise = init_loss + dloss + iloss + vloss + rloss + loss_predict
            # yibudenoise

        elif origin_change4:
            tgt_mean = (tgt_X[:, :, 1] * tgt_mask[..., None]).sum(dim=1) / tgt_mask[..., None].sum(dim=1).clamp(
                min=1e-4)
            bind_X = tgt_mean[:, None, None, :] + torch.randn_like(
                true_X)  # tgt_mean [B, 3] true_X [B, len(aa), len(atom), 3]

            num_steps = self.steps
            beta_schedule = self.linear_beta_schedule(num_steps)

            alpha_schedule = 1 - beta_schedule
            sqrt_one_minus_alphas_cumprod, sqrt_alphas_cumprod, alphas = self.get_alpha(beta_schedule)

            # print(ss)
            b = true_X.shape[0]
            device = true_X.device
            bb_mask = (true_X != 0).float()
            #t = torch.randint(1, num_steps, (b,), device=device)

            #test

            #test
            noise1 = torch.randn_like(true_X).cpu()
            noise2 = 1 / (num_steps-1) * tgt_mean[:, None, None, :].cpu()
            noise3 = tgt_mean[:, None, None, :].cpu()
            

            perturbed_x_list = []

            # diffusion_schedule
            for i in range(num_steps):

                if i == 0:
                    perturbed_x_list.append(true_X)
                else:

                    t = torch.full((b,), i)
                    
                    perturbed_t = sqrt_alphas_cumprod[t][:, None, None, None] * true_X.cpu() + sqrt_one_minus_alphas_cumprod[t][
                                                                                       :, None, None, None] * (
                                      noise1 ).cpu()
                    perturbed_t = perturbed_t + t[:, None, None, None] * noise2
                    # perturbed_t = sqrt_alphas_cumprod[t][:, None, None, None] * true_X.cpu() + sqrt_one_minus_alphas_cumprod[t][
                    # :, None, None, None] * (noise1 + noise3).cpu()
                    perturbed_t = perturbed_t.float().cuda()
                    perturbed_x_list.append(perturbed_t)


            # zhubudenoise
            for m in reversed(range(num_steps-1)):
                # Interpolated label
            # for m in range(num_steps-1):
            #     ratio = (m + 1) / self.rstep
            #     label_X = true_X * ratio + bind_X.detach() * (1 - ratio)


                label_X = perturbed_x_list[m]

                true_V = self.features._dihedrals(label_X)
                true_R, rmask_2D = inner_square_dist(label_X, true_A.clamp(max=1).float())
                true_D, mask_2D = self_square_dist(label_X, true_mask)
                true_C, cmask_2D = full_square_dist(label_X, tgt_X, true_A, tgt_A)
                inter_D, imask_2D = cross_square_dist(label_X, tgt_X, true_mask, tgt_mask)

                bind_V = self.features._dihedrals(bind_X)
                
                V = torch.cat([bind_V, tgt_V], dim=1).detach()
                X = torch.cat([bind_X, tgt_X], dim=1).detach()
                A = torch.cat([true_A, tgt_A], dim=1).detach()
                S = torch.cat([self.embedding(true_S), tgt_S], dim=1).detach()

                h_S = self.W_i(S)
                h, X = self.struct_mpn(X, V, h_S, A)

                bind_X = X[:, :N]

                dloss_t, vloss_t, rloss_t, iloss_t, closs_t = self.struct_loss(
                    bind_X, tgt_X, true_V, true_R, true_D, inter_D, true_C
                )
                vloss = vloss + vloss_t * true_mask
                dloss = dloss + dloss_t * mask_2D
                iloss = iloss + iloss_t * imask_2D
                rloss = rloss + rloss_t * rmask_2D
                closs = closs + closs_t * cmask_2D

                #add
                errors = bind_X - label_X
                errors = errors * bb_mask
                loss_predict = loss_predict + errors ** 2
                #add

            loss_predict = torch.sum(loss_predict) / (torch.sum(bb_mask) + 1e-10)

            dloss = torch.sum(dloss) / mask_2D.sum()
            iloss = torch.sum(iloss) / imask_2D.sum()
            vloss = torch.sum(vloss) / true_mask.sum()
            if self.hierarchical:
                rloss = torch.sum(rloss) / rmask_2D.sum()
            else:
                rloss = torch.sum(rloss[:, :, :4, :4]) / rmask_2D[:, :, :4, :4].sum()

            loss_noise = init_loss + dloss + iloss + vloss + rloss + loss_predict
           
        elif origin_changewithxt_1_withoutrefine:
            num_steps = self.steps
            init_X = true_X
            beta_schedule = self.linear_beta_schedule(num_steps)

            alpha_schedule = 1 - beta_schedule
            sqrt_one_minus_alphas_cumprod, sqrt_alphas_cumprod, alphas = self.get_alpha(beta_schedule)

            b = true_X.shape[0]
            device = true_X.device

            t = torch.randint(1, num_steps, (b,), device=device)
            #t = torch.full((b,), 5, device=device)

            t_1 = t - 1

            noise1 = torch.randn_like(true_X)
            noise2 = 1 / (num_steps-1) * tgt_mean[:, None, None, :]


            perturbed_x = sqrt_alphas_cumprod[t][:, None, None, None] * init_X.cpu() + sqrt_one_minus_alphas_cumprod[t][:, None, None,None] * (noise1 + t[:, None, None, None]*noise2).cpu()
            perturbed_x = perturbed_x.float().cuda()

            perturbed_x_1 = sqrt_alphas_cumprod[t_1][:, None, None, None] * init_X.cpu() + sqrt_one_minus_alphas_cumprod[t_1][:, None, None, None] * (noise1 + t_1[:, None, None, None]*noise2).cpu()
            perturbed_x_1 = perturbed_x_1.float().cuda()

            mask = (t_1 == 0)
            replace_indices = mask.nonzero(as_tuple=True)
            perturbed_x_1[replace_indices] = true_X[replace_indices]

            xt_1 = perturbed_x_1
            xt = perturbed_x

            label_X = xt_1
            true_V = self.features._dihedrals(label_X)
            true_R, rmask_2D = inner_square_dist(label_X, true_A.clamp(max=1).float())
            true_D, mask_2D = self_square_dist(label_X, true_mask)
            true_C, cmask_2D = full_square_dist(label_X, tgt_X, true_A, tgt_A)
            inter_D, imask_2D = cross_square_dist(label_X, tgt_X, true_mask, tgt_mask)

            
            bind_X = self.predict(xt, tgt_V, tgt_X, tgt_A, tgt_S, true_A, true_S, N)

                # bind_X = self.predict_time(bind_X, tgt_V, tgt_X, tgt_A, tgt_S, true_A, true_S, N, t)

            dloss_t, vloss_t, rloss_t, iloss_t, closs_t = self.struct_loss(
                    bind_X, tgt_X, true_V, true_R, true_D, inter_D, true_C
                )
            vloss = vloss + vloss_t * true_mask
            dloss = dloss + dloss_t * mask_2D
            iloss = iloss + iloss_t * imask_2D
            rloss = rloss + rloss_t * rmask_2D
            closs = closs + closs_t * cmask_2D

            dloss = torch.sum(dloss) / mask_2D.sum()
            iloss = torch.sum(iloss) / imask_2D.sum()
            vloss = torch.sum(vloss) / true_mask.sum()
            if self.hierarchical:
                rloss = torch.sum(rloss) / rmask_2D.sum()
            else:
                rloss = torch.sum(rloss[:, :, :4, :4]) / rmask_2D[:, :, :4, :4].sum()



            bb_mask = (true_X != 0).float()

            # print(eps_theta_val.shape)
            # print(bb_mask.shape)

            errors = bind_X - xt_1
            errors = errors * bb_mask
            loss_predict = errors ** 2
            loss_predict = torch.sum(loss_predict) / (torch.sum(bb_mask) + 1e-10)

            # b_t = beta_schedule[t].cuda()
            # a_t = alpha_schedule[t].cuda()
            #
            # predict_eps_t = (xt - a_t[:, None, None, None]*bind_X)/b_t[:, None, None, None]
            # true_eps_t = (xt - a_t[:, None, None, None]*xt_1)/b_t[:, None, None, None]
            # errors2 = predict_eps_t - true_eps_t
            # errors2 = errors2 * bb_mask
            # loss_eps = errors2 ** 2
            # loss_eps = torch.sum(loss_eps) / (torch.sum(bb_mask) + 1e-10)

            # loss_eps = 0
            #loss_noise = init_loss + dloss + iloss + loss_predict
            # loss_noise = init_loss + dloss + iloss + vloss + rloss + loss_predict
            #loss_noise = loss_predict + dloss + vloss
            loss_noise = init_loss + dloss + iloss + vloss + rloss
            # loss_noise = vloss + rloss + loss_eps
        elif origin:
            tgt_mean = (tgt_X[:,:,1] * tgt_mask[...,None]).sum(dim=1) / tgt_mask[...,None].sum(dim=1).clamp(min=1e-4)
            bind_X = tgt_mean[:,None,None,:] + torch.randn_like(true_X) # tgt_mean [B, 3] true_X [B, len(aa), len(atom), 3]
            for t in range(self.rstep):
            # Interpolated label 
                ratio = (t + 1) / self.rstep
                label_X = true_X * ratio + bind_X.detach() * (1 - ratio)
                true_V = self.features._dihedrals(label_X)
                true_R, rmask_2D = inner_square_dist(label_X, true_A.clamp(max=1).float())
                true_D, mask_2D = self_square_dist(label_X, true_mask)
                true_C, cmask_2D = full_square_dist(label_X, tgt_X, true_A, tgt_A)
                inter_D, imask_2D = cross_square_dist(label_X, tgt_X, true_mask, tgt_mask)

                bind_V = self.features._dihedrals(bind_X)
                V = torch.cat([bind_V, tgt_V], dim=1).detach()
                X = torch.cat([bind_X, tgt_X], dim=1).detach()
                A = torch.cat([true_A, tgt_A], dim=1).detach()
                S = torch.cat([self.embedding(true_S), tgt_S], dim=1).detach()
                
                h_S = self.W_i(S)
                h, X = self.struct_mpn(X, V, h_S, A)
                
                bind_X = X[:, :N]
                
                dloss_t, vloss_t, rloss_t, iloss_t, closs_t = self.struct_loss(
                        bind_X, tgt_X, true_V, true_R, true_D, inter_D, true_C
                )
                vloss = vloss + vloss_t * true_mask
                dloss = dloss + dloss_t * mask_2D
                iloss = iloss + iloss_t * imask_2D
                rloss = rloss + rloss_t * rmask_2D
                closs = closs + closs_t * cmask_2D

            dloss = torch.sum(dloss) / mask_2D.sum()
            iloss = torch.sum(iloss) / imask_2D.sum()
            vloss = torch.sum(vloss) / true_mask.sum()
            if self.hierarchical:
                rloss = torch.sum(rloss) / rmask_2D.sum()
            else:
                rloss = torch.sum(rloss[:,:,:4,:4]) / rmask_2D[:,:,:4,:4].sum()

            loss_noise = init_loss + dloss + iloss + vloss + rloss

        loss = loss_noise

        return ReturnType(loss=loss, bind_X=bind_X.detach(), handle=(tgt_X, tgt_A))


    def test222(self, binder, target, surface):
        true_X, true_S, true_A, _ = binder
        tgt_X, tgt_S, tgt_A, _ = target
        bind_surface, tgt_surface = surface

        true_mask = true_A[:, :, 1].clamp(max=1).float()

        # Encode target
        tgt_S = self.embedding(tgt_S)
        tgt_V = self.features._dihedrals(tgt_X)
        tgt_h, _ = self.target_mpn(tgt_X, tgt_V, self.U_i(tgt_S), tgt_A)
        _, tgt_S, _ = self.select_target(tgt_X, tgt_S, tgt_A, tgt_surface)
        tgt_X, tgt_h, tgt_A = self.select_target(tgt_X, tgt_h, tgt_A, tgt_surface)
        tgt_V = self.features._dihedrals(tgt_X)

        B, N, M = true_S.size(0), true_S.size(1), tgt_X.size(1)
        true_mask = true_A[:, :, 1].clamp(max=1).float()
        tgt_mask = tgt_A[:, :, 1].clamp(max=1).float()

        tgt_mean = (tgt_X[:, :, 1] * tgt_mask[..., None]).sum(dim=1) / tgt_mask[..., None].sum(dim=1).clamp(min=1e-4)

        init_loss = 0
        dloss = vloss = rloss = iloss = closs = loss_predict = 0
        loss_noise = 0
        # Refine

        origin_change3 = False
        origin_change4 = True

        if origin_change3:
            tgt_mean = (tgt_X[:, :, 1] * tgt_mask[..., None]).sum(dim=1) / tgt_mask[..., None].sum(dim=1).clamp(
                min=1e-4)
            bind_X = tgt_mean[:, None, None, :] + torch.randn_like(
                true_X)  # tgt_mean [B, 3] true_X [B, len(aa), len(atom), 3]

            num_steps = self.steps
            beta_schedule = self.linear_beta_schedule(num_steps)

            alpha_schedule = 1 - beta_schedule
            sqrt_one_minus_alphas_cumprod, sqrt_alphas_cumprod, alphas = self.get_alpha(beta_schedule)

            # print(ss)
            b = true_X.shape[0]
            device = true_X.device
            bb_mask = (true_X != 0).float()
            # t = torch.randint(1, num_steps, (b,), device=device)

            # test

            # test
            noise1 = torch.randn_like(true_X).cpu()
            noise2 = 1 / (num_steps - 1) * tgt_mean[:, None, None, :]
            noise2 = noise2.cpu()
            noise3 = tgt_mean[:, None, None, :].cpu()

            perturbed_x_list = []

            for i in range(num_steps):

                if i == 0:
                    perturbed_x_list.append(true_X)
                else:

                    t = torch.full((b,), i)

                    perturbed_t = sqrt_alphas_cumprod[t][:, None, None, None] * true_X.cpu() + sqrt_one_minus_alphas_cumprod[t][
                                                                                       :, None, None, None] * (
                                      noise1 + t[:, None, None, None] * noise2).cpu()
                    # perturbed_t = sqrt_alphas_cumprod[t][:, None, None, None] * true_X.cpu() + sqrt_one_minus_alphas_cumprod[t][
                    # :, None, None, None] * (noise1 + noise3).cpu()
                    perturbed_t = perturbed_t.float().cuda()
                    perturbed_x_list.append(perturbed_t)

            # linear
            # linear_weight = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1]
            # linear_weight_tensor = torch.from_numpy(np.array(linear_weight))
            # for i in range(num_steps):
            #
            #     if i == 0:
            #         perturbed_x_list.append(true_X)
            #     else:
            #
            #         t = torch.full((b,), i, device=device)
            #         t2 = torch.full((b,), 8-i, device=device)
            #         # perturbed_t = linear_weight_tensor[t2][:, None, None, None] * true_X.cpu() + linear_weight_tensor[t][
            #         #                                                                    :, None, None, None] * (
            #         #                   noise1 + t[:, None, None, None] * noise2).cpu()
            #         perturbed_t = linear_weight_tensor[t2][:, None, None, None] * true_X.cpu() + linear_weight_tensor[t][
            #                       :, None, None, None] * (noise1 + noise3).cpu()
            #         # perturbed_t = linear_weight[8 - i] * true_X.cpu() + \
            #         #               linear_weight[i] * (noise1 + noise3).cpu()
            #         perturbed_t = perturbed_t.float().cuda()
            #         perturbed_x_list.append(perturbed_t)
            # linear

            #zhubudenoise
            for m in reversed(range(num_steps - 1)):
                # Interpolated label
                # for m in range(num_steps-1):
                #     ratio = (m + 1) / self.rstep
                #     label_X = true_X * ratio + bind_X.detach() * (1 - ratio)

                label_X = perturbed_x_list[m]

                true_V = self.features._dihedrals(label_X)
                true_R, rmask_2D = inner_square_dist(label_X, true_A.clamp(max=1).float())
                true_D, mask_2D = self_square_dist(label_X, true_mask)
                true_C, cmask_2D = full_square_dist(label_X, tgt_X, true_A, tgt_A)
                inter_D, imask_2D = cross_square_dist(label_X, tgt_X, true_mask, tgt_mask)

                bind_V = self.features._dihedrals(bind_X)
                V = torch.cat([bind_V, tgt_V], dim=1).detach()
                X = torch.cat([bind_X, tgt_X], dim=1).detach()
                A = torch.cat([true_A, tgt_A], dim=1).detach()
                S = torch.cat([self.embedding(true_S), tgt_S], dim=1).detach()

                h_S = self.W_i(S)
               
                h, X = self.struct_mpn(X, V, h_S, A)

                bind_X = X[:, :N]


                # #add_z
                # noise_scale = 1.0
                # z = None
                # if z is None and m > 0:
                #     # z = torch.randn_like(x_t) + i/num_steps * tgt_mean[:, None, None, :]
                #     z = torch.randn_like(true_X).float()
                # elif z is None:
                #     z = 0.0
                # current_time = torch.full((b,), m, device=device)
                # b_t = beta_schedule[current_time]
                #
                # bind_X = bind_X + z * np.sqrt(b_t).float().cuda() * noise_scale
                # # add_z



                dloss_t, vloss_t, rloss_t, iloss_t, closs_t = self.struct_loss(
                    bind_X, tgt_X, true_V, true_R, true_D, inter_D, true_C
                )
                vloss = vloss + vloss_t * true_mask
                dloss = dloss + dloss_t * mask_2D
                iloss = iloss + iloss_t * imask_2D
                rloss = rloss + rloss_t * rmask_2D
                closs = closs + closs_t * cmask_2D

                # add
                errors = bind_X - label_X
                errors = errors * bb_mask
                loss_predict = loss_predict + errors ** 2
                # add

            loss_predict = torch.sum(loss_predict) / (torch.sum(bb_mask) + 1e-10)

            dloss = torch.sum(dloss) / mask_2D.sum()
            iloss = torch.sum(iloss) / imask_2D.sum()
            vloss = torch.sum(vloss) / true_mask.sum()
            if self.hierarchical:
                rloss = torch.sum(rloss) / rmask_2D.sum()
            else:
                rloss = torch.sum(rloss[:, :, :4, :4]) / rmask_2D[:, :, :4, :4].sum()

            loss_noise = init_loss + dloss + iloss + vloss + rloss + loss_predict
            #loss_noise = init_loss + dloss + iloss + loss_predict
            #zhubudenoise

        elif origin_change4:
            tgt_mean = (tgt_X[:, :, 1] * tgt_mask[..., None]).sum(dim=1) / tgt_mask[..., None].sum(dim=1).clamp(
                min=1e-4)
            bind_X = tgt_mean[:, None, None, :] + torch.randn_like(
                true_X)  # tgt_mean [B, 3] true_X [B, len(aa), len(atom), 3]

            num_steps = self.steps
            beta_schedule = self.linear_beta_schedule(num_steps)

            alpha_schedule = 1 - beta_schedule
            sqrt_one_minus_alphas_cumprod, sqrt_alphas_cumprod, alphas = self.get_alpha(beta_schedule)

            # print(ss)
            b = true_X.shape[0]
            device = true_X.device
            bb_mask = (true_X != 0).float()
            # t = torch.randint(1, num_steps, (b,), device=device)

            # test

            # test
            noise1 = torch.randn_like(true_X).cpu()
            noise2 = 1 / (num_steps - 1) * tgt_mean[:, None, None, :]
            noise2 = noise2.cpu()
        

            perturbed_x_list = []

            for i in range(num_steps):

                if i == 0:
                    perturbed_x_list.append(true_X)
                else:

                    t = torch.full((b,), i)

                    perturbed_t = sqrt_alphas_cumprod[t][:, None, None, None] * true_X.cpu() + \
                                  sqrt_one_minus_alphas_cumprod[t][
                                  :, None, None, None] * (noise1).cpu()
                    # perturbed_t = sqrt_alphas_cumprod[t][:, None, None, None] * true_X.cpu() + sqrt_one_minus_alphas_cumprod[t][
                    # :, None, None, None] * (noise1 + noise3).cpu()
                    perturbed_t = perturbed_t + t[:, None, None, None] * noise2
                    perturbed_t = perturbed_t.float().cuda()
                    perturbed_x_list.append(perturbed_t)

            # linear
            # linear_weight = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1]
            # linear_weight_tensor = torch.from_numpy(np.array(linear_weight))
            # for i in range(num_steps):
            #
            #     if i == 0:
            #         perturbed_x_list.append(true_X)
            #     else:
            #
            #         t = torch.full((b,), i, device=device)
            #         t2 = torch.full((b,), 8-i, device=device)
            #         # perturbed_t = linear_weight_tensor[t2][:, None, None, None] * true_X.cpu() + linear_weight_tensor[t][
            #         #                                                                    :, None, None, None] * (
            #         #                   noise1 + t[:, None, None, None] * noise2).cpu()
            #         perturbed_t = linear_weight_tensor[t2][:, None, None, None] * true_X.cpu() + linear_weight_tensor[t][
            #                       :, None, None, None] * (noise1 + noise3).cpu()
            #         # perturbed_t = linear_weight[8 - i] * true_X.cpu() + \
            #         #               linear_weight[i] * (noise1 + noise3).cpu()
            #         perturbed_t = perturbed_t.float().cuda()
            #         perturbed_x_list.append(perturbed_t)
            # linear

            #zhubudenoise
            for m in reversed(range(num_steps - 1)):
                # Interpolated label
                # for m in range(num_steps-1):
                #     ratio = (m + 1) / self.rstep
                #     label_X = true_X * ratio + bind_X.detach() * (1 - ratio)

                label_X = perturbed_x_list[m]

                true_V = self.features._dihedrals(label_X)
                true_R, rmask_2D = inner_square_dist(label_X, true_A.clamp(max=1).float())
                true_D, mask_2D = self_square_dist(label_X, true_mask)
                true_C, cmask_2D = full_square_dist(label_X, tgt_X, true_A, tgt_A)
                inter_D, imask_2D = cross_square_dist(label_X, tgt_X, true_mask, tgt_mask)

                bind_V = self.features._dihedrals(bind_X)
                V = torch.cat([bind_V, tgt_V], dim=1).detach()
                X = torch.cat([bind_X, tgt_X], dim=1).detach()
                A = torch.cat([true_A, tgt_A], dim=1).detach()
                S = torch.cat([self.embedding(true_S), tgt_S], dim=1).detach()

                h_S = self.W_i(S)
               
                h, X = self.struct_mpn(X, V, h_S, A)

                bind_X = X[:, :N]


                # #add_z
                # noise_scale = 1.0
                # z = None
                # if z is None and m > 0:
                #     # z = torch.randn_like(x_t) + i/num_steps * tgt_mean[:, None, None, :]
                #     z = torch.randn_like(true_X).float()
                # elif z is None:
                #     z = 0.0
                # current_time = torch.full((b,), m, device=device)
                # b_t = beta_schedule[current_time]
                #
                # bind_X = bind_X + z * np.sqrt(b_t).float().cuda() * noise_scale
                # # add_z



                dloss_t, vloss_t, rloss_t, iloss_t, closs_t = self.struct_loss(
                    bind_X, tgt_X, true_V, true_R, true_D, inter_D, true_C
                )
                vloss = vloss + vloss_t * true_mask
                dloss = dloss + dloss_t * mask_2D
                iloss = iloss + iloss_t * imask_2D
                rloss = rloss + rloss_t * rmask_2D
                closs = closs + closs_t * cmask_2D

                # add
                # #add_z
                # noise_scale = 1.0
                # z = None
                # if z is None and m > 0:
                #     # z = torch.randn_like(x_t) + i/num_steps * tgt_mean[:, None, None, :]
                #     z = torch.randn_like(true_X).float()
                # elif z is None:
                #     z = 0.0
                # current_time = torch.full((b,), m)
                # b_t = beta_schedule[current_time]
                # bind_X = bind_X + z * np.sqrt(b_t).float().cuda() * noise_scale

                errors = bind_X - label_X
                errors = errors * bb_mask
                loss_predict = loss_predict + errors ** 2
                # add

            loss_predict = torch.sum(loss_predict) / (torch.sum(bb_mask) + 1e-10)

            dloss = torch.sum(dloss) / mask_2D.sum()
            iloss = torch.sum(iloss) / imask_2D.sum()
            vloss = torch.sum(vloss) / true_mask.sum()
            if self.hierarchical:
                rloss = torch.sum(rloss) / rmask_2D.sum()
            else:
                rloss = torch.sum(rloss[:, :, :4, :4]) / rmask_2D[:, :, :4, :4].sum()

            loss_noise = init_loss + dloss + iloss + vloss + rloss + loss_predict
            #loss_noise = init_loss + dloss + iloss + loss_predict
            #zhubudenoise




        loss = loss_noise

        return ReturnType(loss=loss, bind_X=bind_X.detach(), handle=(tgt_X, tgt_A))

    def predict(self, bind_X, tgt_V, tgt_X, tgt_A, tgt_S, true_A, true_S, N):
        bind_V = self.features._dihedrals(bind_X)
        V = torch.cat([bind_V, tgt_V], dim=1).detach()
        X = torch.cat([bind_X, tgt_X], dim=1).detach()
        A = torch.cat([true_A, tgt_A], dim=1).detach()
        aa_emb = self.embedding(true_S)

        # cur_time = cur_time.view(1)
        #
        # time_emb = self.time_mlp(cur_time)
        #
        # total_emb = torch.cat((aa_emb, time_emb.unsqueeze(1).expand(-1, N, -1)), dim=2)
        #
        # total_emb = self.W_a(total_emb)
        #
        # S = torch.cat([total_emb, tgt_S_aa], dim=1).detach()
        S = torch.cat([self.embedding(true_S), tgt_S], dim=1).detach()

        h_S = self.W_i(S)
        h, X = self.struct_mpn(X, V, h_S, A)

        bind_X = X[:, :N]

        return bind_X

    def predict_time(self, bind_X, tgt_V, tgt_X, tgt_A, tgt_S, true_A, true_S, N, t):
        bind_V = self.features._dihedrals(bind_X)
        V = torch.cat([bind_V, tgt_V], dim=1).detach()
        X = torch.cat([bind_X, tgt_X], dim=1).detach()
        A = torch.cat([true_A, tgt_A], dim=1).detach()
        aa_emb = self.embedding(true_S)

        cur_time = t

        time_emb = self.time_mlp(cur_time)

        total_emb = torch.cat((aa_emb, time_emb.unsqueeze(1).expand(-1, N, -1)), dim=2)

        total_emb = self.W_a(total_emb)

        S = torch.cat([total_emb, tgt_S], dim=1).detach()

        # S = torch.cat([self.embedding(true_S), tgt_S], dim=1).detach()

        h_S = self.W_i(S)
        h, X = self.struct_mpn(X, V, h_S, A)

        bind_X = X[:, :N]

        return bind_X


