import os.path
import warnings

import torch
from matplotlib import pyplot as plt
from tqdm import tqdm

from mog_util import p_t, schedule, grad_log_p_t, grad_log_p_t_1d, abdot, v_t, p_t_1d, v_t_1d
from sde_lib import *

# from data_config import samples_0, samples_1, save_path, pis, d, mus, sigmas, probability_flow
# from data_config import config

class Flow(object):
    def __init__(self, model, trajectory, config):
        self.model = model
        self.trajectory = trajectory
        self.name = self.__class__.__name__
        self.name += f'_{self.trajectory}'
        self.config = config

        # for re-parameterization, these two function should be the inverse of each other in x and output.
        # it says, how do you want to modify the output of model(x_t, t)
        # use identical as default
        self.from_pred_to_target = lambda output, x, t: output
        self.from_target_to_pred = lambda output, x, t: output

    def get_random_train_pairs(self, z0=None, z1=None, batch_size=None):
        # if batch_size == z1.shape[0]:
        #     return z0, z1
        # indices_z0 = torch.randperm(len(z0))[:batch_size]
        # indices_z1 = torch.randperm(len(z1))[:batch_size]
        # return z0[indices_z0], z1[indices_z1]
        # 
        length_z0 = len(z0)
        length_z1 = len(z1)

        #  batch_size  z1 
        if batch_size == length_z1:
            return z0, z1

        #  randint  randperm
        indices_z0 = torch.randint(0, length_z0, (batch_size,), device=z0.device)
        indices_z1 = torch.randint(0, length_z1, (batch_size,), device=z1.device)

        # 
        return z0[indices_z0], z1[indices_z1]

    def get_train_tuple(self, z0=None, z1=None, t=None):
        raise NotImplementedError

    @torch.no_grad()
    def sample_ode(self, model, z0, N):
        raise NotImplementedError

    @torch.no_grad()
    def p(self, x, t):
        if self.config['d'] == 1:
            return p_t_1d(x, *schedule(t, type=self.trajectory), self.config['mus'],
                       self.config['sigmas']**2, self.config['pis'])
        else:
            return p_t(x, *schedule(t, type=self.trajectory), self.config['mus'],
                                    self.config['sigmas']**2, self.config['pis'], log=False)
    @torch.no_grad()
    def score(self, x, t):
        if self.config['d'] == 1:
            return grad_log_p_t_1d(x, *schedule(t, type=self.trajectory), self.config['mus'],
                                self.config['sigmas']**2, self.config['pis'])
        else:
            return grad_log_p_t(x, *schedule(t, type=self.trajectory), self.config['mus'],
                       self.config['sigmas']**2, self.config['pis'])

    # @torch.no_grad()
    def velocity(self, x, t):
        if self.config['d'] == 1:
            return v_t_1d(x, *schedule(t, type=self.trajectory), *abdot(t, type=self.trajectory),
                          self.config['mus'],
                          self.config['sigmas']**2, self.config['pis'])
        else:
            return v_t(x, *schedule(t, type=self.trajectory), *abdot(t, type=self.trajectory),
                       self.config['mus'],
                       self.config['sigmas']**2, self.config['pis'])

    @torch.no_grad()
    def gt(self, x, t):
        raise NotImplementedError

    @torch.no_grad()
    def target_gt(self, x, t):
        return self.from_pred_to_target(self.gt(x, t), x, t)

    # @torch.no_grad()
    def pred(self, x, t):
        # 1
        return self.from_target_to_pred(self.model(x, t), x, t)
        # 2
        # return self.from_target_to_pred(self.model(x, self.t_inverse(t)), x, t)
        # 3
        # return self.from_target_to_pred(self.model(x, self.t_trans(t)), x, t)
        # you may use the following two line to check the implementations of re-parametrization.
        # return self.gt(x, t)
        # 1
        # return self.from_target_to_pred(self.target_gt(x, t), x, t)
        # 2
        # t = self.t_inverse(t)
        # return self.from_target_to_pred(self.target_gt(x, t), x, t)

    @torch.no_grad()
    def empirical_risk(self, z0, z1, t):
        z_t = []
        risk = []

        for i in range(2000):
            _z0, _z1 = self.get_random_train_pairs(z0, z1, batch_size=z0.shape[0])
            _z_t, _t, target = self.get_train_tuple(_z0, _z1, t)[:3]
            _risk = (target - self.target_gt(_z_t, t)) ** 2
            z_t.append(_z_t)
            risk.append(_risk)
        z_t = torch.cat(z_t, dim=0)
        risk = torch.cat(risk, dim=0)
        return z_t, risk

    @torch.no_grad()
    def empirical_target(self, z0, z1, t):
        z_t = []
        targets = []
        batch_size = 20480
        iter=20
        t = t[:1].repeat(batch_size, 1)

        pbar = tqdm(total=iter, leave=False)
        for i in range(iter):
            _z0, _z1 = self.get_random_train_pairs(z0, z1, batch_size)
            _z_t, _t, target = self.get_train_tuple(_z0, _z1, t)[:3]
            _target = self.from_target_to_pred(target,  _z_t, _t)
            z_t.append(_z_t)
            targets.append(_target)
            pbar.update(1)
        pbar.close()
        z_t = torch.cat(z_t, dim=0)
        targets = torch.cat(targets, dim=0)
        return z_t, targets

    def train(self, optimizer, z0, z1, batchsize, inner_iters):
        pbar = tqdm(total=inner_iters, leave=False)
        loss_curve = []
        for i in range(inner_iters + 1):
            optimizer.zero_grad()
            _z0, _z1 = self.get_random_train_pairs(z0, z1, batchsize)
            z_t, t, target = self.get_train_tuple(z0=_z0, z1=_z1)

            pred = self.model(z_t, t)
            loss = (target - pred).view(pred.shape[0], -1).abs().pow(2).sum(dim=1)
            loss = loss.mean()
            loss.backward()

            optimizer.step()
            loss_curve.append(loss.item())  ## to store the loss curve
            pbar.update(1)
        pbar.close()
        save_path= self.config['save_path']
        torch.save(self.model.state_dict(), f'{save_path}/{self.name}/model.pth')
        return loss_curve

    @torch.no_grad()
    def predicted_velocity(self, x, t):
        raise NotImplementedError

