import torch
from abc import ABC, abstractmethod
import os
# from noise_schedulers import NoiseScheduleVE, NoiseScheduleVP

import numpy as np
from scipy.optimize import minimize
from scipy.optimize import LinearConstraint

class StepOptim(object):
    def __init__(self, ns):
        super().__init__()
        self.ns = ns
        self.T = self.ns.T # t_T of diffusion sampling, for VP models, T=1.0; for EDM models, T=80.0
        # self.is_latent_space = isinstance(self.ns, NoiseScheduleVP)

    def alpha(self, t):
        t = torch.as_tensor(t, dtype = torch.float64)
        return self.ns.marginal_alpha(t).numpy()
    def sigma(self, t):
        return np.sqrt(1 - self.alpha(t) * self.alpha(t))
    def lambda_func(self, t):
        return np.log(self.alpha(t)/self.sigma(t))
    def edm_lambda_func(self, t):
        return np.log(self.alpha(t)/self.edm_sigma(t))
    def H0(self, h):
        return np.exp(h) - 1
    def H1(self, h):
        return np.exp(h) * h - self.H0(h)
    def H2(self, h):
        return np.exp(h) * h * h - 2 * self.H1(h)
    def H3(self, h):
        return np.exp(h) * h * h * h - 3 * self.H2(h)
    def inverse_lambda(self, lamb):
        lamb = torch.as_tensor(lamb, dtype = torch.float64)
        return self.ns.inverse_lambda(lamb)
    def edm_sigma(self, t):
        return t
    def edm_inverse_sigma(self, edm_sigma):
        alpha = 1 / (edm_sigma*edm_sigma+1).sqrt()
        sigma = alpha*edm_sigma
        lambda_t = np.log(alpha/sigma)
        t = self.inverse_lambda(lambda_t)
        return t

    def sel_lambdas_lof_obj(self, lambda_vec, eps):
        lambda_func = self.lambda_func if self.is_latent_space else self.edm_lambda_func
        lambda_eps, lambda_T = lambda_func(eps).item(), lambda_func(self.T).item()
        lambda_vec_ext = np.concatenate((np.array([lambda_T]), lambda_vec, np.array([lambda_eps])))
        N = len(lambda_vec_ext) - 1

        hv = np.zeros(N)
        for i in range(N):
            hv[i] = lambda_vec_ext[i+1] - lambda_vec_ext[i]
        elv = np.exp(lambda_vec_ext)
        emlv_sq = np.exp(-2*lambda_vec_ext)
        alpha_vec = 1./np.sqrt(1+emlv_sq)
        sigma_vec = 1./np.sqrt(1+np.exp(2*lambda_vec_ext))
        if self.is_latent_space:
            data_err_vec = (sigma_vec**2)/alpha_vec
        else:
            data_err_vec = (sigma_vec**1)/alpha_vec
        # for pixel-space diffusion models, we empirically find (sigma_vec**1)/alpha_vec will be better

        if N <= 7:
            truncNum = 3 # For NFEs <= 7, set truncNum = 3 to avoid numerical instability; for NFEs > 7, truncNum = 0
        else:
            truncNum = 0

        res = 0. 
        c_vec = np.zeros(N)
        for s in range(N):
            if s in [0, N-1]:
                n, kp = s, 1 
                J_n_kp_0 = elv[n+1] - elv[n]
                res += abs(J_n_kp_0 * data_err_vec[n])
            elif s in [1, N-2]:
                n, kp = s-1, 2
                J_n_kp_0 = -elv[n+1] * self.H1(hv[n+1]) / hv[n]
                J_n_kp_1 = elv[n+1] * (self.H1(hv[n+1])+hv[n]*self.H0(hv[n+1])) / hv[n]
                if s >= truncNum:
                    c_vec[n] += data_err_vec[n] * J_n_kp_0
                    c_vec[n+1] += data_err_vec[n+1] * J_n_kp_1
                else:
                    res += np.sqrt((data_err_vec[n] * J_n_kp_0)**2 + (data_err_vec[n+1] * J_n_kp_1)**2)
            else:
                n, kp = s-2, 3  
                J_n_kp_0 = elv[n+2] * (self.H2(hv[n+2])+hv[n+1]*self.H1(hv[n+2])) / (hv[n]*(hv[n]+hv[n+1]))
                J_n_kp_1 = -elv[n+2] * (self.H2(hv[n+2])+(hv[n]+hv[n+1])*self.H1(hv[n+2])) / (hv[n]*hv[n+1])
                J_n_kp_2 = elv[n+2] * (self.H2(hv[n+2])+(2*hv[n+1]+hv[n])*self.H1(hv[n+2])+hv[n+1]*(hv[n]+hv[n+1])*self.H0(hv[n+2])) / (hv[n+1]*(hv[n]+hv[n+1]))
                if s >= truncNum:
                    c_vec[n] += data_err_vec[n] * J_n_kp_0
                    c_vec[n+1] += data_err_vec[n+1] * J_n_kp_1
                    c_vec[n+2] += data_err_vec[n+2] * J_n_kp_2
                else:
                    res += np.sqrt((data_err_vec[n] * J_n_kp_0)**2 + (data_err_vec[n+1] * J_n_kp_1)**2 + (data_err_vec[n+2] * J_n_kp_2)**2)
        res += sum(abs(c_vec))
        return res

    def get_ts_lambdas(self, N, eps):
        if self.is_latent_space:
            initType = "unif_t"
        else:
            initType = "unif"
        # eps is t_0 of diffusion sampling, e.g. 1e-3 for VP models
        # initType: initTypes with '_origin' are baseline time step discretizations (without optimization)
        # initTypes without '_origin' are optimized time step discretizations with corresponding baseline
        # time step discretizations as initializations. For latent-space diffusion models, 'unif_t' is recommended.
        # For pixel-space diffusion models, 'unif' is recommended (which is logSNR initialization)

        lambda_func = self.lambda_func if self.is_latent_space else self.edm_lambda_func
        lambda_eps, lambda_T = lambda_func(eps).item(), lambda_func(self.T).item()
        
        # constraints
        constr_mat = np.zeros((N, N-1)) 
        for i in range(N-1):
            constr_mat[i][i] = 1.
            constr_mat[i+1][i] = -1
        lb_vec = np.zeros(N)
        lb_vec[0], lb_vec[-1] = lambda_T, -lambda_eps

        ub_vec = np.zeros(N)
        for i in range(N):
            ub_vec[i] = np.inf
        linear_constraint = LinearConstraint(constr_mat, lb_vec, ub_vec)

        # initial vector
        if initType in ['unif', 'unif_origin']:
            lambda_vec_ext = torch.linspace(lambda_T, lambda_eps, N+1)
        elif initType in ['unif_t', 'unif_t_origin']:
            t_vec = torch.linspace(self.T, eps, N+1)
            lambda_vec_ext = self.lambda_func(t_vec)
        elif initType in ['edm', 'edm_origin']:
            rho = 7
            edm_sigma_min, edm_sigma_max = self.edm_sigma(eps).item(), self.edm_sigma(self.T).item()
            edm_sigma_vec = torch.linspace(edm_sigma_max**(1. / rho), edm_sigma_min**(1. / rho), N + 1).pow(rho)
            t_vec = self.edm_inverse_sigma(edm_sigma_vec)
            lambda_vec_ext = self.lambda_func(t_vec)
        elif initType in ['quad', 'quad_origin']:
            t_order = 2
            t_vec = torch.linspace(self.T**(1./t_order), eps**(1./t_order), N+1).pow(t_order)
            lambda_vec_ext = self.lambda_func(t_vec)
        else:
            print('InitType not found!')
            return 

        if initType in ['unif_origin', 'unif_t_origin', 'edm_origin', 'quad_origin']:
                lambda_res = lambda_vec_ext
                t_res = torch.tensor(self.inverse_lambda(lambda_res))
        else: 
            lambda_vec_init = np.array(lambda_vec_ext[1:-1])
            res = minimize(self.sel_lambdas_lof_obj, lambda_vec_init, method='trust-constr', args=(eps), constraints=[linear_constraint], options={'verbose': 1})
            lambda_res = torch.tensor(np.concatenate((np.array([lambda_T]), res.x, np.array([lambda_eps]))))
            t_res = torch.tensor(self.inverse_lambda(lambda_res))
        return t_res, lambda_res
    
    
    
