import jax
from jax import numpy as jnp

from flax import struct

from flax.training import common_utils
from typing import Any, Callable
from functools import partial

from .utils import *

import tqdm


@struct.dataclass
class GradPCA:
    num_classes: int = struct.field(pytree_node=False)
    train_state: Any
    grad_mean: Any
    grad_eigenvecs: list 
    num_eigenvecs: int = struct.field(pytree_node=False)
    method: str = struct.field(pytree_node=False) 
    aggregate : bool = struct.field(pytree_node=False) 
    param_keys: Any = struct.field(pytree_node=False) # either None or set of keys 
    mask: Any 

    def __call__(self,x):
        return self.compute_score(x)

    @classmethod
    def create(cls, num_classes, state, 
               data_loader, 
               method = 'block_structure', 
               param_keys = (('classifier',),), 
               aggregate = True ,
               eps = 0.99, 
               dice = False, 
               dice_p = 0.8,
               random_batch_size = 1000):
        
        
        mask = None
        
        if method == 'block_structure':

            dataset = data_loader.get_loader_by_class(framework='jax')

            grad_vecs = cls._compute_jacobian_by_class(state,dataset,param_keys,num_classes,aggregate=aggregate)
            global_mean = jax.tree_map(lambda x: jnp.mean(x, axis = 0),grad_vecs)

            if dice:
                mask, global_mean, grad_vecs = cls._apply_dice(state,param_keys,global_mean,grad_vecs,dice_p)

            grad_vecs = jax.tree_map(jax.jit(jnp.subtract,donate_argnums=0),grad_vecs,global_mean) 

            eigvals, eigvecs = block_svd(grad_vecs,aggregate=aggregate)
            spectrum_cut = cls._select_eigen_cut(eigvals, eps, num_classes, aggregate)
            eigvecs = eigvecs[..., :spectrum_cut]
            orth_fim_top_eigv = cls._project_and_normalize(grad_vecs, eigvecs, spectrum_cut, aggregate) 
            del grad_vecs

        elif method == 'batch':

            dataset = data_loader.get_random_batch(random_batch_size, framework='jax')

            grad_vecs = jacobian(state, dataset['image'], param_keys, aggregate=aggregate )
            global_mean = jax.tree_map(lambda x: jnp.mean(x, axis = 0),grad_vecs)
            
            if dice:
                mask, global_mean, grad_vecs = cls._apply_dice(state,param_keys,global_mean,grad_vecs,dice_p)

            grad_vecs = jax.tree_map(jax.jit(jnp.subtract,donate_argnums=0),grad_vecs,global_mean)

            eigvals, eigvecs = block_svd(grad_vecs,aggregate=aggregate)
            spectrum_cut = cls._select_eigen_cut(eigvals, eps, num_classes, aggregate)
            eigvecs = eigvecs[..., :spectrum_cut]
            orth_fim_top_eigv = cls._project_and_normalize(grad_vecs, eigvecs, spectrum_cut, aggregate) 
            del grad_vecs

        elif method == 'GradOrth':
            aggregate = True

            dataset = data_loader.get_random_batch(random_batch_size, framework='jax')
            batch = {'image':  dataset['image'],
                     'label': common_utils.onehot(dataset['label'], num_classes=num_classes)}

            batch_features = features(state, batch['image'])

            eigvals, eigvecs = block_svd(batch_features, aggregate)#always aggregate  since features are 1d
            del batch_features

            spectrum_cut = cls._select_eigen_cut(eigvals, eps, num_classes, aggregate)
            eigvecs = eigvecs[..., :spectrum_cut]

            grad_vecs = loss_jacobian(state, batch, params_keys=param_keys)
            global_mean = jax.tree_util.tree_map(lambda x: jnp.zeros(x.shape[1:]), grad_vecs)
            
            orth_fim_top_eigv = cls._project_and_normalize(grad_vecs, eigvecs, spectrum_cut, aggregate) 
            
        else:
            raise Exception("Unknown method.")
            
        return cls(num_classes,
                   state,
                   global_mean,
                   orth_fim_top_eigv, 
                   spectrum_cut, 
                   method, 
                   aggregate, 
                   param_keys,
                   mask) 
    
    @staticmethod
    def _select_eigen_cut(eigvals, eps, num_classes, aggregate):
        eigvals_sum = jnp.sum(eigvals, axis=-1)
        eigvals_cumsum = jnp.cumsum(eigvals, axis=-1).T

        if aggregate:
            return jnp.argwhere(eigvals_cumsum / eigvals_sum >= eps).flatten()[0] + 1
        else:
            _cuts = jnp.sum(eigvals_cumsum / eigvals_sum >= eps, axis=-1)
            return jnp.argwhere(_cuts == num_classes).flatten()[0] + 1
        
    @staticmethod
    def _project_and_normalize(jacobian, eigvecs, num_vecs, aggregate):
        if not aggregate:
            f = lambda x: jnp.einsum('jik...,ijh->hik...', x, eigvecs)
        else:
            f = lambda x: jnp.einsum('jk...,jh->hk...', x, eigvecs)
        proj = jax.tree_map(jax.jit(f), jacobian)
        return normalize(proj, num_vecs=num_vecs, aggregate=aggregate)
    
    @staticmethod
    def _apply_dice(state,param_keys,global_mean,grad_vecs,dice_p):
        def make_mask(x):
            x = jnp.abs(x)
            x_size = len(x.flatten())
            k = int(x_size*(1.-dice_p))
            top_k_val = sorted(x.flatten())[-k]
            mask = jnp.zeros(x.shape)
            mask = mask.at[x >= top_k_val].set(1)
            return mask 

        mask = jax.tree_map(lambda x: make_mask(x),get_params_subset(state.params,param_keys))
        global_mean = jax.tree_map(jnp.multiply, mask, global_mean)
        grad_vecs = jax.tree_map(jnp.multiply, mask, grad_vecs)

        return mask, global_mean, grad_vecs 

    @partial(jax.jit, static_argnames=['parallelize_over'])
    def compute_score(self, x, parallelize_over='eigvecs'):
        if self.method == 'gradORTH':
            batch = {'image': x,
                     'label': jnp.ones((self.num_classes,))/self.num_classes}
            jac = loss_jacobian(self.train_state, batch, self.param_keys)
        else:
            jac = jacobian(self.train_state, x, self.param_keys, aggregate =self.aggregate )
            
        jac = jax.tree_map(jnp.subtract,jac,self.grad_mean)

        if self.mask is not None:
             jac = jax.tree_map(jnp.multiply,jac,self.mask)
        
        jac_proj = jax.tree_util.tree_map(lambda x: jnp.zeros(x.shape), jac)

        _jac_dot = partial(jac_dot, aggregate=self.aggregate)
        _jac_broadcast = partial(jac_broadcast, aggregate=self.aggregate)

        if parallelize_over == 'eigvecs':
            for i in range(x.shape[0]):
                grad = jax.tree_map(lambda x: x[i], jac)
                dot_product = jax.vmap(_jac_dot,(None,0),0)(grad,self.grad_eigenvecs)
                
                def proj_f(tree,scalar):
                    return jax.tree_util.tree_map(lambda x: x*_jac_broadcast(scalar,x), tree)
                proj_of_v = jax.vmap(proj_f,(0,0),0)(self.grad_eigenvecs,dot_product)

                jac_proj_of_v = jax.tree_util.tree_map(lambda x: jnp.sum(x,axis=0), proj_of_v)  
                jac_proj = jax.tree_util.tree_map(lambda x,y: x.at[i].set(y), jac_proj, jac_proj_of_v)
        
        elif parallelize_over == 'batch':
            for i in range(self.num_eigenvecs):
                v = jax.tree_map(lambda x: x[i], self.grad_eigenvecs)
                
                dot_product = jax.vmap(_jac_dot,(None,0),0)(v,jac)
                def proj_f(tree,scalar):
                    return jax.tree_util.tree_map(lambda x: x*_jac_broadcast(scalar,x), tree)
                proj_on_v = jax.vmap(proj_f,(None,0),0)(v,dot_product)
                
                jac_proj = jax.tree_util.tree_map(jnp.add, proj_on_v, jac_proj)  
        
        jac_norm =  jax.vmap(_jac_dot,(0,0),0)(jac,jac) 
        jac_proj = jax.vmap(_jac_dot,(0,0),0)(jac_proj,jac_proj)
        return jac_proj/jac_norm

    
    @staticmethod
    def _compute_jacobian_by_class(state,dataset_by_class_func,params_keys,num_classes,aggregate =False):
    #function to compute jacobian averaged for each class on a given dataset

        if not aggregate :
            def jac_by_class_shape(x):
                return tuple([num_classes,num_classes]+list(x.shape))
            def class_jac_shape(x):
                return tuple([num_classes]+list(x.shape))
        else:
            def jac_by_class_shape(x):
                return tuple([num_classes]+list(x.shape))
            def class_jac_shape(x):
                return x.shape
            
        jacobian_by_class = jax.tree_util.tree_map(lambda x: jnp.zeros(jac_by_class_shape(x)), 
                                            get_params_subset(state.params,params_keys))


        for c in tqdm.tqdm(range(num_classes)):
            class_jac = jax.tree_util.tree_map(lambda x: jnp.zeros(class_jac_shape(x)), get_params_subset(state.params,params_keys)) #create three of zeros with structure of state.params

            num_batches = 0
            for batch in dataset_by_class_func(c): 
                batch_jac = mean_batch_jacobian(state, batch, params_keys, aggregate =aggregate ) 
                class_jac = jax.tree_util.tree_map(jax.jit(jnp.add,donate_argnums=1), batch_jac, class_jac)
                num_batches += 1
            
            class_jac = jax.tree_util.tree_map(jax.jit(lambda x: x/num_batches,donate_argnums=0), class_jac)
            jacobian_by_class = jax.tree_util.tree_map(partial(replace,c=c), jacobian_by_class, class_jac)
            
        return jacobian_by_class #pytree with leaves indexed as (class,out_head,...)
    



