import jax.numpy as jnp
import jax
import numpy as np
import os, sys
import einops
import optax
from functools import partial
from dataclasses import replace

BASEDIR = os.path.dirname(os.path.dirname(__file__))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)

import util.latent_obj_util as loutil
import util.structs as structs
# import util.model_util as mutil

SIGMA_MIN = 0.0
# SIGMA_MIN = 1e-3
SIGMA_POW_FACTOR = 1
def get_sigma(t, dm_type, s=0.008):
    '''
    ddpm: t:[0,1] -> 0-small noise -> 1-large noise
    edm: t:[0,inf) -> 0-small noise -> inf-large noise
    '''
    if dm_type == 'ddpm':
        return jnp.sqrt(1-jnp.cos((t + s) / (1 + s) * np.pi * 0.5) ** 2)
    elif dm_type == 'edm' or dm_type[:2] == 'fm':
        return jnp.array(t)
    elif dm_type == 'regression':
        return jnp.zeros_like(t)

def init_noise(jkey, h_shape, latent_shape, dm_type, learnable_queries=None):
    if dm_type == 'regression':
        assert learnable_queries is not None
        bc_shape = jnp.broadcast_shapes(h_shape, learnable_queries.shape)
        return jnp.broadcast_to(learnable_queries, bc_shape)
    elif dm_type == 'ddpm':
        return jax.random.normal(jkey, shape=h_shape)
    elif dm_type == 'fm':
        jkey, subkey = jax.random.split(jkey)
        noise = jax.random.normal(subkey, shape=h_shape)
        tmp_obj = loutil.LatentObjects().init_h(noise, latent_shape)
        # random_scale = jax.random.uniform(jkey, shape=tmp_obj.outer_shape+(1,1), minval=0.2, maxval=0.6)
        # _, jkey = jax.random.split(jkey)
        # tmp_obj = tmp_obj.replace(z=random_scale[...,None]*tmp_obj.z)
        tmp_obj = tmp_obj.replace(z=0.3*tmp_obj.z)
        jkey, subkey = jax.random.split(jkey)
        random_rel_fps = 0.3*jax.random.normal(subkey, shape=tmp_obj.rel_fps.shape)
        tmp_obj = tmp_obj.replace(rel_fps=random_rel_fps)
        jkey, subkey = jax.random.split(jkey)
        # tmp_obj = tmp_obj.replace(pos=jax.random.uniform(subkey, shape=tmp_obj.pos.shape, minval=jnp.array([-2,-2,-1.0]), maxval=jnp.array([2,2,1.0])))
        return tmp_obj.h
    
    elif dm_type == 'fm_cat':
        tmp_obj = loutil.LatentObjects().init_h(learnable_queries, latent_shape)
        for _ in range(len(h_shape) -1 - len(tmp_obj.outer_shape)):
            tmp_obj = tmp_obj.extend_outer_shape(0)

        original_fps_tf = tmp_obj.fps_tf

        # select each component
        jks = jax.random.split(jkey, 4)
        rand_int = jax.random.randint(jks[0], shape=h_shape[:-1] + (1,1,1), minval=0, maxval=learnable_queries.shape[-2])
        tmp_obj = tmp_obj.replace(z=jnp.take_along_axis(tmp_obj.z, rand_int, axis=-4))

        rand_int = jax.random.randint(jks[1], shape=h_shape[:-1] + (1,), minval=0, maxval=learnable_queries.shape[-2])
        tmp_obj = tmp_obj.replace(pos=jnp.take_along_axis(tmp_obj.pos, rand_int, axis=-2))

        rand_int = jax.random.randint(jks[2], shape=h_shape[:-1] + (1,1), minval=0, maxval=learnable_queries.shape[-2])
        tmp_obj = tmp_obj.set_fps_tf(jnp.take_along_axis(original_fps_tf, rand_int, axis=-3))
        # tmp_obj = tmp_obj.replace(rel_fps=jnp.take_along_axis(tmp_obj.rel_fps, rand_int, axis=-3))

        return tmp_obj.h

    elif dm_type in ['fm_reg', 'fm_vel']:
        
        tmp_obj = loutil.LatentObjects().init_h(learnable_queries, latent_shape)
        for _ in range(len(h_shape) -1 - len(tmp_obj.outer_shape)):
            tmp_obj = tmp_obj.extend_outer_shape(0)

        original_fps_tf = tmp_obj.fps_tf

        # select each component
        jks = jax.random.split(jkey, 4)
        tmp_obj = replace(tmp_obj, z=tmp_obj.z + jax.random.normal(jks[0], shape=h_shape[:-1] + tmp_obj.z.shape[-3:])*0.1)
        tmp_obj = replace(tmp_obj, pos=tmp_obj.pos + jax.random.normal(jks[1], shape=h_shape[:-1] + tmp_obj.pos.shape[-1:])*0.05)
        tmp_obj = tmp_obj.set_fps_tf(original_fps_tf + jax.random.normal(jks[2], shape=h_shape[:-1] + original_fps_tf.shape[-2:])*0.01)

        return tmp_obj.h


