from dataclasses import replace
from src.offline_smoothing import *
from src.online_smoothing import *
from src.stats.hmm import * 
from src.variational.sequential_models import *
from jax._src.tree_util import GetAttrKey, DictKey, SequenceKey
# import tensorflow as tf 
from jax.tree_util import tree_flatten,tree_flatten_with_path,tree_unflatten

import jax
from jax import vmap, value_and_grad, numpy as jnp
import multiprocessing
import optax 
from time import time
from jax_tqdm import scan_tqdm
from dataclasses import asdict


def define_frozen_tree(frozen_params, theta_init, theta_true):
    if 'prior' in frozen_params:
        theta_init.prior = theta_true.prior

    if 'transition' in frozen_params:
        theta_init.transition = theta_true.transition

    if 'emission' in frozen_params:
        theta_init.emission = theta_true.emission

    if 'transition.noise' in frozen_params:
        tr = copy.copy(theta_init.transition)
        theta_init.transition.noise.scale = theta_true.transition.noise.scale

    if 'emission.noise' in frozen_params:
        em = copy.copy(theta_init.emission)
        em.noise.scale = theta_true.emission.noise.scale
        theta_init.emission = em

    if 'all' in frozen_params:
        theta_init.prior = jax.tree.map(lax.stop_gradient, theta_true.prior)
        theta_init.transition = jax.tree.map(lax.stop_gradient, theta_true.transition)
        theta_init.emission = jax.tree.map(lax.stop_gradient, theta_true.emission)

    if 'transition.map' in frozen_params:
        theta_init.transition.map['w'] = theta_true.transition.map['w']
    if 'emission.map' in frozen_params:
        theta_init.emission.map['w'] = theta_true.emission.map['w']
        
        
    if 'transition.map.chaotic_W.w' in frozen_params:
        if 'linear' in theta_init.transition.map and 'w' in theta_init.transition.map['linear']:
            theta_init.transition.map['linear']['w'] = theta_true.transition.map['linear']['w']
        elif 'chaotic_W' in theta_init.transition.map and 'w' in theta_init.transition.map['chaotic_W']:
            theta_init.transition.map['chaotic_W']['w'] = theta_true.transition.map['chaotic_W']['w']
    return theta_init


def _is_int_or_bool(x):
    try:
        dt = jnp.asarray(x).dtype
        return jnp.issubdtype(dt, jnp.integer) or jnp.issubdtype(dt, jnp.bool_)
    except Exception:
        return False

def floatize_static(tree):
    return jax.tree.map(
        lambda x: jnp.asarray(x, jnp.float64) if _is_int_or_bool(x) else x,
        tree
    )

def build_mask(params, frozen):
    print(frozen)
    if frozen == "all":
        leaves, treedef = tree_flatten(params)
        mask_leaves = [False] * len(leaves)
        return tree_unflatten(treedef, mask_leaves)

    def get_subtree(root, dotted):
        obj = root
        for tok in dotted.split("."):
            obj = getattr(obj, tok) if not isinstance(obj, dict) else obj[tok]
        return obj

    frozen_ids = set()
    for key in frozen:
        sub = get_subtree(params, key)
        for leaf in tree_leaves(sub):
            frozen_ids.add(id(leaf))

    # flatten whole params, mark leaves by identity, and unflatten
    leaves, treedef = tree_flatten(params)
    mask_leaves = [id(leaf) not in frozen_ids for leaf in leaves]  # True=trainable
    return tree_unflatten(treedef, mask_leaves)



def mae_maps(true, guess):
    F_t = jnp.asarray(true.transition.map.w)   
    G_t = jnp.asarray(true.emission.map.w)         
    F_g = jnp.asarray(guess.transition.map.w)  
    G_g = jnp.asarray(guess.emission.map.w)    
    F_err = jnp.mean(jnp.abs(F_t - F_g))
    G_err = jnp.mean(jnp.abs(G_t - G_g))
    return F_err, G_err






    
