"""
This file contains the iPNDM sampler implementation from LD3 codebase
with minimal modifications for integration into our framework.

Original code:
    https://github.com/vinhsuhi/LD3/blob/ec1bf603fb19696966ca30198ed209ae6488a3e5/samplers/ipndm.py
"""

import torch
from src.ld3.general_solver import ODESolver


def einsum_float_double(string, a, b):
    """
    Compute einsum(a, b) with float64 precision.
    """
    return torch.einsum(string, a.double(), b.double())

class iPNDM(ODESolver):
    def __init__(
        self,
        noise_schedule,
        algorithm_type="noise_prediction",
    ):
        super().__init__(noise_schedule, algorithm_type)
        self.noise_schedule = noise_schedule # noiseScheduleVP
        assert algorithm_type == "noise_prediction" # need to be noise prediction!
        self.predict_x0 = algorithm_type == "data_prediction" # false
    
    def sample(self, model_fn, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', flags=None,
    ):
        self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
        t_0 = self.noise_schedule.eps if t_end is None else t_end
        t_T = self.noise_schedule.T if t_start is None else t_start
        device = x.device
        timesteps, timesteps2 = self.prepare_timesteps(steps=steps, t_start=t_T, t_end=t_0, skip_type=skip_type, device=device, load_from=flags.load_from)
        
        with torch.no_grad():
            return self.sample_simple(model_fn, x, order, timesteps, timesteps2)

    def sample_simple(self, model_fn, x, timesteps, timesteps2, order=2, condition=None, unconditional_condition=None, **kwargs):
        '''
        PNDM follows the steps:
        
        '''
        # self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])), condition, unconditional_condition)
        self.model = model_fn
        
        epsilon_buffer = list() 
        x_next = x
        
        ns = self.noise_schedule
        steps = len(timesteps) - 1
        for step in range(steps):
            step_order = min(order, step + 1)
            
            t_cur1, t_next1 = timesteps[step], timesteps[step + 1]
            t_cur2, t_next2 = timesteps2[step], timesteps2[step + 1]
        
            x_cur = x_next 
            epsilon_cur = self.model_fn(x_cur, t_cur2)
            
            lambda_s, lambda_t = ns.marginal_lambda(t_cur1), ns.marginal_lambda(t_next1)
            h = lambda_t - lambda_s
            log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(t_cur1), ns.marginal_log_mean_coeff(t_next1)
            sigma_t = ns.marginal_std(t_next1)
            phi_1 = torch.expm1(h)
            if step_order == 1:
                x_next = (
                    torch.exp(log_alpha_t - log_alpha_s) * x_cur 
                    - (sigma_t * phi_1) * epsilon_cur
                )
            elif step_order == 2:
                x_next = (
                    torch.exp(log_alpha_t - log_alpha_s) * x_cur 
                    - (sigma_t * phi_1) * (3 * epsilon_cur - 1 * epsilon_buffer[-1]) / 2
                )
            elif step_order == 3:
                x_next = (
                    torch.exp(log_alpha_t - log_alpha_s) * x_cur 
                    - (sigma_t * phi_1) * (23 * epsilon_cur - 16 * epsilon_buffer[-1] + 5 * epsilon_buffer[-2]) / 12
                )
            elif step_order == 4:
                x_next = (
                    torch.exp(log_alpha_t - log_alpha_s) * x_cur 
                    - (sigma_t * phi_1) * (55 * epsilon_cur - 59 * epsilon_buffer[-1] + 37 * epsilon_buffer[-2] - 9 * epsilon_buffer[-3]) / 24
                )
            
            if len(epsilon_buffer) == order - 1:
                for k in range(order - 2):
                    epsilon_buffer[k] = epsilon_buffer[k + 1]
                epsilon_buffer[-1] = epsilon_cur
            else:
                epsilon_buffer.append(epsilon_cur)
            
        return x_next
        