class RectifiedFlow(Flow):
    def __init__(self, model, trajectory, config, reparam=None):
        super(RectifiedFlow, self).__init__(model, trajectory, config)
        if self.trajectory == 'vp':
            self.sde = VPSDE()
        elif self.trajectory == 'subvp':
            self.sde = subVPSDE()
        elif self.trajectory == 'linear':
            self.sde = Linear()
        else:
            raise NotImplementedError
        assert isinstance(self.sde, SDE)
        # reverse t. Smaller t, more likely gaussian.
        original_marginal_prob = self.sde.marginal_prob
        self.sde.marginal_prob = lambda x, t, return_coef = False: original_marginal_prob(x, 1 - t, return_coef)
        original_sde = self.sde.sde
        self.sde.sde = lambda x, t : original_sde(x, 1 - t)

        self.t_trans = lambda t: t
        self.t_trans_dot = lambda t: 1
        self.t_inverse = lambda t: t

        self.from_pred_to_target = lambda output, x, t: output
        self.from_target_to_pred = lambda output, x, t: output
        self.reparam = None
        if reparam is not None:
            self.reparam = reparam
            self.name += f'_{reparam}'
        reparameterize(self, reparam)

    def get_train_tuple(self, z0=None, z1=None, t=None):
        from mog_util import schedule, abdot
        if t is None:
            t = torch.rand(z1.shape).to('cuda') * (1 - 1e-5)
        # else:
        #     t = t.repeat(z1.shape[0])[:, None]
        alpha_t, beta_t = schedule(t, self.trajectory)
        z_t = alpha_t * z1 + beta_t * z0
        alpha_t_dot, beta_t_dot = abdot(t, self.trajectory)
        target = alpha_t_dot * z1 + beta_t_dot * z0
        target = self.from_pred_to_target(target, z_t, t)
        # to avoid numerical issues
        if self.reparam == 'x_pred':
            target = z1
        if self.reparam == 'epsilon_pred':
            target = z0
        return z_t, t, target

    @torch.no_grad()
    def sample_ode(self, z0, N, use_gt=False):
        ### NOTE: Use Euler method to sample from the learned flow
        dt = 1./N
        traj = [] # to store the trajectory
        z = z0.detach().clone()
        batchsize = z.shape[0]

        traj.append(z.detach().clone())
        for i in range(0, N):
            t = torch.ones((batchsize,1)).to('cuda') * i / N * (1-1e-2)+1e-2
            if use_gt:
                pred = self.gt(z, t)
            else:
                pred = self.pred(z, t)
            # if i < 90: continue
            # if not use_gt:
            #     if i < 100:
            #         pred = self.pred(z, t)
            #     else:
            #         pred = self.gt(z, t)

            if not self.config['probability_flow']:
                drift, diffusion = self.sde.sde(z, t)
                score = (-pred - drift) / ( - diffusion ** 2 * 0.5)
                drift = drift - diffusion ** 2 * score * 1
                z_mean = z - drift * dt
                z = z_mean + diffusion * torch.sqrt(torch.tensor(dt)) * torch.randn_like(z)
            else:
                z = z.detach().clone() + pred * dt

            traj.append(z.detach().clone())
        traj = torch.stack(traj)
        return traj

    def gt(self, x, t):
        return self.velocity(x, t)

    @torch.no_grad()
    def predicted_velocity(self, x, t):
        return self.pred(x, t)