def expand_dims(x, dims):
    for _ in range(dims):
        x = x.unsqueeze(-1)
    return x

def update_lists(t_list, model_list, t_, model_x, order, first=False):
    if first:
        t_list.append(t_)
        model_list.append(model_x)
        return
    for m in range(order - 1):
        t_list[m] = t_list[m + 1]
        model_list[m] = model_list[m + 1]
    t_list[-1] = t_
    model_list[-1] = model_x

class ODESolver(ABC):
    def __init__(
        self,
        noise_schedule,
        algorithm_type="data_prediction",
        correcting_x0_fn=None,
    ):
        self.noise_schedule = noise_schedule # noiseScheduleVP
        assert algorithm_type in ["data_prediction", "noise_prediction"]
        self.predict_x0 = algorithm_type == "data_prediction" # true
        self.correcting_x0_fn = correcting_x0_fn # None
        
    
    def dx_dt_for_blackbox_solvers(self, x, t1, t2):
        '''
        for edm, dx_dt = noise
        '''
        ft = self.noise_schedule.ft(t1) # should be 0.
        gt = self.noise_schedule.gt(t1) # should be 1. 
        sigma_t = self.noise_schedule.marginal_std(t1)
        noise = self.noise_prediction_fn(x, t2)
        return ft * x + gt ** 2 / (2 * sigma_t) * noise
    

    def noise_prediction_fn(self, x, t):
        """
        Return the noise prediction model.
        """
        return self.model(x, t)

    def data_prediction_fn(self, x, t):
        """
        Return the data prediction model (with corrector).
        """
        noise = self.noise_prediction_fn(x, t)
        alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
        x0 = (x - sigma_t * noise) / alpha_t
        if self.correcting_x0_fn is not None:
            x0 = self.correcting_x0_fn(x0)
        return x0

    def model_fn(self, x, t):
        """
        Convert the model to the noise prediction model or the data prediction model. 
        """
        if self.predict_x0:
            return self.data_prediction_fn(x, t)
        else:
            return self.noise_prediction_fn(x, t)

    def get_time_steps(self, skip_type, t_T, t_0, N, device):
        """Compute the intermediate time steps for sampling.
        """
        if skip_type == 'logSNR':
            lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
            lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
            logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
            t = self.noise_schedule.inverse_lambda(logSNR_steps)
        elif skip_type == 'time_uniform':
            t = torch.linspace(t_T, t_0, N + 1).to(device)
        elif skip_type == 'time_quadratic':
            rho = 2.0
            t = self.get_time_step_poly(t_T, t_0, N, device, rho)
        elif skip_type == "edm":
            rho = 7.0  # 7.0 is the value used in the paper
            t = self.get_time_step_edm(t_T, t_0, N, device, rho)
            t_t = self.get_time_step_edm_t(t_T, t_0, N, device, rho)
            # distance = (t - t_t).abs().max()
            # breakpoint()
            # if distance > 1e-6:
            #     raise ValueError("The time steps are not equal")
        elif "poly" in skip_type:
            rho = float(skip_type.split("_")[-1])
            t = self.get_time_step_poly(t_T, t_0, N, device, rho)
        elif skip_type == "dmn":
            optimizer = StepOptim(self.noise_schedule)
            t, _ = optimizer.get_ts_lambdas(N, t_0)
            t = t.to(device).to(torch.float32)
            print(t)
            return t
        else:
            raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))

        return t

    def append_zero(self, x):
        return torch.cat([x, x.new_zeros([1])])
    
    
    # def get_time_step_poly(self, sigma_max, sigma_min, n, device, rho=7.0):
    #     """Constructs the noise schedule of Karras et al. (2022)."""
    #     ramp = torch.linspace(0, 1, n)
    #     min_inv_rho = sigma_min ** (1 / rho)
    #     max_inv_rho = sigma_max ** (1 / rho)
    #     sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    #     return self.append_zero(sigmas).to(device)
        
    # def get_time_step_poly(self, t_T, t_0, N, device, rho=7.0):
    #     t_min: float = t_0
    #     t_max: float = t_T
    #     ramp = torch.linspace(0, 1, N + 1).to(device)
    #     min_inv_rho = t_min ** (1 / rho)
    #     max_inv_rho = t_max ** (1 / rho)
    #     ts = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    #     return ts
    
    def get_time_step_poly(self, t_T, t_0, N, device, rho=2.0):
        mono_sequence = torch.arange(0, N+1).pow(rho).to(device)
        sequence_min = mono_sequence.min()
        sequence_max = mono_sequence.max()
        t_max = t_T
        t_min = t_0
        ts = t_min + (t_max - t_min) * (mono_sequence - sequence_min) / (sequence_max - sequence_min)
        return ts.flip(0)
    
    def get_time_step_edm_t(self, t_T, t_0, N, device, rho=7.0):
        t_min: float = t_0
        t_max: float = t_T
        ramp = torch.linspace(0, 1, N + 1).to(device)
        min_inv_rho = t_min ** (1 / rho)
        max_inv_rho = t_max ** (1 / rho)
        ts = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return ts
    
    # def get_time_step_edm(self, t_T, t_0, N, device, rho=7.0):
    #     if isinstance(self.noise_schedule, NoiseScheduleVE):
    #         sigma_min = self.noise_schedule.marginal_std(t_0).to(device)
    #         sigma_max = self.noise_schedule.marginal_std(t_T).to(device)
    #     else:
    #         sigma_min = t_0 
    #         sigma_max = t_T 

    #     ramp = torch.linspace(0, 1, N + 1).to(device)
    #     min_inv_rho = sigma_min ** (1 / rho)
    #     max_inv_rho = sigma_max ** (1 / rho)
    #     sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    #     if isinstance(self.noise_schedule, NoiseScheduleVE):
    #         ts = self.noise_schedule.inverse_std(sigmas)
    #     else:
    #         ts = sigmas
    #     return ts
    
    def prepare_learn_timesteps(self, load_from, load_rs=False, device=None):
        # timesteps = torch.load(os.path.join(load_from, 'best.pt'))['best_t_steps']
        timesteps = torch.load(load_from)['best_t_steps'].to(device)
        
        length = timesteps.shape[0] // 2
        timesteps2 = timesteps[length:]
        timesteps = timesteps[:length]
        
        if load_rs:
            try:
                rs = torch.load(load_from)['best_rs'].to(device)
                rs2 = rs[length:]
                rs = rs[:length]
            except:
                rs = [0.5] * length
                rs2 = rs
            return timesteps, timesteps2, rs, rs2
            
        return timesteps, timesteps2
    
    
    def prepare_timesteps(self, steps=None, t_start=None, t_end=None, skip_type=None, device=None, load_from=None):
        if load_from is not None and os.path.isfile(load_from):
            timesteps, timesteps2 = self.prepare_learn_timesteps(load_from=load_from, device=device)
        else:
            timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_start, t_0=t_end, N=steps, device=device)
            timesteps2 = timesteps
        return timesteps, timesteps2
    
    
    def prepare_timesteps_single(self, steps, NFEs, t_start, t_end, flags, device, skip_type='time_uniform'):
        if flags.learn:
            timesteps, timesteps2, rs, rs2 = self.prepare_learn_timesteps(load_from=flags.load_from, load_rs=True, device=device)
        else:
            timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_start, t_0=t_end, N=steps, device=device)
            timesteps2 = timesteps
            rs = [0.5] * steps 
            rs2 = rs      
        return timesteps, timesteps2, rs, rs2


    def sample(self, *args, **kwargs):
        pass

    @abstractmethod
    def sample_simple(self, model_fn, x, timesteps, timesteps2=None, condition=None, unconditional_condition=None, **kwargs):
        pass 
    
    def dynamic_thresholding_fn(self, x0, t):
        """
        The dynamic thresholding method.(not used by anything so far)
        """
        dims = x0.dim()
        p = self.dynamic_thresholding_ratio
        s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
        s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
        x0 = torch.clamp(x0, -s, s) / s
        return x0

    
