import numpy as np
import torch.nn as nn
import torch
from model.diffusion import GradLogPEstimator2d
from model.diffsinger import DiffSingerNet
from torchdyn.core import NeuralODE
from model.optimal_transport import OTPlanSampler
import math
import time
from typing import overload

class Wrapper(nn.Module):
    def __init__(self, vector_field_net, mask, mu, spk):
        super(Wrapper, self).__init__()
        self.net = vector_field_net
        self.mask = mask
        self.mu = mu
        self.spk = spk

    def forward(self, t, x, args):
        # NOTE: args cannot be dropped here. This function signature is strictly required by the NeuralODE class
        t = torch.tensor([t], device=t.device)
        return self.net(x, self.mask, self.mu, t, self.spk)


class FM(nn.Module):
    def __init__(self, n_feats, dim, spk_emb_dim=64, sigma_min: float = 0.1, pe_scale=1000, net_type="unet", encoder_output_dim=80):
        super(FM, self).__init__()
        self.n_feats = n_feats
        self.dim = dim
        self.spk_emb_dim = spk_emb_dim
        self.sigma_min = sigma_min
        self.pe_scale = pe_scale

        print(f"Using flow matching net type: {net_type}")
        if net_type == "unet":
            self.estimator = GradLogPEstimator2d(dim,
                                                 spk_emb_dim=spk_emb_dim,
                                                 pe_scale=pe_scale,
                                                 n_feats=n_feats)
        elif net_type == "diffsinger":
            self.estimator = DiffSingerNet(residual_channels=dim, in_dims=n_feats,
                                           spk_emb_dim=spk_emb_dim, pe_scale=pe_scale,
                                           encoder_hidden=encoder_output_dim)
        else:
            raise NotImplementedError

        self.criterion = torch.nn.MSELoss()

    def ode_wrapper(self, mask, mu, spk):
        # self.estimator receives x, mask, mu, t, spk as arguments
        return Wrapper(self.estimator, mask, mu, spk)



    @torch.no_grad()
    def inference(self, z, mask, mu, n_timesteps, spk=None, solver="dopri5"):
        print("FM inference")
        t_span = torch.linspace(0, 1, n_timesteps+1)
        neural_ode = NeuralODE(self.ode_wrapper(mask, mu, spk), solver=solver, sensitivity="adjoint", atol=1e-4, rtol=1e-4)

        print("neural_ode: ", neural_ode)
        x = z

        eval_points, traj = neural_ode(x, t_span) # 리버스 식
        print("traj: ", traj) # timestep이 100이면 101개, 시간에 따라 변화된 trajectory. x가 시간 0에서 시작해서 t=1까지 벡터 필드를 따라 이동한 전체 경로.
        return traj

    def backward(self, x, mask, mu, n_timesteps, spk=None, solver="euler"):
        print("FM backward")
        t_span = torch.linspace(1, 0, n_timesteps+1)
        neural_ode = NeuralODE(self.ode_wrapper(mask, mu, spk), solver=solver, sensitivity="adjoint", atol=1e-4, rtol=1e-4)
        _, traj = neural_ode(x, t_span)
        return traj

    def compute_likelihood(self, x, mask, mu, n_timesteps, spk=None, solver="euler"):
        print("FM compute_likelihood")
        device = x.device
        with torch.no_grad():
            
            back_traj = self.backward(x, mask, mu, n_timesteps, spk=spk, solver=solver).cpu()
            last_sample = back_traj[-1]
            del x  # free cuda memory.
            back_traj = back_traj[:-1].squeeze(1)  # Omit the last one. [timesteps, 80, L]
            time_interval = 1/n_timesteps

            t_span = torch.linspace(1, time_interval, n_timesteps)  # cpu

            D, L = last_sample.shape[1], last_sample.shape[2]  # last sample has shape [1, 80, L]
            # compute its likelihood given N(0,1)
            last_sample_loglike = -0.5 * (last_sample**2).sum() - D*L/2 * math.log(2*math.pi)

        batch_size = 3
        num_runs = 1
        trace_estimate = 0
        for run_index in range(num_runs):
            for start in range(n_timesteps//batch_size+1):
                if start * batch_size == n_timesteps:
                    break
                end_index = min((start+1)*batch_size, n_timesteps)
                back_traj_segment = back_traj[start * batch_size: end_index].to(device)
                back_traj_segment.requires_grad = True
                mu_segment = torch.concat([mu] * (end_index - start * batch_size), 0)
                spk_segment = torch.concat([spk] * (end_index - start * batch_size), 0)
                mask_segment = torch.concat([mask] * (end_index - start * batch_size), 0)
                t_segment = t_span[start * batch_size: end_index].to(device)

                minibatch = (back_traj_segment, mask_segment, mu_segment, t_segment, spk_segment)
                vf = self.estimator(*minibatch)
                del mu_segment, spk_segment, mask_segment, t_segment
                noise = torch.randn_like(back_traj_segment)  # [timesteps, 80, L]
                mult_with_noise = vf * noise  # [timesteps, 80, L]
                mult_with_noise.sum().backward(retain_graph=True)
                grad = back_traj_segment.grad.detach().cpu()
                trace_estimate = trace_estimate - (grad * noise.cpu()).sum().detach().item()
                back_traj_segment.grad = None

                del back_traj_segment, noise
        trace_estimate /= num_runs  # average over runs
        integral_estimate = trace_estimate * time_interval

        estimate_loglike = last_sample_loglike + integral_estimate
        return last_sample_loglike.item(), estimate_loglike.item(), back_traj.shape[-1]

    def forward(self, x1, mask, mu, spk=None, offset=1e-5):
        # x1는 원본 멜, mu는 텍스트+멜 정렬

        print("FM forward") # forward에서는 실제 ODE 적분을 하진 않음.
        
        # t는 랜덤한 시간 시점. ODE가 continuous니까 랜덤하게 샘플링해서 학습에 다양성을 줌.
        t = torch.rand(x1.shape[0], dtype=x1.dtype, device=x1.device, requires_grad=False)
        t = torch.clamp(t, offset, 1.0 - offset)
        return self.loss_t(x1, mask, mu, t, spk)

    def loss_t(self, x1, mask, mu, t, spk=None):
        print("FM loss_t")
        
        # 포워드 식은 x(t)=μ(t)+σ(t)⋅ϵ
        # x1 :clean 데이터, μ(t)=t⋅x1, 𝜎(𝑡)=1−(1−σmin​)⋅t
        t_unsqueeze = t.unsqueeze(1).unsqueeze(1)
        mu_t = t_unsqueeze * x1
        sigma_t = 1 - (1-self.sigma_min) * t_unsqueeze # 노이즈 스케줄링
        x = mu_t + sigma_t * torch.randn_like(x1) # 중간샘플 x(t)

        # <실제 GT 벡터 필드 (데이터의 변화량)> u(t)=dx(t)​ / dt = f⋆(x(t),t) ➡️ 이게 타겟값! 이걸 예측하는 게 우리의 목적
        ut = (self.sigma_min - 1) / sigma_t * (x - mu_t) + x1

        # <예측된 벡터 필드 (데이터의 변화량)> 학습 목표는 vector_field_estimation ≈ ut가 되도록 하는 것! 즉, 모델이 벡터 필드를 예측하도록 학습하는 구조
        vector_field_estimation = self.estimator(x, mask, mu, t, spk)


        mse_loss = self.criterion(ut, vector_field_estimation)
        return mse_loss, x

    @property
    def nparams(self):
        return sum([p.numel() for p in self.parameters()])




####################
######## CFM #######
####################
class CFM(FM): # FM을 상속해서, 조건부 Flow Matching을 수행. x0을 noise로부터 생성하고 x1과의 conditional trajectory를 학습
    def __init__(self, n_feats, dim, spk_emb_dim=64, sigma_min: float = 0.1, pe_scale=1000, shift_by_mu=False, condition_by_mu=True,
                 net_type="unet", encoder_output_dim=80):
        super(CFM, self).__init__(n_feats, dim, spk_emb_dim, sigma_min, pe_scale, net_type=net_type, encoder_output_dim=encoder_output_dim)
        self.condition_by_mu = condition_by_mu
        self.shift_by_mu = shift_by_mu

    def sample_x0(self, mu, mask):
        x0 = torch.randn_like(mu)   # N(0, I)
        if self.shift_by_mu:
            x0 = x0 + mu  # N(mu, I)
        mask = mask.bool()
        x0.masked_fill_(~mask, 0)
        return x0

    def forward(self, x1, noise, mask, mu, spk=None, offset=1e-5):
        print("CFM Forward")
        t = torch.rand(x1.shape[0], dtype=x1.dtype, device=x1.device, requires_grad=False)
        t = torch.clamp(t, offset, 1.0 - offset)
        return self.loss_t(x1, noise, mask, mu, t, spk)


    def loss_t(self, x1, noise, mask, mu, t, spk=None):
        # construct noise (in CFM theory, this is x0)
        if noise is not None:
            x0 = noise  
        else:
            x0 = self.sample_x0(mu, mask)

        ut = x1 - x0
        t_unsqueeze = t.unsqueeze(1).unsqueeze(1)
        mu_t = t_unsqueeze * x1 + (1 - t_unsqueeze) * x0  # conditional Gaussian mean
        sigma_t = self.sigma_min

        x = mu_t + sigma_t * torch.randn_like(x1)  # sample p_t(x|x_0, x_1)
        if self.condition_by_mu:
            mu_input = mu
        else:
            mu_input = torch.zeros_like(mu)
        vector_field_estimation = self.estimator(x, mask, mu_input, t, spk)

        mse_loss = self.criterion(ut, vector_field_estimation)
        return mse_loss, x


    @torch.no_grad()
    def inference(self, z, mask, mu, n_timesteps, spk=None, solver="dopri5"):
        print("CFM Inference")
        super_class = super()
        if self.condition_by_mu:
            return super_class.inference(z, mask, mu, n_timesteps, spk=spk, solver=solver)
        else:
            return super_class.inference(z, mask, torch.zeros_like(mu), n_timesteps, spk=spk, solver=solver)
    


    # region 1. Naive
    @torch.no_grad()
    def naive_attack(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        intermediate_reverse, intermediate_reverse2 = self.naive_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        intermediate_denoise = self.naive_denoise(x1, intermediate_reverse, intermediate_reverse2, mask, mu, n_timesteps, terminal_time, spk=spk)
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise)**2).flatten(2).sum(dim=-1) 

    def naive_denoise(self, x1, intermediate_reverse, intermediate_reverse2, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        intermediate_vector_field = []
        h = terminal_time / n_timesteps

        for i, x0 in enumerate(intermediate_reverse2):
            t = ((i + 0) * h) * torch.ones(x1.shape[0], dtype=x1.dtype, device=x1.device)
            t_unsqueeze = t.unsqueeze(-1).unsqueeze(-1)

            # intermediate_reverse 반영한 xt 구하고
            mu_t = t_unsqueeze * x1 + (1 - t_unsqueeze) * x0
            sigma_t = self.sigma_min
            xt = mu_t + sigma_t * torch.randn_like(x0)

            # xt의 노이즈 추정
            if self.condition_by_mu:
                mu_input = mu
            else:
                mu_input = torch.zeros_like(mu)

            vector_field_estimation = self.estimator(xt, mask, mu_input, t, spk)
            intermediate_vector_field.append(vector_field_estimation)

        return intermediate_vector_field


    def naive_reverse(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        intermediate_result = []
        intermediate_result2 = []
        h = terminal_time / n_timesteps
        
        for i in range(n_timesteps):
            t = ((i + 0) * h) * torch.ones(x1.shape[0], dtype=x1.dtype, device=x1.device)
            t_unsqueeze = t.unsqueeze(-1).unsqueeze(-1)
            
            # 첫번째
            # eps = torch.randn_like(x1)
            # intermediate_result.append(eps)
            
            # 두번째
            x0 = self.sample_x0(mu, mask)
            ut = x1 - x0

            # mu_t = t_unsqueeze * x1 + (1 - t_unsqueeze) * x0
            # sigma_t = self.sigma_min
            # xt = mu_t + sigma_t * torch.randn_like(x1)

            intermediate_result.append(ut)
            intermediate_result2.append(x0)

        return intermediate_result, intermediate_result2



    # region 2. PIA
    @torch.no_grad()
    # 72.3 / 21.0
    # def pia_attack(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
    #     intermediate_reverse = self.pia_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
    #     intermediate_denoise = self.pia_denoise(intermediate_reverse, mask, mu, terminal_time, spk=spk)
    #     # del intermediate_reverse[-1]
    #     # del intermediate_denoise[0]
    #     intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
    #     return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)

    # def pia_reverse(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
    #     h = terminal_time / n_timesteps
    #     xt = x1 * mask
    #     intermediate_result = []

    #     i = 10
    #     t = ((i + 0) * h) * torch.ones(x1.shape[0], dtype=x1.dtype, device=x1.device)
    #     time = t.unsqueeze(-1).unsqueeze(-1)
                
    #     eps = self.estimator(x1, mask, mu, t, spk)

    #     for i in range(n_timesteps):
    #         t = (((i + 1) + 0) * h) * torch.ones(eps.shape[0], dtype=eps.dtype, device=eps.device)
    #         t_unsqueeze = t.unsqueeze(-1).unsqueeze(-1)

    #         #  xt = self.forward_diffusion_eps(x1, mask, mu, eps, t) # x0에 eps 노이즈를 추가
    #         x0 = self.sample_x0(mu, mask)
    #         mu_t = t_unsqueeze * x1 + (1 - t_unsqueeze) * x0
    #         sigma_t = self.sigma_min
    #         xt = mu_t + sigma_t * torch.randn_like(x0)

    #         intermediate_result.append(xt)

    #     return intermediate_result


    # def pia_denoise(self, xt_intermediate, mask, mu, terminal_time=0.12, spk=None):
    #     n_timesteps = len(xt_intermediate)
    #     h = terminal_time / n_timesteps
    #     intermediate_result = []

    #     for i, xt in enumerate(xt_intermediate):
    #         t = (((i + 1) + 0) * h) * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
    #         xt = xt * mask
    #         intermediate_result.append(mu - self.estimator(xt, mask, mu, t, spk)) # clean data로 이동

    #     return intermediate_result

    # 74.0 / 23.0
    # def pia_attack(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
    #     intermediate_reverse = self.pia_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
    #     intermediate_denoise = self.pia_denoise(intermediate_reverse, mask, mu, terminal_time, spk=spk)
    #     # del intermediate_reverse[-1]
    #     # del intermediate_denoise[0]
    #     intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
    #     return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)

    # def pia_reverse(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
    #     h = terminal_time / n_timesteps
    #     x0 = self.sample_x0(mu, mask).detach()  # 고정 초기 노이즈
    #     intermediate_result = [x1 * mask]
        
    #     for i in range(n_timesteps):
    #         t = (i * h) * torch.ones(x1.shape[0], device=x1.device)
    #         # 조건부 특징 처리[4]
    #         mu_input = mu if self.condition_by_mu else torch.zeros_like(mu)
    #         eps = self.estimator(intermediate_result[-1], mask, mu_input, t, spk)
    #         # 선형 보간[3]
    #         xt = x0 + (x1 - x0) * t.unsqueeze(-1).unsqueeze(-1)  
    #         intermediate_result.append(xt + eps * h)
    #     return intermediate_result


    # def pia_denoise(self, xt_intermediate, mask, mu, terminal_time=0.12, spk=None):
    #     n_timesteps = len(xt_intermediate)
    #     h = terminal_time / n_timesteps
    #     intermediate_result = []

    #     for i, xt in enumerate(xt_intermediate):
    #         t = (((i + 1) + 0) * h) * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
    #         xt = xt * mask
    #         intermediate_result.append(mu - self.estimator(xt, mask, mu, t, spk)) # clean data로 이동

    #     return intermediate_result


    @torch.no_grad()
    def pia_attack(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        intermediate_reverse = self.pia_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        intermediate_denoise = self.pia_denoise(intermediate_reverse, mask, mu, terminal_time, spk=spk)
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)

    def pia_reverse(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        h = terminal_time / n_timesteps
        x0 = self.sample_x0(mu, mask).detach()  # 고정 초기 노이즈

        i = 10
        t = ((i + 0) * h) * torch.ones(x0.shape[0], dtype=x0.dtype, device=x0.device)
        t_unsqueeze = t.unsqueeze(-1).unsqueeze(-1)

        mu_input = mu if self.condition_by_mu else torch.zeros_like(mu)
        eps = self.estimator(x1, mask, mu_input, t, spk)
        intermediate_result = []
        
        for i in range(n_timesteps):
            t = ((i + 1) * h) * torch.ones(x0.shape[0], dtype=x0.dtype, device=x0.device)
            t_unsqueeze = t.unsqueeze(-1).unsqueeze(-1)

            mu_input = mu if self.condition_by_mu else torch.zeros_like(mu)
            
            mu_t = t_unsqueeze * x1 + (1 - t_unsqueeze) * x0
            sigma_t = self.sigma_min
            # xt = mu_t + sigma_t * torch.randn_like(x0)
            xt = mu_t + sigma_t * eps

            intermediate_result.append(xt)
        return intermediate_result

    def pia_denoise(self, xt_intermediate, mask, mu, terminal_time=0.12, spk=None):
        n_timesteps = len(xt_intermediate)
        h = terminal_time / n_timesteps
        intermediate_result = []

        for i, xt in enumerate(xt_intermediate):
            t = (((i + 1) + 0) * h) * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
            xt = xt * mask
            
            intermediate_result.append(mu - self.estimator(xt, mask, mu, t, spk)) # 깨끗한 데이터를 복원하는 방식

        return intermediate_result


    @torch.no_grad()
    def pian_attack(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        intermediate_reverse = self.pian_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        intermediate_denoise = self.pian_denoise(intermediate_reverse, mask, mu, terminal_time, spk=spk)
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)

    def pian_reverse(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        h = terminal_time / n_timesteps
        x0 = self.sample_x0(mu, mask).detach()  # 고정 초기 노이즈

        i = 10
        t = ((i + 0) * h) * torch.ones(x0.shape[0], dtype=x0.dtype, device=x0.device)
        t_unsqueeze = t.unsqueeze(-1).unsqueeze(-1)

        mu_input = mu if self.condition_by_mu else torch.zeros_like(mu)
        eps = self.estimator(x1, mask, mu_input, t, spk)
        eps = eps / eps.abs().mean(list(range(1, eps.ndim)), keepdim=True) * (2 / torch.pi) ** 0.5 # 정규화

        intermediate_result = []
        
        for i in range(n_timesteps):
            t = ((i + 1) * h) * torch.ones(x0.shape[0], dtype=x0.dtype, device=x0.device)
            t_unsqueeze = t.unsqueeze(-1).unsqueeze(-1)

            mu_input = mu if self.condition_by_mu else torch.zeros_like(mu)
            
            mu_t = t_unsqueeze * x1 + (1 - t_unsqueeze) * x0
            sigma_t = self.sigma_min
            # xt = mu_t + sigma_t * torch.randn_like(x0)
            xt = mu_t + sigma_t * eps

            intermediate_result.append(xt)
        return intermediate_result

    def pian_denoise(self, xt_intermediate, mask, mu, terminal_time=0.12, spk=None):
        n_timesteps = len(xt_intermediate)
        h = terminal_time / n_timesteps
        intermediate_result = []

        for i, xt in enumerate(xt_intermediate):
            t = (((i + 1) + 0) * h) * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
            xt = xt * mask
            
            intermediate_result.append(mu - self.estimator(xt, mask, mu, t, spk)) # 깨끗한 데이터를 복원하는 방식

        return intermediate_result
    # def pian_attack(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
    #     intermediate_reverse = self.pian_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
    #     intermediate_denoise = self.pian_denoise(intermediate_reverse, mask, mu, terminal_time, spk=spk)
    #     intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
    #     return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)

    # def pian_reverse(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
    #     h = terminal_time / n_timesteps
    #     x0 = self.sample_x0(mu, mask).detach()
    #     intermediate_result = []
        
    #     for i in range(n_timesteps):
    #         t = (i * h) * torch.ones(x1.shape[0], device=x1.device)
    #         t = ((i + 0) * h) * torch.ones(x0.shape[0], dtype=x0.dtype, device=x0.device)
    #         t_unsqueeze = t.unsqueeze(-1).unsqueeze(-1)

    #         mu_input = mu if self.condition_by_mu else torch.zeros_like(mu)
    #         eps = self.estimator(x1, mask, mu_input, t, spk)
    #         eps = eps / eps.abs().mean(list(range(1, eps.ndim)), keepdim=True) * (2 / torch.pi) ** 0.5 # 정규화

    #         mu_t = t_unsqueeze * x1 + (1 - t_unsqueeze) * x0
    #         sigma_t = self.sigma_min
    #         xt = mu_t + sigma_t * torch.randn_like(x0)

    #         intermediate_result.append(xt)
    #     return intermediate_result

    # def pian_denoise(self, xt_intermediate, mask, mu, terminal_time=0.12, spk=None):
    #     n_timesteps = len(xt_intermediate)
    #     h = terminal_time / n_timesteps
    #     intermediate_result = []

    #     for i, xt in enumerate(xt_intermediate):
    #         t = (((i + 1) + 0) * h) * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
    #         xt = xt * mask
    #         intermediate_result.append(mu - self.estimator(xt, mask, mu, t, spk)) # clean data로 이동

    #     return intermediate_result


    # region 3. SecMI
    @torch.no_grad()
    # 70.8 / 13.5
    # def secmi_attack(self, x1, mask, mu, n_timesteps, terminal_time=1.0, spk=None):
    #     intermediate_reverse = self.secmi_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
    #     intermediate_denoise = self.secmi_denoise(intermediate_reverse, mask, mu, n_timesteps, terminal_time, spk=spk)
    #     del intermediate_reverse[-1]
    #     del intermediate_denoise[0]
    #     intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
    #     return ((intermediate_reverse - intermediate_denoise)**2).flatten(2).sum(dim=-1)

    # def secmi_denoise(self, intermediate_reverse, mask, mu, n_timesteps, terminal_time, spk=None):
    #     h = terminal_time / n_timesteps
    #     # t_span = torch.linspace(0, 1, n_timesteps+1)
    #     t_span = torch.linspace(1, 0, n_timesteps+1) 
    #     neural_ode = NeuralODE(self.ode_wrapper(mask, mu, spk), solver="euler", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
    #     intermediate_result = []

    #     for i in range(1, len(t_span)):
    #         x = intermediate_reverse[i - 1]
    #         t0 = t_span[i-1]
    #         t1 = t_span[i]
    #         eval_points, traj = neural_ode(x, torch.tensor([t0, t1]))
    #         dxt = (traj[-1] - traj[0]) * h
    #         x = (x - dxt) 
    #         intermediate_result.append(x)

    #     return intermediate_result

    # def secmi_reverse(self, x1, mask, mu, n_timesteps, terminal_time, spk=None):
    #     h = terminal_time / n_timesteps
    #     t_span = torch.linspace(0, 1, n_timesteps+1)
    #     neural_ode = NeuralODE(self.ode_wrapper(mask, mu, spk), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
    #     x = x1
    #     intermediate_result = []

    #     for i in range(1, len(t_span)):
    #         t0 = t_span[i-1]
    #         t1 = t_span[i]
    #         eval_points, traj = neural_ode(x, torch.tensor([t0, t1]))
    #         dxt = (traj[-1] - traj[0]) * h
    #         x = (x + dxt) 
    #         intermediate_result.append(x)
    #     return intermediate_result

    @torch.no_grad()
    def secmi_attack(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        intermediate_reverse = self.secmi_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        intermediate_denoise = self.secmi_denoise(intermediate_reverse, mask, mu, terminal_time, spk=spk)
        del intermediate_reverse[-1]
        del intermediate_denoise[0]
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)

    def secmi_reverse(self, x0, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        h = terminal_time / n_timesteps
        xt = x0 * mask
        intermediates = []

        scaling_factor = 1
        
        for i in range(n_timesteps):
            t = (i * h) * torch.ones(x0.shape[0], device=x0.device)
            vt = self.estimator(xt, mask, mu, t, spk)
            #dxt = 0.5 * (mu - xt - vt) * h

            dxt = mu - xt - vt  # CFM 공식에 따른 drift vector

            xt2 = xt + dxt * h * scaling_factor
            intermediates.append(xt2)
        return intermediates

    def secmi_denoise(self, xt_list, mask, mu, terminal_time=0.12, spk=None):
        n_timesteps = len(xt_list)
        h = terminal_time / n_timesteps
        intermediates = []
        
        scaling_factor = 1

        for i, xt in enumerate(xt_list):
            t = ((i+1) * h) * torch.ones(xt.shape[0], device=xt.device)
            vt = self.estimator(xt, mask, mu, t, spk)
            #dxt = 0.5 * (mu - xt - vt) * h

            dxt = mu - xt - vt  # CFM 공식에 따른 drift vector
            
            xt2 = xt - dxt * h * scaling_factor
            #xt = (xt - dxt) * mask  # Denoise 방향 업데이트
            intermediates.append(xt2)
        return intermediates

    
    # region 4. Feedback Loop
    @torch.no_grad()
    def feedback_attack(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        intermediate_reverse = self.feedback_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        intermediate_denoise = self.feedback_denoise(intermediate_reverse, mask, mu, terminal_time, spk=spk)
        del intermediate_reverse[-1]
        del intermediate_denoise[0]
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)

    def feedback_reverse(self, x0, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        h = terminal_time / n_timesteps
        xt = x0 * mask
        intermediates = []

        feedback_factor = 1
        xt_prev = None  # 이전 xt 저장 (피드백용)
        
        for i in range(n_timesteps):
            t = (i * h) * torch.ones(x0.shape[0], device=x0.device)
            vt = self.estimator(xt, mask, mu, t, spk)
            #dxt = 0.5 * (mu - xt - vt) * h

            dxt = mu - xt - vt  # CFM 공식에 따른 drift vector

            if xt_prev is not None:
                dxt = dxt + feedback_factor * (xt_prev - xt)

            xt_prev = xt.clone()  # 현재 xt를 저장하여 다음 스텝에서 사용

            xt2 = xt + dxt * h
            intermediates.append(xt2)
        return intermediates

    def feedback_denoise(self, xt_list, mask, mu, terminal_time=0.12, spk=None):
        n_timesteps = len(xt_list)
        h = terminal_time / n_timesteps
        intermediates = []

        for i, xt in enumerate(xt_list):
            t = ((i+1) * h) * torch.ones(xt.shape[0], device=xt.device)
            vt = self.estimator(xt, mask, mu, t, spk)
            #dxt = 0.5 * (mu - xt - vt) * h

            dxt = mu - xt - vt  # CFM 공식에 따른 drift vector
            
            xt2 = xt - dxt * h
            #xt = (xt - dxt) * mask  # Denoise 방향 업데이트
            intermediates.append(xt2)
        return intermediates




    # region 5. Scaling
    @torch.no_grad()
    def scaling_attack(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        intermediate_reverse = self.scaling_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        intermediate_denoise = self.scaling_denoise(intermediate_reverse, mask, mu, terminal_time, spk=spk)
        del intermediate_reverse[-1]
        del intermediate_denoise[0]
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)

    def scaling_reverse(self, x0, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        h = terminal_time / n_timesteps
        xt = x0 * mask
        intermediates = []

        scaling_factor = 0.001
        
        for i in range(n_timesteps):
            t = (i * h) * torch.ones(x0.shape[0], device=x0.device)
            vt = self.estimator(xt, mask, mu, t, spk)
            #dxt = 0.5 * (mu - xt - vt) * h

            dxt = mu - xt - vt  # CFM 공식에 따른 drift vector

            xt2 = xt + dxt * h * scaling_factor
            intermediates.append(xt2)
        return intermediates

    def scaling_denoise(self, xt_list, mask, mu, terminal_time=0.12, spk=None):
        n_timesteps = len(xt_list)
        h = terminal_time / n_timesteps
        intermediates = []
        
        scaling_factor = 0.001

        for i, xt in enumerate(xt_list):
            t = ((i+1) * h) * torch.ones(xt.shape[0], device=xt.device)
            vt = self.estimator(xt, mask, mu, t, spk)
            #dxt = 0.5 * (mu - xt - vt) * h

            dxt = mu - xt - vt  # CFM 공식에 따른 drift vector
            
            xt2 = xt - dxt * h * scaling_factor
            #xt = (xt - dxt) * mask  # Denoise 방향 업데이트
            intermediates.append(xt2)
        return intermediates



    # region 6. Hybrid
    @torch.no_grad()
    def hybrid_attack1(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        threshold = 11
        pia_timesteps = 100
        naive_timesteps = 100

        secmi_intermediate_reverse = self.secmi_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        secmi_intermediate_denoise = self.secmi_denoise(secmi_intermediate_reverse, mask, mu, terminal_time, spk=spk)
        
        naive_intermediate_reverse, intermediate_reverse2 = self.naive_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        naive_intermediate_denoise = self.naive_denoise(x1, naive_intermediate_reverse, intermediate_reverse2, mask, mu, n_timesteps, terminal_time, spk=spk)
        
        intermediate_reverse = naive_intermediate_reverse[:threshold] + secmi_intermediate_reverse[threshold:]
        intermediate_denoise = naive_intermediate_denoise[:threshold] + secmi_intermediate_denoise[threshold:]

        intermediate_reverse = torch.stack(intermediate_reverse)
        intermediate_denoise = torch.stack(intermediate_denoise)
        
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)

    def hybrid_attack2(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        threshold = 21
        pia_timesteps = 100
        naive_timesteps = 100

        secmi_intermediate_reverse = self.secmi_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        secmi_intermediate_denoise = self.secmi_denoise(secmi_intermediate_reverse, mask, mu, terminal_time, spk=spk)
        
        naive_intermediate_reverse, intermediate_reverse2 = self.naive_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        naive_intermediate_denoise = self.naive_denoise(x1, naive_intermediate_reverse, intermediate_reverse2, mask, mu, n_timesteps, terminal_time, spk=spk)
        
        intermediate_reverse = naive_intermediate_reverse[:threshold] + secmi_intermediate_reverse[threshold:]
        intermediate_denoise = naive_intermediate_denoise[:threshold] + secmi_intermediate_denoise[threshold:]

        intermediate_reverse = torch.stack(intermediate_reverse)
        intermediate_denoise = torch.stack(intermediate_denoise)
        
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)
    
    def hybrid_attack3(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        threshold = 31
        pia_timesteps = 100
        naive_timesteps = 100

        secmi_intermediate_reverse = self.secmi_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        secmi_intermediate_denoise = self.secmi_denoise(secmi_intermediate_reverse, mask, mu, terminal_time, spk=spk)
        
        naive_intermediate_reverse, intermediate_reverse2 = self.naive_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        naive_intermediate_denoise = self.naive_denoise(x1, naive_intermediate_reverse, intermediate_reverse2, mask, mu, n_timesteps, terminal_time, spk=spk)
        
        intermediate_reverse = naive_intermediate_reverse[:threshold] + secmi_intermediate_reverse[threshold:]
        intermediate_denoise = naive_intermediate_denoise[:threshold] + secmi_intermediate_denoise[threshold:]

        intermediate_reverse = torch.stack(intermediate_reverse)
        intermediate_denoise = torch.stack(intermediate_denoise)
        
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)

    def hybrid_attack4(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        threshold = 41
        pia_timesteps = 100
        naive_timesteps = 100

        secmi_intermediate_reverse = self.secmi_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        secmi_intermediate_denoise = self.secmi_denoise(secmi_intermediate_reverse, mask, mu, terminal_time, spk=spk)
        
        naive_intermediate_reverse, intermediate_reverse2 = self.naive_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        naive_intermediate_denoise = self.naive_denoise(x1, naive_intermediate_reverse, intermediate_reverse2, mask, mu, n_timesteps, terminal_time, spk=spk)
        
        intermediate_reverse = naive_intermediate_reverse[:threshold] + secmi_intermediate_reverse[threshold:]
        intermediate_denoise = naive_intermediate_denoise[:threshold] + secmi_intermediate_denoise[threshold:]

        intermediate_reverse = torch.stack(intermediate_reverse)
        intermediate_denoise = torch.stack(intermediate_denoise)
        
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)    
    
    def hybrid_attack5(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        threshold = 51
        pia_timesteps = 100
        naive_timesteps = 100

        secmi_intermediate_reverse = self.secmi_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        secmi_intermediate_denoise = self.secmi_denoise(secmi_intermediate_reverse, mask, mu, terminal_time, spk=spk)
        
        naive_intermediate_reverse, intermediate_reverse2 = self.naive_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        naive_intermediate_denoise = self.naive_denoise(x1, naive_intermediate_reverse, intermediate_reverse2, mask, mu, n_timesteps, terminal_time, spk=spk)
        
        intermediate_reverse = naive_intermediate_reverse[:threshold] + secmi_intermediate_reverse[threshold:]
        intermediate_denoise = naive_intermediate_denoise[:threshold] + secmi_intermediate_denoise[threshold:]

        intermediate_reverse = torch.stack(intermediate_reverse)
        intermediate_denoise = torch.stack(intermediate_denoise)
        
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)

    def hybrid_attack7(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        threshold = 71
        pia_timesteps = 100
        naive_timesteps = 100

        secmi_intermediate_reverse = self.secmi_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        secmi_intermediate_denoise = self.secmi_denoise(secmi_intermediate_reverse, mask, mu, terminal_time, spk=spk)
        
        naive_intermediate_reverse, intermediate_reverse2 = self.naive_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        naive_intermediate_denoise = self.naive_denoise(x1, naive_intermediate_reverse, intermediate_reverse2, mask, mu, n_timesteps, terminal_time, spk=spk)
        
        intermediate_reverse = naive_intermediate_reverse[:threshold] + secmi_intermediate_reverse[threshold:]
        intermediate_denoise = naive_intermediate_denoise[:threshold] + secmi_intermediate_denoise[threshold:]

        intermediate_reverse = torch.stack(intermediate_reverse)
        intermediate_denoise = torch.stack(intermediate_denoise)
        
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)    
    
    def hybrid_attack8(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        threshold = 81
        pia_timesteps = 100
        naive_timesteps = 100

        secmi_intermediate_reverse = self.secmi_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        secmi_intermediate_denoise = self.secmi_denoise(secmi_intermediate_reverse, mask, mu, terminal_time, spk=spk)
        
        naive_intermediate_reverse, intermediate_reverse2 = self.naive_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        naive_intermediate_denoise = self.naive_denoise(x1, naive_intermediate_reverse, intermediate_reverse2, mask, mu, n_timesteps, terminal_time, spk=spk)
        
        intermediate_reverse = naive_intermediate_reverse[:threshold] + secmi_intermediate_reverse[threshold:]
        intermediate_denoise = naive_intermediate_denoise[:threshold] + secmi_intermediate_denoise[threshold:]

        intermediate_reverse = torch.stack(intermediate_reverse)
        intermediate_denoise = torch.stack(intermediate_denoise)
        
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)

    def hybrid_attack9(self, x1, mask, mu, n_timesteps, terminal_time=0.12, spk=None):
        threshold = 91
        pia_timesteps = 100
        naive_timesteps = 100

        secmi_intermediate_reverse = self.secmi_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        secmi_intermediate_denoise = self.secmi_denoise(secmi_intermediate_reverse, mask, mu, terminal_time, spk=spk)
        
        naive_intermediate_reverse, intermediate_reverse2 = self.naive_reverse(x1, mask, mu, n_timesteps, terminal_time, spk=spk)
        naive_intermediate_denoise = self.naive_denoise(x1, naive_intermediate_reverse, intermediate_reverse2, mask, mu, n_timesteps, terminal_time, spk=spk)
        
        intermediate_reverse = naive_intermediate_reverse[:threshold] + secmi_intermediate_reverse[threshold:]
        intermediate_denoise = naive_intermediate_denoise[:threshold] + secmi_intermediate_denoise[threshold:]

        intermediate_reverse = torch.stack(intermediate_reverse)
        intermediate_denoise = torch.stack(intermediate_denoise)
        
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1) 



class OTCFM(CFM):
    def __init__(self, n_feats, dim, spk_emb_dim=64, sigma_min: float = 0.1, pe_scale=1000, method="exact",
                 net_type="unet", encoder_output_dim=80):
        raise NotImplementedError("CFM with Optimal Transport Sampling is currently not supported")
        super(OTCFM, self).__init__(n_feats, dim, spk_emb_dim, sigma_min, pe_scale, shift_by_mu=False,
                                    net_type=net_type, encoder_output_dim=encoder_output_dim)
        assert method == 'exact', "OT methods except 'exact' are not considered currently"
        self.ot_sampler = OTPlanSampler(method=method)

    def loss_t(self, x1, noise, mask, mu, t, spk=None):
        # construct noise (in CFM theory, this is x0)
        if noise is not None:
            x0 = noise
        else:
            x0 = self.sample_x0(mu, mask)

        # x1 and x0 shape is [B, 80, L]
        B, D, L = x0.shape
        new_x0 = torch.zeros_like(x1)
        new_x1 = torch.zeros_like(x0)
        for l in range(L):
            sub_x0, sub_x1, i, j = self.ot_sampler.sample_plan_with_index(x1[..., l], x0[..., l])
            index_that_would_sort_i = np.argsort(i)  # To keep i and j synchronized for each position in L
            i = i[index_that_would_sort_i]
            j = j[index_that_would_sort_i]

            new_x0[..., l] = x1[i, :, l]
            new_x1[..., l] = x0[j, :, l]

        x1 = new_x0
        x0 = new_x1

        ut = x1 - x0  # conditional vector field. This is actually x0 - x1 in paper.
        t_unsqueeze = t.unsqueeze(1).unsqueeze(1)
        mu_t = t_unsqueeze * x1 + (1 - t_unsqueeze) * x0  # conditional Gaussian mean
        sigma_t = self.sigma_min
        x = mu_t + sigma_t * torch.randn_like(x1)  # sample p_t(x|x_0, x_1)
        vector_field_estimation = self.estimator(x, mask, mu, t, spk)
        mse_loss = self.criterion(ut, vector_field_estimation)
        return mse_loss, x