class ScoreMatch(Flow):
    def __init__(self, model, trajectory, config, reparam=None):
        super(ScoreMatch, self).__init__(model, trajectory,config)
        if self.trajectory == 'vp':
            self.sde = VPSDE()
        elif self.trajectory == 'subvp':
            self.sde = subVPSDE()
        elif self.trajectory == 'linear':
            self.sde = Linear()
        else:
            raise NotImplementedError
        assert isinstance(self.sde, SDE)

        self.t_trans = lambda t: t
        self.t_trans_dot = lambda t: 1
        self.t_inverse = lambda t: t
        # self.t_trans = lambda t: torch.log(4*t + 1) / torch.log(torch.tensor(5).to('cuda'))
        # self.t_trans_dot = lambda t: 4 / torch.log(torch.tensor(5).to('cuda'))/ (4 * t + 1)
        # self.t_trans = lambda t: torch.log(0.5*t + 1) / torch.log(torch.tensor(1.5).to('cuda'))
        # self.t_trans_dot = lambda t: 0.5 / torch.log(torch.tensor(1.5).to('cuda'))/ (0.5 * t + 1)
        # self.t_trans = lambda t: torch.log(8*t + 1) / torch.log(torch.tensor(9).to('cuda'))
        # self.t_trans_dot = lambda t: 8 / torch.log(torch.tensor(9).to('cuda')) / (8 * t + 1)
        # self.t_inverse = lambda t: (torch.exp(torch.log(torch.tensor(9).to('cuda')) * t) - 1) / 8

        # reverse t. Smaller t, more likely gaussian.
        original_marginal_prob = self.sde.marginal_prob
        # 1
        self.sde.marginal_prob = lambda x, t, return_coef = False: original_marginal_prob(x, 1 - self.t_trans(t), return_coef)
        # 2
        # self.sde.marginal_prob = lambda x, t : original_marginal_prob(x, 1 - t)

        original_sde = self.sde.sde
        # 1
        self.sde.sde = lambda x, t : original_sde(x, 1 - self.t_trans(t))
        # 2
        # self.sde.sde = lambda x, t : original_sde(x, 1 - t)

        self.reparam = reparam
        if self.reparam is not None:
            self.name += f'_{self.reparam}'
        reparameterize(self, reparam)

    # @torch.no_grad()
    def predicted_velocity(self, x, t):
        rsde = self.sde.reverse(self.pred, probability_flow=True)
        drift, _ = rsde.sde(x, t)
        return -drift

    def one_step_x0_pred(self, zt, t):
        _t = t.view(-1, *([1] * (zt.dim() - 1)))  # (N, 1, 1, ..., 1)
        alpha_t, beta_t = schedule(_t, self.trajectory)
        alpha_t_dot, beta_t_dot = abdot(_t, self.trajectory)

        pred = self.gt(zt, t)
        # return (beta_t_dot * zt - beta_t * pred) / (alpha_t * beta_t_dot - alpha_t_dot * beta_t)
        return (zt + pred * beta_t ** 2) / alpha_t

    @torch.no_grad()
    def sample_ode(self, z0, N, use_gt=False, batch_size=None, corrector=False):
        ### NOTE: Use Euler method to sample from the learned flow
        dt = 1./N
        all_traj = [] # to store the trajectory
        z = z0.detach().clone()
        # Split z0 into smaller batches
        if batch_size is None: batch_size = z0.shape[0]
        z0_batches = torch.split(z0, batch_size)

        for z0_batch in z0_batches:
            traj = []  # to store the trajectory for the current batch
            z = z0_batch.detach().clone()
            traj.append(z.detach().clone())
            if use_gt:
                rsde = self.sde.reverse(self.gt, probability_flow=self.config['probability_flow'])
            else:
                rsde = self.sde.reverse(self.pred, probability_flow=self.config['probability_flow'])


            for i in range(N):
                t = torch.ones((z0_batch.shape[0], 1)).to(z0.device) * i / N * (1 - 1e-2) + 1e-2

                # if not use_gt:
                #     if i < 100:
                #         rsde = self.sde.reverse(self.gt, probability_flow=probability_flow)
                #     else:
                #         rsde = self.sde.reverse(self.pred, probability_flow=probability_flow)

                noise = torch.randn_like(z)
                drift, diffusin = rsde.sde(z, t)



                # x_0 = self.one_step_x0_pred(z, t)
                # x0_grad = (x_0 + torch.ones_like(x_0))
                # x0_grad[:, :1] = 0
                # z = z + 0.035 * x0_grad



                # if i % 100 == 0 or i  == 999:
                #     x_0 = self.one_step_x0_pred(z, t)
                #     z1 = x_0
                #     plt.scatter(z1[:50000, 0].cpu().numpy(), z1[:50000, 1].cpu().numpy(), label=r'True Data', alpha=0.1,
                #                 s=1)
                #     plt.xlabel(r'$x$', fontsize=12)
                #     plt.ylabel(r'$y$', fontsize=12)
                #     plt.xticks(fontsize=12)
                #     plt.yticks(fontsize=12)
                #     plt.tight_layout()
                #     plt.legend()
                #     # plt.savefig(f'{save_path}/{rectified_flow.name}/distribution.png', dpi=300)
                #     plt.show()
                #     plt.close()

                # if not use_gt:
                #     drift = drift + torch.ones_like(drift) * 0.2

                # 1
                t_dot = self.t_trans_dot(t)
                # 2
                # t_dot = 1
                z_mean = z - drift * dt * t_dot
                z = z_mean + diffusin * torch.sqrt(torch.tensor(dt * t_dot)) * noise

                if corrector:
                    _, std = self.sde.marginal_prob(z, t)
                    step_size = std ** 2/500
                    z = z + step_size / 2 * self.pred(z,t) + step_size ** 0.5 * torch.randn_like(z)

                # if i == N - 1:
                traj.append(z.detach().clone())
            all_traj.append(torch.stack(traj, dim=0))
            # Concatenate all trajectories from each batch
        full_traj = torch.cat(all_traj, dim=1)  # Concatenate along the batch dimension
        return full_traj

    @torch.no_grad()
    def sample_ddim(self, z0, N, use_gt=False, batch_size=None, corrector=False):
        ### NOTE: Use Euler method to sample from the learned flow
        dt = 1./N
        all_traj = [] # to store the trajectory
        z = z0.detach().clone()
        # Split z0 into smaller batches
        if batch_size is None: batch_size = z0.shape[0]
        z0_batches = torch.split(z0, batch_size)

        for z0_batch in z0_batches:
            traj = []  # to store the trajectory for the current batch
            z = z0_batch.detach().clone()
            traj.append(z.detach().clone())
            if use_gt:
                rsde = self.sde.reverse(self.gt, probability_flow=self.config['probability_flow'])
            else:
                rsde = self.sde.reverse(self.pred, probability_flow=self.config['probability_flow'])

            for i in reversed(range(N)):
                t = torch.ones((z0_batch.shape[0], 1)).to(z0.device) * i / N * (1 - 2e-2) + 2e-2

                log_alpha_s, log_alpha_t = self.sde.marginal_log_mean_coeff(t), self.sde.marginal_log_mean_coeff(t-dt)
                # sigma_t = self.sde.marginal_std(t)
                sigma_t_next = self.sde.marginal_std(t-dt)
                lambda_t, lambda_t_next = self.sde.marginal_lambda(t), self.sde.marginal_lambda(t-dt)
                h = lambda_t_next - lambda_t
                phi_1 = torch.expm1(h)
                z = (
                        torch.exp(log_alpha_t - log_alpha_s) * z
                        - (sigma_t_next * phi_1) * (self.model(z, 1-t))
                )
                traj.append(z.detach().clone())
            all_traj.append(torch.stack(traj, dim=0))
            # Concatenate all trajectories from each batch
        full_traj = torch.cat(all_traj, dim=1)  # Concatenate along the batch dimension
        return full_traj

    @torch.no_grad()
    def sample_second_order_dpm(self, z0, N, use_gt=False, batch_size=None, corrector=False):
        ### NOTE: Use Euler method to sample from the learned flow
        dt = 1. / N
        all_traj = []  # to store the trajectory
        z = z0.detach().clone()
        # Split z0 into smaller batches
        if batch_size is None: batch_size = z0.shape[0]
        z0_batches = torch.split(z0, batch_size)

        for z0_batch in z0_batches:
            traj = []  # to store the trajectory for the current batch
            z = z0_batch.detach().clone()
            traj.append(z.detach().clone())
            if use_gt:
                rsde = self.sde.reverse(self.gt, probability_flow=self.config['probability_flow'])
            else:
                rsde = self.sde.reverse(self.pred, probability_flow=self.config['probability_flow'])

            for i in reversed(range(N)):
                s = torch.ones((z0_batch.shape[0], 1)).to(z0.device) * i / N * (1 - 2e-2) + 2e-2
                t = s - dt
                r1 = 0.5
                ns = self.sde
                lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
                h = lambda_t - lambda_s
                lambda_s1 = lambda_s + r1 * h
                s1 = ns.inverse_lambda(lambda_s1)
                log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
                    s1), ns.marginal_log_mean_coeff(t)
                sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
                alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)

                phi_11 = torch.expm1(r1 * h)
                phi_1 = torch.expm1(h)

                model_s = self.model(z, 1-s)
                x_s1 = (
                        torch.exp(log_alpha_s1 - log_alpha_s) * z
                        - (sigma_s1 * phi_11) * model_s
                )
                model_s1 = self.model(x_s1, 1-s1)
                z = (
                        torch.exp(log_alpha_t - log_alpha_s) * z
                        - (sigma_t * phi_1) * model_s
                        - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
                )

                # z = (
                #         torch.exp(log_alpha_t - log_alpha_s) * z
                #         - (sigma_t_next * phi_1) * (self.model(z, 1 - t))
                # )
                traj.append(z.detach().clone())
            all_traj.append(torch.stack(traj, dim=0))
            # Concatenate all trajectories from each batch
        full_traj = torch.cat(all_traj, dim=1)  # Concatenate along the batch dimension
        return full_traj


    def get_train_tuple(self, z0=None, z1=None, t=None):
        if t is None:
            # when t = 1, the score explores (-inf or inf).
            # Neural network cannot learn.
            t = torch.rand((z1.shape[0], 1)).to(z1.device) * (1 - 1e-4)
            if self.config['train_only_data_and_gaussian']:
                t = torch.randint(0, 2, (z1.shape[0], 1)).to(z1.device)
                t = t * 0.9
                # t = t * 0
        else:
            t = t * torch.ones((z1.shape[0], 1)).to(z1.device) * (1 - 1e-4)
        # 1
        mean, std = self.sde.marginal_prob(z1, t)
        # 2
        # mean, std = self.sde.marginal_prob(z1,self.t_trans(t))
        # 3
        # mean, std = self.sde.marginal_prob(z1,self.t_inverse(t))

        z = torch.randn_like(z1)
        # z = torch.randn((z1.shape[0], 20)).to('cuda')
        z_t = mean + std * z
        # z_t[:, 1:] = z[:, 1:]
        target = self.from_pred_to_target(-z / std, z_t, t)
        loss_weight = 1
        if self.reparam == 'epsilon_pred':
            # maybe this is good to avoid numerical unstable, i.e.,
            # z / std * std may not be z exactly when std is very small.
            # drift, diffusion = self.sde.sde(z_t, t)
            # loss_weight = diffusion**2
            target = z #+ 0.3 * torch.randn_like(z)
        if self.reparam == 'x_pred':
            # maybe this is good to avoid numerical unstable, i.e.,
            # z / std * std may not be z when std is very small.
            target = z1
            coef, diffusion = self.sde.marginal_prob(z1, t, return_coef=True)
            # loss_weight = std / log_mean_coeff
            loss_weight = std / torch.log(coef)
        if self.reparam in [None, '', 'negative']:
            loss_weight = std

        return z_t, t, target, loss_weight

    def train(self, optimizer, z0, z1, batchsize, inner_iters):
        loss_curve = []
        pbar = tqdm(total=inner_iters, leave=False)
        gaussian_errors = []
        gaussian_errors_gt = []
        data_errors = []
        data_errors_gt = []
        ts = []
        for i in range(inner_iters + 1):
            # if i % 100 ==0:
            #     errors, gts = eval_flow(self, z0)
            #     gaussian_error, data_error = errors
            #     gaussian_error_gt, data_error_gt = gts
            #     gaussian_errors.append(gaussian_error)
            #     data_errors.append(data_error)
            #     gaussian_errors_gt.append(gaussian_error_gt)
            #     data_errors_gt.append(data_error_gt)
            #     ts.append(i)

            optimizer.zero_grad()

            _z0, _z1 = self.get_random_train_pairs(z0, z1, batchsize)
            z_t, t, target, loss_weight = self.get_train_tuple(_z0, _z1)
            # z_t.requires_grad_(True)
            predicted_target = self.model(z_t, t)
            # predicted_target[t.squeeze(-1) > 0.4, 2:] = torch.zeros_like(predicted_target[t.squeeze(-1) > 0.4, 2:])
            # target[t.squeeze(-1) > 0.4, 2:] = torch.zeros_like(target[t.squeeze(-1) > 0.4, 2:])
            # predicted_target[:, 2:] = torch.zeros_like(predicted_target[:, 2:])
            # target[:, 2:] = torch.zeros_like(target[:, 2:])
            #TODO: Maybe using (std * predicted_target - z) for identical will be better to avoid numerical issues.
            loss = (loss_weight * (predicted_target - target)).view(predicted_target.shape[0], -1).abs().pow(2).sum(dim=1)

            # grad_y = torch.autograd.grad(outputs=predicted_target, inputs=z_t, grad_outputs=torch.ones_like(z_t), create_graph=True)[0]
            # _, std = self.sde.marginal_prob(z_t, t)
            # loss2 = torch.mean((grad_y - 1/std) ** 2)

            loss = loss.mean()
            # loss+=loss2
            loss.backward()

            optimizer.step()
            loss_curve.append(loss.item())  ## to store the loss curve
            pbar.update(1)
        pbar.close()
        save_path = self.config['save_path']
        torch.save(self.model.state_dict(), f'{save_path}/{self.name}/{self.model.name}.pth')

        # plt.figure()
        # plt.plot(ts, gaussian_errors, label='close to gaussian')
        # plt.plot(ts, gaussian_errors_gt, label='close to gaussian gt')
        # plt.legend()
        # plt.show()
        # plt.close()
        # plt.figure()
        # plt.plot(ts, data_errors, label='close to data')
        # plt.plot(ts, data_errors_gt, label='close to data gt')
        # plt.legend()
        # plt.show()
        # plt.close()
        return loss_curve

    def get_random_pairs(self, z0, z1, N):
        tensor1, tensor2 = z0, z1
        # 
        N1, N2 = tensor1.shape[0], tensor2.shape[0]

        # 
        indices1 = torch.randint(low=0, high=N1, size=(N,))
        indices2 = torch.randint(low=0, high=N2, size=(N,))

        # 
        mapped_tensor1 = tensor1[indices1]
        mapped_tensor2 = tensor2[indices2]

        return mapped_tensor1, mapped_tensor2


    def get_train_tuple_fixed_t(self, z0, z1, t, per_noise):
        z1 = z1.repeat(per_noise, 1)
        t = t * torch.ones((z1.shape[0], 1)).to(z1.device) * (1 - 1e-4)
        # 1
        mean, std = self.sde.marginal_prob(z1, t)
        # 2
        # mean, std = self.sde.marginal_prob(z1,self.t_trans(t))
        # 3
        # mean, std = self.sde.marginal_prob(z1,self.t_inverse(t))

        # z = torch.randn((z1.shape[0], 20)).to('cuda')
        # z_t = mean + std * z0
        noise = torch.randn_like(z1)
        z_t = mean + std * noise
        # z_t[:, 1:] = z[:, 1:]
        target = self.from_pred_to_target(-noise / std, z_t, t)
        loss_weight = 1
        if self.reparam == 'epsilon_pred':
            # maybe this is good to avoid numerical unstable, i.e.,
            # z / std * std may not be z exactly when std is very small.
            # drift, diffusion = self.sde.sde(z_t, t)
            # loss_weight = diffusion**2
            target = noise #+ 0.3 * torch.randn_like(z)
        if self.reparam == 'x_pred':
            # maybe this is good to avoid numerical unstable, i.e.,
            # z / std * std may not be z when std is very small.
            target = z1
            coef, diffusion = self.sde.marginal_prob(z1, t, return_coef=True)
            # loss_weight = std / log_mean_coeff
            loss_weight = std / torch.log(coef)
        if self.reparam in [None, '', 'negative']:
            loss_weight = std

        return z_t, t, target, loss_weight

    def train_fixed_t(self, optimizer, z0, z1, batchsize, inner_iters, t, per_noise, ema):
        loss_curve = []
        pbar = tqdm(total=inner_iters, leave=False)
        for i in range(inner_iters + 1):
            optimizer.zero_grad()

            # _z0, _z1 = self.get_random_pairs(z0, z1, batchsize)
            z_t, t, target, loss_weight = self.get_train_tuple_fixed_t(z0, z1, t, per_noise=per_noise)
            predicted_target = self.model(z_t, t)
            loss = (loss_weight * (predicted_target - target)).view(predicted_target.shape[0], -1).abs().pow(2).sum(dim=1)
            loss = loss.mean()
            loss.backward()
            optimizer.step()
            ema.update(self.model)
            loss_curve.append(loss.item())  ## to store the loss curve
            pbar.update(1)
        pbar.close()
        # save_path = self.config['save_path']
        # torch.save(self.model.state_dict(), f'{save_path}/{self.name}/{self.model.name}.pth')

        t = t * torch.ones((z_t.shape[0], 1)).to(z_t.device) * (1 - 1e-4)
        error = ((self.pred(z_t, t) -self.gt(z_t, t))**2).mean().item()
        return error


    def train_by_gt(self, optimizer, z0, z1, batchsize, inner_iters):
        loss_curve = []
        pbar = tqdm(total=inner_iters, leave=False)
        for i in range(inner_iters + 1):
            optimizer.zero_grad()

            _z0, _z1 = self.get_random_train_pairs(z0, z1, batchsize)
            z_t, t, target, loss_weight = self.get_train_tuple(_z0, _z1)
            predicted_target = self.model(z_t, self.t_trans(t))
            gt_score = self.score(z_t, t)
            mask = ~torch.isnan(gt_score)
            # gt_score = gt_score[mask]
            # predicted_target = predicted_target[mask]
            # loss = (predicted_target - gt_score).view(predicted_target.shape[0], -1).abs().pow(2).sum(dim=1)
            loss = (loss_weight * (predicted_target - gt_score))[mask].view(predicted_target.shape[0], -1).abs().pow(2).sum(dim=1)
            # loss = ((predicted_target - gt_score))[mask].view(predicted_target.shape[0], -1).abs().pow(2).sum(dim=1)

            loss = loss.mean()
            loss.backward()

            optimizer.step()
            loss_curve.append(loss.item())  ## to store the loss curve
            pbar.update(1)
        pbar.close()
        save_path = self.config['save_path']
        torch.save(self.model.state_dict(), f'{save_path}/{self.name}/{self.model.name}.pth')
        return loss_curve

    def gt(self, x, t):
        return self.score(x, self.t_trans(t))


    @torch.no_grad()
    def p(self, x, t):
        if self.config['d'] == 1:
            return p_t_1d(x, *schedule(self.t_trans(t), type=self.trajectory),
                          self.config['mus'],
                          self.config['sigmas']**2, self.config['pis'])
        else:
            return p_t(x, *schedule(self.t_trans(t), type=self.trajectory),
                       self.config['mus'],
                       self.config['sigmas']**2, self.config['pis'])