def reverse_wth_zero_pred(jkey, x0_pred:loutil.LatentObjects, t:jnp.ndarray, previous_x_ptb:loutil.LatentObjects, previous_t, dm_type):
    sigma = get_sigma(t, dm_type=dm_type)
    for _ in range(x0_pred.h.ndim - sigma.ndim):
        sigma = sigma[..., None]
    
    if previous_t is not None:
        previous_sigma = get_sigma(previous_t, dm_type=dm_type)
        for _ in range(x0_pred.h.ndim - previous_sigma.ndim):
            previous_sigma = previous_sigma[..., None]

    if dm_type == 'regression':
        return x0_pred

    elif dm_type == 'ddpm':
        # x_prev = x0_pred*jnp.sqrt(1-previous_sigma**2) + previous_sigma * noise
        # x_cur = x0_pred*jnp.sqrt(1-sigma**2) + sigma * noise
        ptb_obj = x0_pred.set_h(previous_x_ptb.h*(sigma/previous_sigma) + x0_pred.h*(-(sigma/previous_sigma)*jnp.sqrt(1-previous_sigma**2) + jnp.sqrt(1-sigma**2)))
    elif dm_type == 'fm_vel':
        ptb_obj = x0_pred.set_h((previous_sigma-sigma)*x0_pred.h + previous_x_ptb.h)
    elif dm_type[:2] == 'fm':
        # fps matching
        # fps_matching_pair, _ = sinkhorn.fps_matching(previous_x_ptb, x0_pred, dif_type='hg', 
        #                                                 dc_pos_loss_coef=20) # order by previous x
        # x0_pred = x0_pred.reorder_fps(fps_matching_pair[...,1])
        ptb_obj = x0_pred.set_h((sigma/previous_sigma)*(previous_x_ptb.h - x0_pred.h) + x0_pred.h)
    return ptb_obj

def calculate_cs(t, args=None):
    t = jnp.array(t)[...,None]
    if args.add_c_skip==1 or args.add_c_skip==0:
        cskip = jnp.zeros_like(t)
        cout =  jnp.ones_like(t)
    elif args.add_c_skip==2:
        cskip = jnp.ones_like(t)
        cout =  jnp.ones_like(t)
    cin = jnp.ones_like(t)
    return t, (cskip, cout, cin)

def sample_t_train(jkey, shape, dm_type='fm'):
    t_samples = jax.random.uniform(jkey, shape=shape)
    if SIGMA_MIN==0:
        return jnp.power(t_samples, SIGMA_POW_FACTOR)
    else:
        return (jnp.power(t_samples, SIGMA_POW_FACTOR)+SIGMA_MIN)/(1+SIGMA_MIN)
    # return t_samples

def get_t_schedule_for_sampling(max_time_steps, dm_type, initial_time_step=None):
    if dm_type=='ddpm' or dm_type[:2]=='fm':
        if initial_time_step is None:
            initial_time_step = 1.0
        # linear_samples = jnp.arange(max_time_steps,0,-1)/max_time_steps
        # linear_samples = np.arange(max_time_steps+1,1,-1)/(max_time_steps+1)
        # linear_samples = np.linspace(0.03,1,max_time_steps)[::-1]
        # linear_samples = np.linspace(0.01,1,max_time_steps)[::-1]
        linear_samples = jnp.linspace(initial_time_step, 0, max_time_steps, endpoint=False)

        if SIGMA_MIN==0:
            return jnp.power(linear_samples, SIGMA_POW_FACTOR)
        else:
            return (jnp.power(linear_samples, SIGMA_POW_FACTOR)+SIGMA_MIN)/(1+SIGMA_MIN)
    elif dm_type=='regression':
        return np.zeros(max_time_steps)
    else:
        raise NotImplementedError