class SVITrainer:

    def __init__(self, 
                p:HMM, 
                theta_true,
                theta_star,
                q:BackwardSmoother, 
                optimizer, 
                learning_rate,
                learning_rate_model,
                optim_options,
                num_epochs, 
                seq_length,
                num_samples=1, 
                frozen_params='',
                num_seqs=1,
                training_mode='offline',
                elbo_mode='autodiff_on_batch'):
        
        self.num_epochs = num_epochs
        self.q = q 
        self.seq_length = seq_length
        self.q.print_num_params()
        self.p = p
        

        theta_star = define_frozen_tree(frozen_params, theta_star, theta_true)
        theta_star = floatize_static(theta_star)

        
        
        self.theta_star = theta_star
        
        self.formatted_theta_star = self.p.format_params(theta_star, 
                                                         precompute=['prec'])
        
        
        self.theta_true = theta_true
        self.formatted_theta_true = self.p.format_params(theta_true, 
                                                         precompute=['prec'])
             
    
        
        self.frozen_params = frozen_params




        self.num_seqs = num_seqs
        self.elbo_mode = elbo_mode
        self.training_mode = training_mode
        
        
        
        if 'share_params' in elbo_mode:
            print('Using transition and prior from true model.')
            if isinstance(q, LinearGaussianHMM):
                def build_params(params):
                    return HMM.Params(
                                    prior=theta_star.prior,
                                    transition=theta_star.transition,
                                    emission=params) 
                def extract_params(params):
                    return params.emission
            elif isinstance(q, JohnsonSmoother):
                def build_params(params):
                    return JohnsonParams(prior=theta_star.prior,
                                         transition=theta_star.transition,
                                         net=params)
                def extract_params(params):
                    return params.net
            elif isinstance(q, NeuralBackwardSmoother):
                def build_params(params):
                    return NeuralBackwardSmoother.Params(prior=params[0], 
                                                         backwd=theta_star.transition,
                                                         state=params[1],
                                                         filt=params[2])
                def extract_params(params):
                    return params.prior, params.state, params.filt
        else:
            def build_params(params):
                return params 
            def extract_params(params):
                return params
            
        self._build_params = build_params
        self._extract_params = extract_params

        if isinstance(p, LinearGaussianHMM) and isinstance(q, LinearGaussianHMM):
            print('Monitor ELBO is analytical.')
            monitor_elbo = LinearGaussianELBO(p, q)
            self.monitor_elbo = lambda _ , obs_seq, compute_up_to, theta, phi: monitor_elbo(obs_seq, 
                                                                                            compute_up_to,
                                                                                            theta, 
                                                                                            phi)
        
        else: 
            self.monitor_elbo = None
            print('Not monitoring the ELBO.')

        # We use streaming data for arriving observations
        if 'streaming' in training_mode:
            self.online_difference = 'difference' in training_mode
            self.num_grad_steps = int(training_mode.split(',')[1])
            streaming = True

            self.online_batch_size = 1
            self.reset = False
                
        
        
        # Set streaming to true
        self.streaming = streaming
        self.monitor = 'monitor' in elbo_mode
        
        # If we dont compute the elbo in a closed form
        if not self.training_mode == 'closed_form':
            self.elbo_options = {}
            if 'score' in elbo_mode: 
                for option in ['resampling', 'mcmc']:
                    self.elbo_options[option] = True if option in elbo_mode else False

                if 'bptt_depth' in elbo_mode: 
                    self.elbo_options['bptt_depth'] = int(elbo_mode.split('bptt_depth')[1].split('_')[1])
                

                self.elbo_options['normalizer'] = exp_and_normalize

                self.elbo_options['streaming'] = True if self.streaming else False
        
        # This given some updates just returns the new params
        def optimizer_update_fn(params, updates):
            new_params = optax.apply_updates(params, updates)
            return new_params
        
        self.optimizer_update_fn = optimizer_update_fn
    


        self.trainable_params = jax.tree.map(lambda x: x == '', self.frozen_params)
        
        
        
        
        #%% Create the optimizer shedules
        
        
        
    
        clip_norm=5
        schedule = None

        if 'linear_sched' in optim_options:
            
            learning_rate = optax.linear_schedule(learning_rate, 
                                                  end_value=10*learning_rate, 
                                                  transition_begin=2000,
                                                  transition_steps=num_epochs * (seq_length / self.online_batch_size))
        
        elif 'warmup_cosine' in optim_options:
            print('Setting up warmup cosine schedule. Starting at {} and ending at {}.'.format(learning_rate / 10, 
                                                                                               learning_rate))

            learning_rate = optax.warmup_cosine_decay_schedule(
                                    init_value=learning_rate / 10,
                                    peak_value=learning_rate,
                                    warmup_steps=100000,
                                    decay_steps=seq_length,
                                    end_value=learning_rate)
            
            
        elif 'gamma' in optim_options:
            gamma = float(optim_options.split(',')[-1].split('_')[1])
            schedule = optax.scale_by_schedule(lambda t: (t+1)**(-gamma))

        
        else:
            pass

        base_optimizer = optax.apply_if_finite(getattr(optax, optimizer)(learning_rate),
                                            max_consecutive_errors=10)
        
        
        if schedule is not None:
            optimizer = optax.chain(base_optimizer, schedule)
        else: 
            optimizer = base_optimizer
        

        
        self.optimizer_phi = base_optimizer
        
        
        #%% Model optimizer
        steps_per_t   = self.num_grad_steps                 # inner updates per time index
        T             = self.seq_length
        E             = self.num_epochs
        freeze_ratio = 0
        freeze_until  = int(freeze_ratio * E * T * steps_per_t)
        
        mdl_lr = float(learning_rate_model if learning_rate_model is not None else learning_rate)
        
        
        
        theta_lr_sched = optax.join_schedules(
            schedules=[
                optax.constant_schedule(0.0),     # warm-up freeze: no θ updates
                optax.constant_schedule(mdl_lr),   # then train θ at mdl_lr
            ],
            boundaries=[freeze_until],            # switch after this many update() calls
        )

        
        
        base_tx = optax.chain(optax.clip_by_global_norm(clip_norm), 
                              optax.adam(theta_lr_sched))
        
        

        
        mask = build_mask(self.theta_star,frozen_params)

        self.mask = mask
        
        

        
        labels = jax.tree.map(lambda b: "train" if b else "frozen", self.mask)

        tx = {
          "train":  base_tx,
          "frozen": optax.set_to_zero(),   # absolutely no update
        }
        self.optimizer_theta = optax.multi_transform(tx, labels)


        
        
                    
        if 'score' in self.elbo_mode:
            print('USING SCORE ELBO.')
            print('Using full gradients.')
            # This ELBO is in online/offline smoothing file 
            self.elbo = OnlineELBOScoreGradients(self.p, 
                                                self.q, 
                                                num_samples=num_samples, 
                                                **self.elbo_options)
                
                    
            def elbo_and_grads_batch(key, ys, params):
                carry, aux = self.elbo.batch_compute(
                                            key, 
                                            ys, 
                                            self.formatted_theta_star, 
                                            params)
                
                elbo, grad = self.elbo.postprocess(carry)
                T = len(ys) - 1 
                elbo = elbo / (T+1)
                neg_grad = jax.tree.map(lambda x: -x / (T+1), grad)
                return (elbo, neg_grad), aux
                    
            self.elbo_step = self.elbo.step
            
        

        else:
            print('ELBO mode not suitable for gradient accumulation.')
            raise NotImplementedError

        self.elbo_batch = elbo_and_grads_batch
        self.get_montecarlo_keys = get_keys


        
        params = self.q.get_random_params(jax.random.PRNGKey(0))
        params = self._extract_params(params)
        
            
        # Model params init
        params_model = self.theta_star

        
        init_carry = jax.vmap(self.elbo.init_carry, 
                              axis_size=self.num_seqs, in_axes=None)(self._build_params(params),
                                                                     self._build_params(params_model))
                                                                                                                           
        self.init_carry = init_carry
        
        

        
        
        def update(key, 
                elbo_carry, 
                strided_ys, 
                timesteps, 
                params,
                params_model):
            
            # Here the actual update step happens and we return the elbo gradient and elbo_carry + some aux values 
            
            if self.streaming: 
                def _step(carry, x):
                    key, t, strided_y = x

                    input_t = {'t':t, 
                            'key': key, 
                            'ys_bptt':strided_y, 
                            'T':t,
                            'phi':params,
                            'theta':params_model}
                    
                    
                    carry['theta_true'] = self.formatted_theta_true

                    
                    
                    
                    new_carry, aux = self.elbo.step(carry, 
                                                input_t)

                    return new_carry, (new_carry, aux)
                
                key_t, key_tm1 = jax.random.split(key, 2)
                new_carry, (_, aux) = _step(elbo_carry, (key_t, timesteps[-1], strided_ys[-1]))
                
