from clus.models.model.cdm import *
from flax.training import train_state
from flax import struct
from copy import deepcopy
from jax import tree_util

### EWC loss of diffusion model ### 
# TODO think in lora
class EWCState(train_state.TrainState) :
    '''
    TrainState for EWC
    '''
    ewc_params : struct.field(pytree_node=True) # ewc params
    fisher_params : struct.field(pytree_node=True) # fisher params

def deep_copy_params(params):
    return jax.tree_map(lambda x: jnp.array(x, copy=True), params)


def clip_gradients(grads, max_norm):
    norm = jnp.sqrt(sum([jnp.sum(g ** 2) for g in jax.tree_leaves(grads)]))
    scale = min(1, max_norm / norm)
    return jax.tree_map(lambda g: g * scale, grads)
   
class EWCDiffusion(ConditionalDiffusion) : 
    '''
    for using EWC loss, you should use this class instead of ConditionalDiffusion
    NOTE : TrainState is forced to use EWCtrain_state
    '''
    def __init__(
            self,
            ewc_mode = 'ewc', # L2 for l2 regularization, fisher for fisher regularization
            ewc_ratio = 5e4,
            **kwargs
        ) :
        super().__init__(**kwargs)
        
        self.ewc_mode = ewc_mode
        if ewc_mode in ['GA', 'ERGA', 'ERSA','ER'] :
            self.ewc_mode = 'ewc'
        self.ewc_ratio = ewc_ratio
        self.imp_count = 0
        self.first_phase = True
        self.avg_ratio =0.9
        self.ascent_ratio = 0.01
    def reinit_optimizer(self):
        optimizer = self.optimizer_config['optimizer_cls'](
            **self.optimizer_config['optimizer_kwargs']
        )

        zeros_like_params = tree_util.tree_map(lambda x : jnp.zeros_like(x), self.train_state.params)
        self.train_state = EWCState.create(
            apply_fn=self.train_state.apply_fn,
            params=self.train_state.params,
            tx=optimizer,
            ewc_params=deep_copy_params(self.train_state.params),
            fisher_params=zeros_like_params,
        )

    @classmethod
    def from_conditional_diffusion(
            cls, 
            conditional_diffusion_instance,
            optimizer_config=None,
            ewc_mode='ewc', 
            # ewc_ratio=0.,
            ewc_ratio=1e1,
        ):
        print("[EWC] mode : ", ewc_mode)
        print("[EWC] ratio : ", ewc_ratio)
        kwargs = deepcopy(conditional_diffusion_instance.__dict__)
        prev_train_state = conditional_diffusion_instance.train_state
        ewc_diffusion = cls(ewc_mode=ewc_mode, ewc_ratio=ewc_ratio, **kwargs)
        ewc_diffusion.train_state = prev_train_state

        # fine tuning optimizer configs
        if optimizer_config is None:
            optimizer_config = conditional_diffusion_instance.optimizer_config
        ewc_diffusion.optimizer_config = optimizer_config

        zeros_like_params = tree_util.tree_map(lambda x : jnp.zeros_like(x), ewc_diffusion.train_state.params)

        optimizer = optimizer_config['optimizer_cls'](**optimizer_config['optimizer_kwargs'])
        optimizer = optax.chain(optax.clip_by_global_norm(1.0), optimizer)

        ewc_diffusion.train_state = EWCState.create(
            apply_fn=ewc_diffusion.train_state.apply_fn,
            params=ewc_diffusion.train_state.params,
            tx=optimizer,
            ewc_params=deep_copy_params(ewc_diffusion.train_state.params),
            fisher_params=zeros_like_params,
        )

        return ewc_diffusion

    def calc_reg(self, params, ewc_params, fisher_info=None):
        '''
        this function walks the leaf nodes of the grads and mask the gradients
        '''
        
        # rank_mask : (book,)
        # def ewc_reg_fn(path, param):
        #     ewc_param = tree_util.tree_flatten(ewc_params)[0][tree_util.tree_flatten(params)[1].index(path)]
        #     fisher = tree_util.tree_flatten(fisher_info)[0][tree_util.tree_flatten(params)[1].index(path)]
        #     return fisher * jnp.square(param - ewc_param)
        def l2_reg_fn(param, ewc_param) :
            return jnp.sum(jnp.square(param - ewc_param))
        if fisher_info is None :
            ewc_loss_terms = tree_util.tree_map(l2_reg_fn, params, ewc_params)
        else :
            ewc_loss_terms = tree_util.tree_map(
                lambda x, y, z : jnp.sum(z * jnp.square(x - y)), 
                params, 
                ewc_params, 
                fisher_info
            )
        total_ewc_loss = sum(tree_util.tree_leaves(ewc_loss_terms)) / 2
        return total_ewc_loss

    ### Difussion model ###
    def p_losses(self, params, state, x_start, t, cond, noise=None, rngs=None, alpha=0.0) :
        model_loss , loss_dict = super().p_losses(params, state, x_start, t, cond, noise=noise, rngs=rngs)
        
        # EWC loss calculation
        if self.ewc_mode == 'l2' :
            ewc_loss = self.calc_reg(params, state.ewc_params)
        elif self.ewc_mode == 'ewc' :
            ewc_loss = self.calc_reg(params, state.ewc_params, fisher_info=state.fisher_params)

        else :
            raise NotImplementedError(f"ewc_mode : {self.ewc_mode} is not supported")
        loss_dict['ewc'] = ewc_loss

        total_loss = model_loss + alpha * ewc_loss
        return total_loss , loss_dict
    
    ##### Gradient ascent Calculation ####
    @partial(jax.jit, static_argnums=(0,))
    def train_model_jit(self, state, x, t, cond, noise, rngs=None, alpha=0.0):
        grad_fn = jax.grad(self.p_losses, has_aux=True)
        grads, loss_dict = grad_fn(state.params, state, x, t, cond, noise, rngs=rngs, alpha=alpha)
        metric = (None, loss_dict)
        # grads = clip_gradients(grads, 1.0)

        state = state.apply_gradients(grads=grads)
        return state, metric
    
    @partial(jax.jit, static_argnums=(0,))
    def unlearn_model_jit(self, state, x, t, cond, noise, rngs=None, alpha=0.0):
        grad_fn = jax.grad(self.p_losses, has_aux=True)
        grads, loss_dict = grad_fn(state.params, state, x, t, cond, noise, rngs=rngs, alpha=alpha)
        metric = (None, loss_dict)

        neg_grads = tree_util.tree_map(lambda x : -x*self.ascent_ratio, grads)
        # neg_grads = clip_gradients(neg_grads, 1.0)

        state = state.apply_gradients(grads=neg_grads)
        return state, metric

    @partial(jax.jit, static_argnums=(0,))
    def unlearn_train_model_jit(
            self, 
            state, 
            x, t, cond, noise, 
            u_x, u_t, u_cond, u_noise,
            alpha=0.0,
            rngs=None
        ):
        grad_fn = jax.grad(self.p_losses, has_aux=True)
        decent_grads, loss_dict = grad_fn(state.params, state, x, t, cond, noise, rngs=rngs, alpha=alpha)
        metric = (None, loss_dict)
        ascent_grads, _ = grad_fn(state.params, state, u_x, u_t, u_cond, u_noise, rngs=rngs)
        # ascent_grads = clip_gradients(ascent_grads,1.0)
        grads = tree_util.tree_map(lambda x, y : x - y*self.ascent_ratio, decent_grads, ascent_grads)
        state = state.apply_gradients(grads=grads)
        return state, metric
    
    def train_model(self, x, cond, t=None, u_x=None, u_cond=None, reg=True) :
        if u_x is  None and u_cond is None : # learning 
            if x.ndim == 2 :
                x = x[:,None,:]
            if cond.ndim == 2 :
                cond = cond[:,None,:]

            if t is None :
                t = np.random.randint(0, self.num_timesteps, (x.shape[0],) ) 
            else : # shape must (Batch,)
                assert t.shape[0] == x.shape[0] , "t and x must have same batch size"
            ewc_ratio = self.ewc_ratio if reg == True else 0.0
            noise = np.random.randn(*(x.shape[0],1,self.out_dim)) # shape (Batch, 1, F)
            self.train_state , metric = self.train_model_jit(
                self.train_state, 
                x, t, cond, noise, 
                rngs=self.sample_rngs,
                alpha=ewc_ratio    
            )
            self.sample_rngs = update_rngs(self.sample_rngs)
            return metric
        elif x is None and cond is None : # only unlearning
            if u_x.ndim == 2 :
                u_x = u_x[:,None,:]
            if u_cond.ndim == 2 :
                u_cond = u_cond[:,None,:]
            
            if t is None :
                t = np.random.randint(0, self.num_timesteps, (x.shape[0],) ) 
            else : # shape must (Batch,)
                assert t.shape[0] == x.shape[0] , "t and x must have same batch size"
            ewc_ratio = self.ewc_ratio if reg == True else 0.0
            noise = np.random.randn(*(x.shape[0],1,self.out_dim)) # shape (Batch, 1, F)
            self.train_state , metric = self.unlearn_model_jit(
                self.train_state, 
                u_x, t, u_cond, noise, 
                rngs=self.sample_rngs,
                alpha=ewc_ratio    
            )
            self.sample_rngs = update_rngs(self.sample_rngs)
            return metric
        else : # learning with replay
            if x.ndim == 2 :
                x = x[:,None,:]
            if cond.ndim == 2 :
                cond = cond[:,None,:]
            if u_x.ndim == 2 :
                u_x = u_x[:,None,:]
            if u_cond.ndim == 2 :
                u_cond = u_cond[:,None,:]

            if t is None :
                t = np.random.randint(0, self.num_timesteps, (x.shape[0],) ) 
            else : # shape must (Batch,)
                assert t.shape[0] == x.shape[0] , "t and x must have same batch size"
            noise = np.random.randn(*(x.shape[0],1,self.out_dim)) # shape (Batch, 1, F)
            u_noise = np.random.randn(*(u_x.shape[0],1,self.out_dim))
            ewc_ratio = self.ewc_ratio if reg == True else 0.0
            self.train_state , metric = self.unlearn_train_model_jit(
                self.train_state, 
                x, t, cond, noise, 
                u_x, t, u_cond, u_noise,
                rngs=self.sample_rngs,
                alpha=ewc_ratio  
            )
            self.sample_rngs = update_rngs(self.sample_rngs)
            return metric
    
    @partial(jax.jit, static_argnums=(0,))
    def _get_grads(self, state, x, t, cond, noise=None, rngs=None) :
        '''
        get gradients of the loss
        # must use gradients of Original model loss.
        '''
        grad_fn = jax.grad(super().p_losses, has_aux=True)
        grads, loss_dict = grad_fn(state.params, state, x, t, cond, noise, rngs=rngs)
        return grads
    
    #### fisher matrix calculation ####
    def update_importance(self) :
        '''
        normalize importance called after each task
        '''
        print(f"[EWC] updated importance : {self.imp_count}")
        if self.first_phase == True :
            self.first_phase = False
            self.train_state = self.train_state.replace(
                fisher_params=tree_util.tree_map(lambda x : x / self.imp_count, self.train_state.fisher_params)
            )
        else : 
            old_fisher = self.old_fisher
            new_fisher = tree_util.tree_map(lambda x : x / self.imp_count, self.train_state.fisher_params)
            self.train_state = self.train_state.replace(
                fisher_params=tree_util.tree_map(lambda x, y : self.avg_ratio*x + (1-self.avg_ratio)*y, old_fisher, new_fisher)
            )
        self.old_fisher = self.train_state.fisher_params
        self.imp_count = 0

    def _update_fisher(self, grads) :
        '''
        update fisher information
        '''
        # print("[EWC] fisher updated!")
        importance = tree_util.tree_map(lambda x : jnp.square(x), grads)
        if self.train_state.fisher_params is None :
            self.train_state = self.train_state.replace(fisher_params=importance)
        else :
            self.train_state = self.train_state.replace(fisher_params=tree_util.tree_map(lambda x, y : x + y, self.train_state.fisher_params, importance))

    def fisher_matrix_accumulation(self, x, cond) :
        '''
        calculate fisher matrix
        '''
        if x.ndim == 2 :
            x = x[:,None,:]
        if cond.ndim == 2 :
            cond = cond[:,None,:]

        t = np.random.randint(0, self.num_timesteps, (x.shape[0],) ) 
        noise = np.random.randn(*(x.shape[0],1,self.out_dim)) # shape (Batch, 1, F)

        grads = self._get_grads(self.train_state, x, t, cond, noise=noise, rngs=self.sample_rngs)
        self.sample_rngs = update_rngs(self.sample_rngs)
        self._update_fisher(grads)
        self.imp_count += 1

        # TODO : check the fisher matrix calcu