class ScoreMatchRF(Flow):
    def __init__(self, model, config, trajectory, reparam='epsilon_pred'):
        super(ScoreMatchRF, self).__init__(model, trajectory,config)
        if self.trajectory == 'vp':
            self.sde = VPSDE()
        elif self.trajectory == 'subvp':
            self.sde = subVPSDE()
        elif self.trajectory == 'linear':
            self.sde = Linear()
        else:
            raise NotImplementedError
        assert isinstance(self.sde, SDE)
        self.reparam = reparam


    def get_train_tuple(self, z0=None, z1=None, t=None):
        if t is None:
            # when t = 0, the score explores (-inf or inf).
            # Neural network cannot learn.
            t = torch.randint(0, 2, (z1.shape[0], 1)).to('cuda')
            # t = t * (1 - 1e-4) + 1e-4
            t = t * (1 - 0.1) + 0.1
    # 1
        mean, std = self.sde.marginal_prob(z1, t)

        z = torch.randn_like(z1)
        z_t = mean + std * z
        target = self.from_pred_to_target(-z / std, z_t, t)
        loss_weight = 1
        if self.reparam == 'epsilon_pred':
            # maybe this is good to avoid numerical unstable, i.e.,
            # z / std * std may not be z exactly when std is very small.
            target = z #+ 0.3 * torch.randn_like(z)
        if self.reparam == 'x_pred':
            # maybe this is good to avoid numerical unstable, i.e.,
            # z / std * std may not be z when std is very small.
            target = z1
            coef, diffusion = self.sde.marginal_prob(z1, t, return_coef=True)
            loss_weight = std / torch.log(coef)
        if self.reparam in [None, '', 'negative']:
            loss_weight = std
        return z_t, t, target, loss_weight


    def train(self, optimizer, z0, z1, batchsize, inner_iters):
        loss_curve = []
        pbar = tqdm(total=inner_iters, leave=False)
        for i in range(inner_iters + 1):
            optimizer.zero_grad()

            _z0, _z1 = self.get_random_train_pairs(z0, z1, batchsize)
            z_t, t, target, loss_weight = self.get_train_tuple(_z0, _z1)

            predicted_target = self.model(z_t, t)
            #TODO: Maybe using (std * predicted_target - z) for identical will be better to avoid numerical issues.
            loss = (loss_weight * (predicted_target - target)).view(predicted_target.shape[0], -1).abs().pow(2).sum(dim=1)
            loss = loss.mean()
            loss.backward()

            optimizer.step()
            loss_curve.append(loss.item())  ## to store the loss curve
            pbar.update(1)

        pbar.close()
        save_path = self.config['save_path']
        torch.save(self.model.state_dict(), f'{save_path}/{self.name}/{self.model.name}.pth')
        torch.save(optimizer.state_dict(), f'{save_path}/{self.name}/op{self.model.name}.pth')
        return loss_curve

    def close_form_train(self, z1):
        import torch.nn.functional as F

        epsilon = torch.randn(2*z1.shape[0], z1.shape[1]).to('cuda')
        self.epsilon = epsilon

        t_0 = torch.zeros(z1.shape[0], 1).to('cuda') + 0.1
        t_1 = torch.ones(z1.shape[0], 1).to('cuda')
        mean, std = self.sde.marginal_prob(z1, t_0)
        z_t_0 = mean + std * epsilon[:z1.shape[0]]
        mean, std = self.sde.marginal_prob(z1, t_1)
        z_t_1 = mean + std * epsilon[z1.shape[0]:]

        input_x_t0 = torch.cat([z_t_0, t_0], dim=1)
        input_x_t1 = torch.cat([z_t_1, t_1], dim=1)

        random_features_t0 = F.relu(torch.matmul(input_x_t0, self.model.W))
        random_features_t1 = F.relu(torch.matmul(input_x_t1, self.model.W))

        Z = torch.cat([random_features_t0, random_features_t1], dim=0) #(2n, p)

        #  ( epsilon )
        Xi = epsilon # (2n, d)
        # 
        Z_T = Z.T  # (p, 2n)
        # theta_close_form = torch.linalg.solve(Z_T @ Z, Z_T @ Xi)  #  (Z_T Z)^(-1) Z_T Xi
        # theta_close_form = torch.linalg.inv(Z.T @ Z) @ Z @ Xi
        Z_pseudo_inv = torch.linalg.pinv(Z_T @ Z) # (p, p)
        theta_close_form = Z_pseudo_inv @ Z_T @ Xi # (p, d)

        #  theta 
        with torch.no_grad():
            self.model.theta.copy_(theta_close_form)
        save_path = self.config['save_path']
        torch.save(self.model.state_dict(), f'{save_path}/{self.name}/{self.model.name}.pth')
        return None

    def eval_loss_close_form(self, train=True):
        if train:
            z1 = self.config['samples_1']
            epsilon = self.epsilon
        else:
            z1 = self.config['samples_1_test']
            epsilon = torch.randn(2*z1.shape[0], z1.shape[1]).to('cuda')

        t_0 = torch.zeros(z1.shape[0], 1).to('cuda') + 0.1
        t_1 = torch.ones(z1.shape[0], 1).to('cuda')
        mean, std = self.sde.marginal_prob(z1, t_0)
        z_t_0 = mean + std * epsilon[:z1.shape[0]]
        mean, std = self.sde.marginal_prob(z1, t_1)
        z_t_1 = mean + std * epsilon[z1.shape[0]:]

        input_x_t0 = torch.cat([z_t_0, t_0], dim=1)
        input_x_t1 = torch.cat([z_t_1, t_1], dim=1)

        loss_t0 = torch.mean((epsilon[:z1.shape[0]] - self.model(z_t_0, t_0)) ** 2)
        loss_t1 = torch.mean((epsilon[z1.shape[0]:] - self.model(z_t_1, t_1)) ** 2)
        return loss_t0.detach().cpu().numpy(), loss_t1.detach().cpu().numpy()

    def eval_theta_distance_close_form(self, train=True):
        import torch.nn.functional as F

        if train:
            z1 = self.config['samples_1']
            epsilon = self.epsilon
        else:
            z1 = self.config['samples_1_test']
            epsilon = torch.randn(2*z1.shape[0], z1.shape[1]).to('cuda')

        t_0 = torch.zeros(z1.shape[0], 1).to('cuda') + 0.1
        t_1 = torch.ones(z1.shape[0], 1).to('cuda')
        mean, std = self.sde.marginal_prob(z1, t_0)
        z_t_0 = mean + std * epsilon[:z1.shape[0]]
        mean, std = self.sde.marginal_prob(z1, t_1)
        z_t_1 = mean + std * epsilon[z1.shape[0]:]

        input_x_t0 = torch.cat([z_t_0, t_0], dim=1)
        input_x_t1 = torch.cat([z_t_1, t_1], dim=1)

        random_features_t0 = F.relu(torch.matmul(input_x_t0, self.model.W))
        random_features_t1 = F.relu(torch.matmul(input_x_t1, self.model.W))

        # Z = torch.cat([random_features_t0, random_features_t1], dim=0) #p*2n
        Z_0 = random_features_t0 #p*n
        Xi_0 = epsilon[:z1.shape[0]] # n*d
        Z_0_T = Z_0.T  # Z 
        Z_0_pseudo_inv = torch.linalg.pinv(Z_0_T @ Z_0)
        theta_0_close_form = Z_0_pseudo_inv @ Z_0_T @ Xi_0

        Z_1 = random_features_t1 #p*n
        Xi_1 = epsilon[z1.shape[0]:] # n*d
        Z_1_T = Z_1.T  # Z 
        Z_1_pseudo_inv = torch.linalg.pinv(Z_1_T @ Z_1)
        theta_1_close_form = Z_1_pseudo_inv @ Z_1_T @ Xi_1

        H_0 = random_features_t0.T @ random_features_t0
        H_1 = random_features_t1.T @ random_features_t1

        dis_0 = (theta_0_close_form - self.model.theta).T @ H_0 @ (theta_0_close_form - self.model.theta)
        dis_1 = (theta_1_close_form - self.model.theta).T @ H_1 @ (theta_1_close_form - self.model.theta)
        dis_0 = torch.mean(dis_0)
        dis_1 = torch.mean(dis_1)
        return dis_0.detach().cpu().numpy(), dis_1.detach().cpu().numpy()

    def eval_loss(self):
        z0 = self.config['samples_0']
        z1 = self.config['samples_1']
        save_path = self.config['save_path']
        optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-3)
        batchsize = self.config['batchsize']
        optimizer.load_state_dict(torch.load(f'{save_path}/{self.name}/op{self.model.name}.pth'))
        _loss_0 = []
        _loss_1 = []
        for i in range(10 + 1):
            optimizer.zero_grad()

            _z0, _z1 = self.get_random_train_pairs(z0, z1, batchsize)
            z_t, t, target, loss_weight = self.get_train_tuple(_z0, _z1)

            predicted_target = self.model(z_t, t)
            #TODO: Maybe using (std * predicted_target - z) for identical will be better to avoid numerical issues.
            loss = (loss_weight * (predicted_target - target)).view(predicted_target.shape[0], -1).abs().pow(2).sum(dim=1)
            loss = loss.mean()
            loss.backward()

            optimizer.step()

            # t_0 = torch.zeros(z1.shape[0], 1).to('cuda') + 1e-4
            t_0 = torch.zeros(z1.shape[0], 1).to('cuda') + 0.1
            t_1 = torch.ones(z1.shape[0], 1).to('cuda')

            mean, std = self.sde.marginal_prob(z1, t_0)
            z = torch.randn_like(z1)
            z_t_0 = mean + std * z
            target_0 = z

            mean, std = self.sde.marginal_prob(z1, t_1)
            z = torch.randn_like(z1)
            z_t_1 = mean + std * z
            target_1 = z

            predicted_target_0 = self.model(z_t_0, t_0)
            predicted_target_1 = self.model(z_t_1, t_1)

            loss_0 = (predicted_target_0 - target_0).view(predicted_target_0.shape[0], -1).abs().pow(2).sum(dim=1).mean()
            loss_1 = (predicted_target_1 - target_1).view(predicted_target_1.shape[0], -1).abs().pow(2).sum(dim=1).mean()
            _loss_0.append(loss_0.detach().cpu().numpy())
            _loss_1.append(loss_1.detach().cpu().numpy())
        loss_0 = np.mean(_loss_0)
        loss_1 = np.mean(_loss_1)
        # print(loss_0, loss_1)
        return loss_0, loss_1