#                new_carry_tm1, (_,_) = _step(elbo_carry, (key_tm1, timesteps[-2], strided_ys[-2]))
                
                
                # Here the elbo and grad will be computed from the carry stats
                elbo_t, grad_t,grad_model_t = self.elbo.postprocess(new_carry)
                
                
                # Old paramter 
                _,grad_tm1,grad_model_tm1 = self.elbo.postprocess(elbo_carry)
                
                if self.online_difference:
                    neg_grad = jax.tree.map(lambda x,y: -(x-y), grad_t, grad_tm1)
                    neg_grad_model = jax.tree.map(lambda x,y: -(x-y), grad_model_t, grad_model_tm1)
                    #neg_grad_model = tree.map(lambda x: x, grad_model_t)
     
                else:
                    neg_grad = jax.tree.map(lambda x: -x, grad_t)
                    neg_grad_model = jax.tree.map(lambda x: -x, grad_model_t)

                elbo = elbo_t / (timesteps[-1]+1)
                if not 'recompute' in self.training_mode:
                    elbo_carry = new_carry
                    
            return elbo, neg_grad,neg_grad_model, elbo_carry, aux
                    
        
        self.update = update

    def timesteps(self, seq_length, key):

        all_timesteps = jnp.arange(0, seq_length)
        if key is None: 
            if self.streaming:
                cnts = range(0, seq_length, 1)
            else:
                cnts = range(0, seq_length, self.online_batch_size)

        else: 
            cnts = jax.random.permutation(key, jnp.arange(0, seq_length, self.online_batch_size))
        
        for cnt in cnts:
            yield all_timesteps[cnt:cnt+self.online_batch_size]

    def fit(self, 
            key_params, 
            key_montecarlo, 
            data, 
            args=None):
        
        # Init parameter and optimizer for the variational family 
        params = self.q.get_random_params(key_params, args)
        params = self._extract_params(params)
        opt_state = self.optimizer_phi.init(params)
        
        
        
        # Init parameter and optimizer for the model params which we wanna learn 
        params_model = self.theta_star
        opt_state_model = self.optimizer_theta.init(self.theta_star)
        
        # Same structure in Gaussian case
        
        
    
        
        ys = data
        seq_length = self.seq_length
        keys = get_keys(key_montecarlo, 
                        seq_length // self.online_batch_size, 
                        self.num_epochs)
        

        strided_ys = jax.vmap(self.elbo.preprocess)(ys)
        

    
        def _step(step_carry, x):
            params,params_model, opt_state,opt_state_model, elbo_carry = step_carry      
            key, timesteps = x 
            
            #For the specisific timestep
            strided_ys_on_timesteps = strided_ys[:,timesteps]
            key, key_params = jax.random.split(key, 2)
            
            #params_model = stop_grads_on_mask(params_model,self.theta_mask)

            def inner_step(carry, x):
                inner_carry, params,params_model, opt_state,opt_state_model = carry
                key = x
                
                
                # Here we must adjust the output s.t we also compute the grad for the model parameter
                elbo, neg_grad,neg_grad_model, new_carry, aux = jax.vmap(self.update, in_axes=(0, 0, 0, None, None,None))(
                                                            jax.random.split(key, len(strided_ys_on_timesteps)), 
                                                            lax.cond(timesteps[-1] == 0, 
                                                                     lambda x:x, 
                                                                     lambda x:elbo_carry,
                                                                     inner_carry),
                                                            strided_ys_on_timesteps, 
                                                            timesteps, 
                                                            self._build_params(params),
                                                            self._build_params(params_model))
                
                
                
                
                
                # Variational updates
                neg_grad = self._extract_params(neg_grad)

                neg_grad = jax.tree.map(partial(jnp.mean, axis=0), neg_grad)
                
                
    
                updates, opt_state = self.optimizer_phi.update(neg_grad, 
                                                            opt_state, 
                                                            params)
                
                
                new_params = self.optimizer_update_fn(params, updates)
                
                
                
                # Model updates
                neg_grad_model = self._extract_params(neg_grad_model)
                neg_grad_model = jax.tree.map(partial(jnp.mean, axis=0), neg_grad_model)
                
                
                # # Pre-update grad norms
                # gA = jnp.linalg.norm(jnp.ravel(neg_grad_model.transition.map.w))
                # gG = jnp.linalg.norm(jnp.ravel(neg_grad_model.emission.map.w))
                
                # # Print every 100 steps (t % 100 == 0)
                # jax.lax.cond(
                #     (timesteps[-1] % 100) == 0,
                #     lambda _: jax.debug.print(
                #         "t={t} mean ||grad A||={a:.3e} ||grad G||={g:.3e}",
                #         t=timesteps[-1], a=gA, g=gG, ordered=True
                #     ),
                #     lambda _: None,
                #     operand=None
                # )
                
                
                updates_model, opt_state_model = self.optimizer_theta.update(
                    neg_grad_model, opt_state_model, params_model
                )
                new_params_model = optax.apply_updates(params_model, updates_model)
                
                # # Post-update parameter deltas
                # dA = jnp.linalg.norm(jnp.ravel(new_params_model.transition.map.w - params_model.transition.map.w))
                # dG = jnp.linalg.norm(jnp.ravel(new_params_model.emission.map.w   - params_model.emission.map.w))
                
                # # Print every 100 steps (t % 100 == 0)
                # jax.lax.cond(
                #     (timesteps[-1] % 100) == 0,
                #     lambda _: jax.debug.print(
                #         "t={t} ΔA={a:.3e} ΔG={g:.3e}",
                #         t=timesteps[-1], a=dA, g=dG, ordered=True
                #     ),
                #     lambda _: None,
                #     operand=None
                # )

            

                # Onluy for lgm 
                mse_value = mae_maps(self.theta_true, new_params_model)
                
                
                
                
                
                return (new_carry, 
                        new_params, 
                        new_params_model,
                        opt_state,opt_state_model), \
                        (jnp.mean(elbo, axis=0), aux,mse_value)
            


            # Here we run over all grad steps per timesteps

            (elbo_carry, next_step_params,next_step_params_model, opt_state,opt_state_model), (elbos, aux,mse_value) = jax.lax.scan(inner_step, 
                                                        init=(elbo_carry, params,params_model,opt_state,opt_state_model), 
                                                        xs=jax.random.split(key, 
                                                                            self.num_grad_steps))

            