class GUIDEDSolver(ODESolver):
    def __init__(
        self,
        noise_schedule,
        algorithm_type="data_prediction",
        correcting_x0_fn=None,
    ):
        super().__init__(noise_schedule, algorithm_type, correcting_x0_fn)
        self.noise_schedule = noise_schedule # noiseScheduleVP
        assert algorithm_type in ["data_prediction", "noise_prediction"]
        self.predict_x0 = algorithm_type == "data_prediction" # true
        self.correcting_x0_fn = correcting_x0_fn # None
        
    @abstractmethod
    def forward_sample_simple(self, latent, timesteps, timesteps2=None, return_image_list=False, **kwargs):
        pass 


    @abstractmethod
    def backward_sample_simple(self, image_list, grad, timesteps=None, timesteps2=None, dis_model=None, **kwargs):
        pass 

    @abstractmethod
    def sample(self, x, steps, t_start, t_end, order, skip_type, flags):
        pass 


class MultiStepODESolver(GUIDEDSolver):

    def __init__(self, model_fn, noise_schedule, algorithm_type="data_prediction"):
        '''
        algorithm_type needs to be data_prediction
        '''
        super().__init__(model_fn, noise_schedule, algorithm_type)
    
    @abstractmethod
    def _one_step(self, t1, t2, t_prev_list, model_prev_list, step, x_next, order=None, update_list=False, first=True):
        pass 

    def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', flags=None):
        t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
        t_T = self.noise_schedule.T if t_start is None else t_start
        device = x.device
        timesteps, timesteps2 = self.prepare_timesteps(steps=steps, t_start=t_T, t_end=t_0, skip_type=skip_type, device=device, load_from=flags.load_from)

        with torch.no_grad():
            return self.forward_sample_simple(x, timesteps, timesteps2, order=order, return_image_list=False)
        
    def forward_sample_simple(self, latent, timesteps, timesteps2=None, return_image_list=False, **kwargs): 
        assert 'order' in kwargs
        order = kwargs['order']
        if timesteps2 is None:
            timesteps2 = timesteps
        step = 0
        numsteps = len(timesteps) - 1
        with torch.no_grad():
            t_student1 = timesteps[step]      
            t_student2 = timesteps2[step]
            t_prev_list_student = [t_student1]
            x_next_ = latent.clone() # bs x 3 x 256 x 256
            
            denoised_T = self.model_fn(x_next_, t_student2)
            model_prev_list_student = [denoised_T]
            
            if return_image_list:
                image_list = []
                image_list.append(x_next_)
                
            for step in range(1, order):
                t1 = timesteps[step]
                t2 = timesteps2[step]
                x_next_ = self._one_step(t1, t2, t_prev_list_student, model_prev_list_student, step, x_next_, order, update_list=True, first=True)
                if return_image_list:
                    image_list.append(x_next_)
            
            for step in range(order, numsteps + 1):
                t1 = timesteps[step]
                t2 = timesteps2[step]
                step_order = min(order, numsteps + 1 - step)
                x_next_ = self._one_step(t1, t2, t_prev_list_student, model_prev_list_student, step_order, x_next_, order, update_list=True, first=False)
                if return_image_list:
                    image_list.append(x_next_)
        if return_image_list:
            return image_list
        return x_next_

    def backward_sample_simple(self, image_list, grad, timesteps=None, timesteps2=None, dis_model=None, **kwargs):    
        assert 'order' in kwargs
        order = kwargs['order']
        assert timesteps is None or len(timesteps) == len(image_list)
        numsteps = len(image_list) - 1 

        for ele in image_list:
            ele.requires_grad = True
            ele.retain_grad()

        for step in range(numsteps, order - 1, -1):
            if dis_model is not None:
                timesteps, timesteps2 = dis_model() 
            else:
                timesteps2 = timesteps2 if timesteps2 is not None else timesteps

            t1 = timesteps[step]
            t2 = timesteps2[step]
            
            t_prev_list_student = [timesteps[step - i - 1] for i in range(order)][::-1] # decrease
            t_prev_list_student2 = [timesteps2[step - i - 1] for i in range(order)][::-1] # decrease
            this_image_list = [image_list[step - i - 1] for i in range(order)][::-1] # decrease
            model_prev_list_student = [self.model_fn(this_image_list[i], t_prev_list_student2[i]) for i in range(len(t_prev_list_student2))]
            
            x_next_input = image_list[step - 1] # use x_1 to predict x_0; use x_2 to predict x_1,..
            
            step_order = min(order, numsteps + 1 - step)
            x_next_ = self._one_step(t1, t2, t_prev_list_student, model_prev_list_student, step_order, x_next_input, update_list=False) # x_0
            x_next_.backward(grad, retain_graph=False) # 
            grad = x_next_input.grad.detach() # dL / dx_1
            
            
        for step in range(order - 1, 0, -1): # 2, 1
            if dis_model is not None:
                timesteps, timesteps2 = dis_model() 
            else:
                timesteps2 = timesteps2 if timesteps2 is not None else timesteps

            t1 = timesteps[step]
            t2 = timesteps2[step]
            t_prev_list_student = [timesteps[step - i - 1] for i in range(step)][::-1] # decrease
            t_prev_list_student2 = [timesteps2[step - i - 1] for i in range(step)][::-1] # decrease
            this_image_list = [image_list[step - i - 1] for i in range(step)][::-1] # decrease
            model_prev_list_student = [self.model_fn(this_image_list[i], t_prev_list_student2[i]) for i in range(len(t_prev_list_student2))]
            
            x_next_input = image_list[step - 1] # x_T
            
            x_next_ = self._one_step(t1, t2, t_prev_list_student, model_prev_list_student, step, x_next_input, update_list=False) # x_T-1
            x_next_.backward(grad, retain_graph=False)
            grad = x_next_input.grad.detach() # dL / dx_T #
        
        return grad    