def reparameterize(self, method=None):
    assert isinstance(self, Flow)
    if isinstance(self, ScoreMatch):
        # print('Reparameterizing using {}'.format(method))
        if method == 'epsilon_pred':
            def from_pred_to_target(output, x, t):
                _, std = self.sde.marginal_prob(output, t)
                return -output * std
            def from_target_to_pred(output, x, t):
                _, std = self.sde.marginal_prob(output, t)
                return -output / std
            self.from_pred_to_target = lambda output, x, t: from_pred_to_target(output, x, t)
            self.from_target_to_pred = lambda output, x, t: from_target_to_pred(output, x, t)
        elif method == 'negative':
            self.from_pred_to_target = lambda output, x, t: -output
            self.from_target_to_pred = lambda output, x, t: -output
        elif method in ['', None]:
            # print('You may want to use a reparameterization method.')
            self.from_pred_to_target = lambda output, x, t: output
            self.from_target_to_pred = lambda output, x, t: output
        elif method == 'x_pred':
            def from_pred_to_target(output, x, t):
                mean, std = self.sde.marginal_prob(output, t, return_coef=True)
                return (x + output * std ** 2) / mean
            def from_target_to_pred(output, x, t):
                mean, std = self.sde.marginal_prob(output, t, return_coef=True)
                return -(x - output * mean) / std ** 2
            self.from_pred_to_target = lambda output, x, t: from_pred_to_target(output, x, t)
            self.from_target_to_pred = lambda output, x, t: from_target_to_pred(output, x, t)
        else:
            raise NotImplementedError(method)
    if isinstance(self, RectifiedFlow):
        print('Reparameterizing using {}'.format(method))
        if method == 'epsilon_pred':
            def from_pred_to_target(output, x, t):
                alpha_t_dot, beta_t_dot = abdot(t, self.trajectory)
                alpha_t, beta_t = schedule(t, self.trajectory)
                z_0 = (alpha_t * output - alpha_t_dot * x) /(alpha_t * beta_t_dot - beta_t * alpha_t_dot)
                return z_0

            def from_target_to_pred(output, x, t):
                alpha_t_dot, beta_t_dot = abdot(t, self.trajectory)
                alpha_t, beta_t = schedule(t, self.trajectory)
                output = output * (alpha_t * beta_t_dot - beta_t * alpha_t_dot)
                output = output + alpha_t_dot * x
                output = output / alpha_t
                return output
            self.from_pred_to_target = lambda output, x, t: from_pred_to_target(output, x, t)
            self.from_target_to_pred = lambda output, x, t: from_target_to_pred(output, x, t)
        elif method == 'x_pred':
            def from_pred_to_target(output, x, t):
                alpha_t_dot, beta_t_dot = abdot(t, self.trajectory)
                alpha_t, beta_t = schedule(t, self.trajectory)
                z_1 = (beta_t * output - beta_t_dot * x) / (alpha_t_dot * beta_t - alpha_t * beta_t_dot)
                return z_1

            def from_target_to_pred(output, x, t):
                alpha_t_dot, beta_t_dot = abdot(t, self.trajectory)
                alpha_t, beta_t = schedule(t, self.trajectory)
                output = output * (alpha_t_dot * beta_t - alpha_t * beta_t_dot)
                output = output + beta_t_dot * x
                output = output / beta_t
                return output

            self.from_pred_to_target = lambda output, x, t: from_pred_to_target(output, x, t)
            self.from_target_to_pred = lambda output, x, t: from_target_to_pred(output, x, t)
        elif method in ['', None]:
            self.from_pred_to_target = lambda output, x, t: output
            self.from_target_to_pred = lambda output, x, t: output
        else:
            raise NotImplementedError(method)