#            next_step_params = params
#            next_step_params_model = params_model


            # if not isinstance(self.q, NonAmortizedBackwardSmoother):
            return (next_step_params,next_step_params_model, opt_state,opt_state_model, elbo_carry), (elbos, aux, params,params_model,mse_value)
        
        # absolute_step_nb = 0


        if self.monitor: 
            monitor_elbo = lambda key, ys, phi:self.monitor_elbo(
                                                                key, 
                                                                ys, 
                                                                len(ys)-1, 
                                                                self.formatted_theta_star, 
                                                                self.q.format_params(phi))[0] / len(ys)
        else: 
            monitor_elbo = lambda *args: 0.0
        
        # What do we need in this init carry object
        elbo_carry = self.init_carry
        if self.streaming:
            all_timesteps = jnp.expand_dims(jnp.arange(0, seq_length), 
                                                axis=-1)


        if self.streaming and self.num_epochs == 1: 

            @scan_tqdm(len(all_timesteps))
            def step(carry, x):
                return _step(carry, x[1:])



                

        if self.streaming and self.num_epochs == 1: 
            print('Streaming on a single sequence only once.')
            def _epoch_step(carry, x):
                params,params_model, opt_state,opt_state_model,elbo_carry = carry
                

                
                _, keys_epoch = x
                
                # Here we run over all timesteps
                
                (params,params_model,opt_state,opt_state_model, elbo_carry), (elbos_steps, aux_results, params_steps,params_model_steps,mse_value_step) = jax.lax.scan(step, 
                                                                init=(params,params_model,opt_state,opt_state_model,elbo_carry),
                                                                xs=(jnp.arange(0, len(all_timesteps)),
                                                                    keys_epoch, 
                                                                    all_timesteps))
                


                return (params,params_model, opt_state,opt_state_model,elbo_carry), (elbos_steps, aux_results, params_steps,params_model_steps,mse_value_step)
            
        
    
        
        # This iterates over the epochs (1 for streaming data)
        # Model opt state and model params should be included
        
        # this is the main loop first over all epochs

        (params,params_model,opt_state , opt_state_model , elbo_carry), (elbos_epochs, aux_results, params_epochs,params_models_epochs,mse_values_epochs) = jax.lax.scan(_epoch_step, 
                                                                                            init=(params,params_model, opt_state,opt_state_model,elbo_carry), 
                                                                                            xs = (jnp.arange(0, self.num_epochs), keys))
        

        return params,params_model, elbos_epochs, aux_results, params_epochs, params_models_epochs,mse_values_epochs
    
    def multi_fit(self, 
                  key, 
                  data, 
                  num_fits, 
                  log_dir='',
                  args=None):

        print('Starting training...')
        
        tensorboard_subdir = os.path.join(log_dir, 'tensorboard_logs')
        os.makedirs(tensorboard_subdir, exist_ok=True)
        times = []
        def run_fit(fit_nb, fit_key):

            if self.logging_type == 'tensorboard':
                pass
                # log_writer = tf.summary.create_file_writer(os.path.join(tensorboard_subdir, f'fit_{fit_nb}'))
                # if self.monitor:
                #     log_writer_monitor = tf.summary.create_file_writer(os.path.join(tensorboard_subdir, f'fit_{fit_nb}_monitor'))
                # else:
                #     log_writer_monitor = None
            else:
                log_writer = None 
                log_writer_monitor = None
            print(f'Starting fit {fit_nb}/{num_fits-1}...')

            key_params, key_montecarlo = jax.random.split(fit_key, 2)


            # time0 = time()
            if self.logging_type == 'tensorboard':
                params, elbos, (means_tm1, means_t) = self.fit(
                                                            key_params, 
                                                            key_montecarlo, 
                                                            data, 
                                                            log_writer, 
                                                            args, 
                                                            log_writer_monitor)


                print(elbos)

                burnin = int(0.75*self.num_epochs)

                elbos = {k:v for k,v in elbos.items() if k > burnin}

                best_step_for_fit = max(elbos, key=elbos.get)

                if 'truncated' in self.elbo_mode:
                    jnp.save(os.path.join(log_dir, 'x_tm1.npy'), jnp.array(means_tm1)[1:])
                    jnp.save(os.path.join(log_dir, 'x_t.npy'), jnp.array(means_t))

                best_elbo_for_fit = elbos[best_step_for_fit]
                best_params_for_fit = params[best_step_for_fit]
                print(f'Fit {fit_nb}: best ELBO {best_elbo_for_fit:.3f} at step {best_step_for_fit}')
                return best_params_for_fit, best_elbo_for_fit, jnp.array(times[20:])
        

            else: 
                final_params, final_elbo = self.fit(key_params, 
                                key_montecarlo, 
                                data, 
                                log_writer, 
                                args, 
                                log_writer_monitor)
                print(f'Fit {fit_nb}: final ELBO {final_elbo:.3f}')
                return final_params, final_elbo, jnp.array(times[20:])
                # time0 = time()
                

        
        fit_nbs = range(args.num_fits)
        fit_keys = jax.random.split(key, args.num_fits)
        best_params_per_fit, best_elbos_per_fit = [], []

        all_timings = []
        for fit_nb, fit_key in zip(fit_nbs, fit_keys): 

            best_params_for_fit, best_elbo_for_fit, timings = run_fit(fit_nb, fit_key)
            best_params_per_fit.append(best_params_for_fit)
            best_elbos_per_fit.append(best_elbo_for_fit)
            all_timings.append(timings)


        best_optim = jnp.argmax(jnp.array(best_elbos_per_fit))
        print(f'Best fit is {best_optim}.')
        training_info = dict()
        training_info['avg_time'] = jnp.mean(jnp.concatenate(all_timings)).tolist()
        training_info['best_fit'] = best_optim.tolist()
        save_dict(training_info, 'training_info', log_dir)

        return best_params_per_fit[best_optim], best_params_per_fit