def euler_sampler_obj_single_fori(models, x_shape, model_apply_func, cond:structs.ImgFeatures, jkey:jnp.ndarray, max_time_steps:int=10, 
                                  output_obj_no:int=None, conf_filter_out:bool=True, conf_threshold:float=-0.5, update_cam_pq:bool=False, 
                                  den_repeat_no:int=None, initial_oriCORNs:loutil.LatentObjects=None, initial_time_step:float=None):
    '''
    conf_filter_out -> if True, the object position will be set to [0,0,10] if the confidence is lower than conf_threshold
    '''
    if den_repeat_no is None:
        den_repeat_no = models.dif_args.den_repeat_no
    dm_type = models.dif_args.dm_type
    latent_shape = models.latent_shape
    learnable_queries = models.learnable_queries
    cond = jax.tree_map(lambda x_: jnp.array(x_), cond)
    if dm_type == 'regression':
        max_time_steps = 5
    noise_h = init_noise(jkey, x_shape, latent_shape, dm_type, learnable_queries)
    _, jkey = jax.random.split(jkey)
    x = loutil.LatentObjects().init_h(noise_h, latent_shape)

    if initial_oriCORNs is not None:
        initial_oriCORNs = initial_oriCORNs.valid_obj_padding(initial_oriCORNs.nobj, jkey, z_scale=1.0)
        x = reverse_wth_zero_pred(jkey, initial_oriCORNs, initial_time_step, x, jnp.array(1.0), dm_type)

    x = x.init_conf_zero()
    t_schedule =get_t_schedule_for_sampling(max_time_steps, initial_time_step=initial_time_step, dm_type=dm_type)
    t_schedule = jnp.array(t_schedule).astype(jnp.float32)
    t_schedule = jnp.concat([t_schedule, jnp.zeros_like(t_schedule[:1])], axis=0) # add zero at last -> max_time_steps + 1
    
    def f_(carry, i):
        x, cam_preds, jkey = carry
        if update_cam_pq:
            cond_in = cond.replace(cam_posquat=cam_preds)
        else:
            cond_in = cond
        x0_pred = x
        init_emb = models.init_emb(x0_pred.h.shape[:-1])
        for den_itr in range(den_repeat_no):
            x0_pred, conf, cam_preds_, embs = model_apply_func(x0_pred, cond_in, t_schedule[i], None, jkey, init_emb)
            if den_itr == 0:
                init_emb = embs
        cam_preds = cam_preds_.at[...,0,:].set(cam_preds[...,0,:])
        _, jkey = jax.random.split(jkey)
        if dm_type == 'fm_vel':
            x0_pred_save = reverse_wth_zero_pred(jkey, x0_pred, jnp.zeros_like(t_schedule[...,i]), x, t_schedule[...,i], dm_type)
        else:
            x0_pred_save = x0_pred
        
        # apply confidence filter
        # x0_pred:loutil.LatentObjects
        # valid_conf_mask = conf > conf_threshold
        # x0_pred = x0_pred.valid_obj_padding(x0_pred.nobj, jkey, z_scale=1.0, obj_valid_mask=valid_conf_mask.squeeze(-1))

        x = reverse_wth_zero_pred(jkey, x0_pred, t_schedule[...,i+1], x, t_schedule[...,i], dm_type)


        _, jkey = jax.random.split(jkey)
        return (x, cam_preds, jkey), (x0_pred_save, x, conf)
    
    (x, cam_preds, jkey), (x_pred_list, x_diffusion_list, conf_list) = jax.lax.scan(f_, (x, cond.cam_posquat, jkey), jnp.arange(max_time_steps))
    x = x_pred_list[-1] # the last prediction is from zero prediction
    # conf = conf_list[-1]
    # conf = jnp.mean(conf_list[-max_time_steps//5:], axis=0) # average last 20% of confidences
    conf = jnp.mean(conf_list[-max(max_time_steps//10,1):], axis=0) # average last 20% of confidences
    x = x.set_conf(conf)

    # ### confidence out ###
    # if x_shape[-2] != 1 and output_obj_no is not None:
    #     pick_obj_indices = jnp.argsort(-conf, axis=-2)[...,:output_obj_no,:]
    # else:
    #     pick_obj_indices = jnp.argsort(-conf, axis=-2)
    # def align_rank(arr, rank):
    #     for _ in range(rank - arr.ndim):
    #         arr = arr[...,None]
    #     return arr
    # x, conf = jax.tree_map(lambda x: jnp.take_along_axis(x, align_rank(pick_obj_indices, x.ndim), axis=len(x_shape)-2), (x, conf))
    
    x = x.sort_by_conf()


    valid_mask = x.conf>conf_threshold
    # valid_mask = conf>jnp.max(conf)-1.0
    if x_shape[-2] != 1 and conf_filter_out:
        x = x.deprecate_obj(valid_mask)
        # x = x.replace(pos=jnp.where(valid_mask, x.pos, jnp.array([0,0,10.])))

    # add final state
    x_pred_list = x_pred_list.concat(x[None], axis=0)
    x_diffusion_list = x_diffusion_list.concat(x[None], axis=0)
    if len(x_shape) == 3:
        # make batch axis
        x_pred_list, x_diffusion_list = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), (x_pred_list, x_diffusion_list))
    aux_infos = {'cam_preds':cam_preds, 'conf':x.conf, 'conf_list':conf_list, 'x_pred_list':x_pred_list, 'x_diffusion_list':x_diffusion_list}

    return x, aux_infos


def perturb_recover_obj(models, obj:loutil.LatentObjects, model_apply_jit, cond, t, jkey):
    dm_type = models.dif_args.dm_type
    den_repeat_no = models.dif_args.den_repeat_no
    learnable_queries = models.learnable_queries
    extended = False
    if len(obj.outer_shape) == 0:
        extended = True
        obj = obj.extend_outer_shape(0)
    obj_ptb_init = obj.set_h(init_noise(jkey, obj.h.shape, obj.latent_shape, dm_type, learnable_queries))
    if dm_type == 'fm_vel':
        obj_vel = obj.set_h(obj.h - obj_ptb_init.h)
        obj_ptb = reverse_wth_zero_pred(jkey, obj_vel, t, obj_ptb_init, np.array([1.0]), dm_type)
    else:
        obj_ptb = reverse_wth_zero_pred(jkey, obj, t, obj_ptb_init, np.array([1.0]), dm_type)
    _, jkey = jax.random.split(jkey)
    obj_rec = obj_ptb
    init_emb = models.init_emb(obj_rec.h.shape[:-1])
    for den_itr in range(den_repeat_no):
        obj_rec, conf, _, emb = model_apply_jit(obj_rec, cond, t, jnp.array([True]), jkey, init_emb)
        if den_itr == 0:
            init_emb = emb
    if dm_type == 'fm_vel':
        obj_rec = reverse_wth_zero_pred(jkey, obj_rec, jnp.array([0.0]), obj_ptb, t, dm_type)
    if extended:
        obj_rec = jax.tree_map(lambda x: x.squeeze(0), obj_rec)
    return obj_rec

# %%



if __name__ == '__main__':
    import matplotlib.pyplot as plt
    jkey = jax.random.PRNGKey(0)

    sigma2 = get_sigma(jnp.ones(1000)*0.5, dm_type='ddpm').clip(1e-5, 1-1e-5)
    sigma = get_sigma(jnp.linspace(0,0.5,1000), dm_type='ddpm')
    sigma2 = sigma/sigma2*jnp.sqrt(1-(1-sigma2**2)/(1-sigma**2))
    
    plt.figure()
    plt.plot(sigma)
    plt.plot(sigma2)
    plt.show()

    # t_sp = sample_t_train(jkey, (50000,), EDMP, dm_type='edm')
    t_sp = sample_t_train(jkey, (50000,), EDMP, dm_type='ddpm', add_t_sample_bias=10)
    # sigma = get_sigma(t_sp, dm_type='ddpm')
    # plt.figure()
    # plt.hist(t_sp, bins=1000)
    # plt.show()

    # # sigma = get_sigma(np.arange(100)/100, dm_type='ddpm')
    # sigma = jnp.sqrt(1-jnp.cos(np.arange(100)/100 * np.pi * 0.5) ** 2.5)
    # plt.figure()
    # plt.plot(sigma)
    # plt.show()

    class Args:
        pass
    args = Args()
    args.dm_type = 'ddpm_noise'
    args.add_c_skip = 0

    t_schedule =get_t_schedule_for_sampling(1000, edm_params=EDMP, dm_type='ddpm_noise')
    time, (c_skip, c_out, c_in) = calculate_cs(t_schedule, EDMP, args)

    plt.figure()
    # plt.plot(time)
    plt.plot(c_skip)
    plt.plot(c_out)
    plt.plot(c_in)
    plt.legend(['c_skip', 'c_out', 'c_in'])
    plt.show()