device = 'cuda'
def eval_flow(flow, z0):
    assert isinstance(flow, Flow)
    def preprocess_x(x):
        _x = torch.zeros(len(x), z0.shape[1]).to(device)
        _x[:, 0:1] = x.clone()
        return _x

    def velocity_gradient(x_grid, t):
        x_grid.requires_grad_()  # 
        x_grid.grad = None
        t = t.clone().detach().requires_grad_(True)
        p = flow.p(preprocess_x(x_grid.unsqueeze(-1)), t)
        p = p / torch.sum(p)
        p = p.detach()
        predicted_v = flow.predicted_velocity(
            preprocess_x(x_grid[:, None]), t)
        # error = (p * predicted_v)[:, 0:1]
        # error.backward(torch.ones_like(error))

        predicted_v[:, 0:1].backward(torch.ones_like(predicted_v[:, 0:1]), create_graph=False)
        grad = x_grid.grad
        # grad = grad / torch.max(torch.abs(grad))
        grad = p.squeeze() * grad
        return grad

    def velocity_gradient_gt(x_grid, t):
        x_grid.requires_grad_()  # 
        x_grid.grad = None
        t = t.clone().detach().requires_grad_(True)
        p = flow.p(preprocess_x(x_grid.unsqueeze(-1)), t)
        p = p / torch.sum(p)
        p = p.detach()
        predicted_v = flow.velocity(
            preprocess_x(x_grid[:, None]), t)
        predicted_v[:, 0:1].backward(torch.ones_like(predicted_v[:, 0:1]), create_graph=False)
        grad = x_grid.grad
        # grad = grad / torch.max(torch.abs(grad))
        grad = p.squeeze() * grad
        return grad

    # draw_grid(velocity_error,
    #           'Velocity Error',
    #           f'{save_path}/{rectified_flow.name}/velocity_error', vmin=-1., vmax=1.)
    errors = []
    gts = []
    for _t in torch.tensor([0, 0.99]).to(device):
        x_values = torch.linspace(-3, 3, 1000).to(device)
        func_values = velocity_gradient(x_values, _t * torch.ones_like(x_values.unsqueeze(-1))).detach().cpu().numpy()
        error = np.trapz(np.abs(func_values), x_values.detach().cpu().numpy())
        # error = np.std(np.abs(func_values))
        errors.append(error)

    for _t in torch.tensor([0, 0.99]).to(device):
        x_values = torch.linspace(-3, 3, 1000).to(device)
        func_values = velocity_gradient_gt(x_values, _t * torch.ones_like(x_values.unsqueeze(-1))).detach().cpu().numpy()
        error = np.trapz(np.abs(func_values), x_values.detach().cpu().numpy())
        # error = np.max(np.abs(func_values))
        gts.append(error)
    return errors, gts
