"""
Implements simple agent for a forward state predictive model.
"""


from functools import partial
import os
from typing import Optional, Tuple, Any, Callable

import flax

from flax.training.train_state import TrainState

import jax
import jax.numpy as jnp
import numpy as np
import optax
import time


from koopman.data import Batch
from koopman.networks.dynamics_models import dynamics_models
from koopman.networks.state_networks import state_encoders
from koopman.networks.common import InfoDict, Params
from koopman.utils.common import value_and_multi_grad, recursive_sum, recursive_len, masked_mse_loss

def get_num_params(params):
    return sum(np.prod(p.shape) for p in jax.tree_leaves(params))

class PredictiveLearner(object):
    def __init__(self,
                 seed: int,
                 observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 encoder_lr: float = 3e-4,
                 dynamics_lr: float = 3e-4,
                 state_emb_dim: int = 512,
                 action_emb_dim: int = None,
                 dynamics_model_type: Optional[str] = None,
                 discretize: int = 1,
                 koopman_real_init_type: str = 'learnable',
                 koopman_real_init_value: float = -0.5,
                 koopman_im_init_type: str = 'increasing_freq',
                 optimizer_type: str = 'adam',
                 train_seq_length: int = 100,
                 pred_reward: int = 0,
                 use_state_prediction_loss: int = 1
                 ):
        """
        An implementation of the version of SAC with predictive coding
        """
        self.dynamics_model_type = dynamics_model_type

        optimizer_class = {'adam': optax.adam, 'sgd': optax.sgd}[optimizer_type]

        self._train_states = {}
        self.pred_reward = pred_reward
        self.use_state_prediction_loss = use_state_prediction_loss

        rng = jax.random.PRNGKey(seed)
        rng, encoder_key, decoder_key, dynamics_model_key, reward_model_key = jax.random.split(rng, 5)

        # define train-state
        self.state_encoder = state_encoders.EncoderState(out_dim=state_emb_dim)
        state_rep, state_encoder_params = self.state_encoder.init_with_output(encoder_key, observations)
        self._train_states['state_encoder'] = TrainState.create(
            params=state_encoder_params, tx=optimizer_class(encoder_lr), apply_fn=None
        )

        print(dynamics_model_type)    

        if dynamics_model_type is not None:
            self.state_decoder = state_encoders.DecoderState(observation_dim=observations.shape[-1])
            pred_observations, state_decoder_params = self.state_decoder.init_with_output(decoder_key, state_rep)
            self._train_states['state_decoder'] = TrainState.create(
                params=state_decoder_params, tx=optimizer_class(encoder_lr), apply_fn=None
            )

            if dynamics_model_type == 'regular':
                print("using regular dynamics model")
                self.dynamics_model = dynamics_models.BatchedRegularDynamicsModel(
                    state_dim=state_emb_dim
                )
            elif dynamics_model_type == 'dense_koopman':
                print("using dense koopman dynamics model")
                self.dynamics_model = dynamics_models.BatchedDenseKoopmanDynamicsModel(
                    state_dim=state_emb_dim,
                    action_emb_dim=action_emb_dim
                )
            elif dynamics_model_type == 'diagonal_koopman':
                print(f"using koopman dynamics model with real init type {koopman_real_init_type} and imaginary init type {koopman_im_init_type}")
                self.dynamics_model = dynamics_models.BatchedDiagonalKoopmanDynamicsModel(
                    state_dim=state_emb_dim, action_emb_dim=action_emb_dim,
                    real_init_type=koopman_real_init_type,
                    real_init_value=koopman_real_init_value, 
                    im_init_type=koopman_im_init_type
                )
            elif dynamics_model_type == 'transformer':
                print("using transformer dynamics model")
                self.dynamics_model = dynamics_models.BatchedTransformerDynamicsModel(
                    state_dim=state_emb_dim, action_emb_dim=action_emb_dim
                )
            elif dynamics_model_type == 'gru':
                print("using GRU dynamics model")
                self.dynamics_model = dynamics_models.BatchedGRUDynamicsModel(
                    state_dim=state_emb_dim, action_emb_dim=action_emb_dim
                )
            elif dynamics_model_type == 'dssm':
                print("using DSSM dynamics model")
                self.dynamics_model = dynamics_models.BatchedDSSMDynamicsModel(
                    state_dim=state_emb_dim, action_emb_dim=action_emb_dim
                )
            else:
                print(dynamics_model_type)
                raise Exception('Dynamics model not implemented')



            dynamics_model_params = self.dynamics_model.init(
                dynamics_model_key, actions[np.newaxis], state_rep[0][np.newaxis]
            )
            self._train_states['dynamics_model'] = TrainState.create(
                params=dynamics_model_params, tx=optimizer_class(dynamics_lr), apply_fn=None
            )

            if pred_reward:
                self.reward_model = dynamics_models.BatchedRewardModel()
                reward_model_params = self.reward_model.init(
                    reward_model_key, state_rep[np.newaxis], actions[np.newaxis]
                )
                self._train_states['reward_model'] = TrainState.create(
                    params=reward_model_params, tx=optimizer_class(dynamics_lr), apply_fn=None
                )
        else:
            raise Exception("use a dynamics model ....")

        model_keys = ['state_encoder', 'state_decoder', 'dynamics_model']
        total_params = get_num_params(state_encoder_params) + get_num_params(state_decoder_params) \
                        + get_num_params(dynamics_model_params)

        if pred_reward:
            model_keys += ['reward_model']
            total_params += get_num_params(reward_model_params)

        print('Total number of parameters: ', total_params)

        self.optimizer_class = optimizer_class
        self._model_keys = tuple(model_keys)
        self.rng = rng
        self.step = 1

    def predict_next_state(self, observations: jnp.ndarray, actions_seq: jnp.ndarray):
        if len(actions_seq.shape) == 1:
            actions_seq = actions_seq[None, None, :]
            onservations = observations[None, :]
        if len(actions_seq.shape) == 2:
            if observations.shape[0] == actions_seq.shape[0]:
                actions_seq = actions_seq[:, None, :]
            else:
                actions_seq = actions_seq[None, :, :]
                observations = observations[None, :]
        next_states = self._predict_next_state(
            observations,
            actions_seq,
            self._train_states['state_encoder'].params,
            self._train_states['state_decoder'].params,
            self._train_states['dynamics_model'].params
        )
        return jnp.squeeze(next_states)

    @partial(jax.jit, static_argnames=('self'))
    def _predict_next_state(self,
                            observations: jnp.ndarray,
                            actions: jnp.ndarray,
                            state_encoder_params: Params,
                            state_decoder_params: Params,
                            dynamics_model_params: Params
                            ):
        state_rep = self.state_encoder.apply(state_encoder_params, observations)
        next_state_rep = self.dynamics_model.apply(dynamics_model_params, actions, state_rep)
        next_state = self.state_decoder.apply(state_decoder_params, next_state_rep)
        return next_state

    def predict_next_state_and_reward(self, observations: jnp.ndarray, actions_seq: jnp.ndarray):
        if len(actions_seq.shape) == 1:
            actions_seq = actions_seq[None, None, :]
            onservations = observations[None, :]
        if len(actions_seq.shape) == 2:
            if observations.shape[0] == actions_seq.shape[0]:
                actions_seq = actions_seq[:, None, :]
            else:
                actions_seq = actions_seq[None, :, :]
                observations = observations[None, :]
        next_states, rewards = self._predict_next_state_and_reward(
            observations,
            actions_seq,
            self._train_states['state_encoder'].params,
            self._train_states['state_decoder'].params,
            self._train_states['dynamics_model'].params,
            self._train_states['reward_model'].params
        )
        return jnp.squeeze(next_states), jnp.squeeze(rewards)

    @partial(jax.jit, static_argnames=('self'))
    def _predict_next_state_and_reward(self,
                            observations: jnp.ndarray,
                            actions: jnp.ndarray,
                            state_encoder_params: Params,
                            state_decoder_params: Params,
                            dynamics_model_params: Params,
                            reward_model_params: Params,
                            ):
        state_rep = self.state_encoder.apply(state_encoder_params, observations)
        next_state_rep = self.dynamics_model.apply(dynamics_model_params, actions, state_rep)
        next_state = self.state_decoder.apply(state_decoder_params, next_state_rep)
        rewards = self.reward_model.apply(
            reward_model_params, 
            jnp.concatenate([state_rep[:, None, :], next_state_rep[:,:-1]], axis=1),
            actions
        )
        return next_state, rewards


    def update(self, batch: Batch) -> InfoDict:
        start_time = time.time()
        self.step += 1
        self.rng, self._train_states, info = self._train_step(
            self.rng, self._train_states,  batch
        )
        end_time = time.time()
        info['time_per_training_step'] = end_time - start_time
        return info


    @partial(jax.jit, static_argnames=('self'))
    def _train_step(self,
                    rng,
                    train_states,
                    batch):

        #breakpoint()
        observations_seq = batch.observations
        actions_seq = batch.actions
        seq_masks = batch.seq_masks
        rewards_seq = batch.rewards

        observations = observations_seq[:, 0]

        def loss_fn(train_params, rng):
            loss_collection = {}

            ### Actor loss ##################################
            state_reps = self.state_encoder.apply(train_params['state_encoder'], observations)

            predicted_next_states_reps = self.dynamics_model.apply(
                train_params['dynamics_model'], actions_seq, state_reps)

            target_state_reps = jax.lax.stop_gradient(
                self.state_encoder.apply(train_params['state_encoder'], observations_seq[:, 1:])
            )

            state_prediction_loss = masked_mse_loss(predicted_next_states_reps[:, :-1], target_state_reps, seq_masks[:, 1:][:, :, None]) if self.use_state_prediction_loss else 0.0

            reward_loss = 0.0
            if self.pred_reward:
                rewards_seq_pred = self.reward_model.apply(
                    train_params['reward_model'], 
                    jnp.concatenate([state_reps[:, None, :], predicted_next_states_reps[:, :-1]], axis=1),
                    actions_seq
                )
                reward_loss += masked_mse_loss(jnp.squeeze(rewards_seq_pred), rewards_seq, seq_masks)

            predicted_observations_seq = self.state_decoder.apply(
                train_params['state_decoder'], predicted_next_states_reps
            )

            observation_prediction_loss = masked_mse_loss(predicted_observations_seq[:, :-1], observations_seq[:, 1:], seq_masks[:, 1:][:, :, None])

            loss_collection['dynamics_model'] = state_prediction_loss + observation_prediction_loss + reward_loss
            loss_collection['state_encoder'] = state_prediction_loss + observation_prediction_loss + reward_loss
            loss_collection['state_decoder'] = observation_prediction_loss
            loss_collection['reward_model'] = reward_loss

            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}

        rng, split_rng = jax.random.split(rng)

        (_, aux_values), grads = value_and_multi_grad(
            loss_fn, len(self.model_keys), has_aux=True
        )(train_params, split_rng)

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }

        # rng, split_rng = jax.random.split(rng)

        info = dict()
        # def loss_fn_grad_test(encoder_params, dynamics_model_params):
        #     ############# Dynamics modelling loss for testing the gradient behavior #######################
        #     state_reps = self.state_encoder.apply(encoder_params, observations)
        #     predicted_next_states_reps = self.dynamics_model.apply(
        #         dynamics_model_params, actions_seq, state_reps
        #     )
        #
        #     target_state_reps = jax.lax.stop_gradient(
        #         self.state_encoder.apply(encoder_params, observations_seq[:, 1:])
        #     )
        #     return masked_mse_loss(predicted_next_states_reps[:, :-1], target_state_reps, seq_masks[:, 1:][:, :, None])
        #
        # state_param_grad = jax.grad(loss_fn_grad_test, argnums=0)(
        #     train_params['state_encoder'], train_params['dynamics_model']
        # )
        # state_param_grad_mean = recursive_sum(state_param_grad) / recursive_len(state_param_grad)
        # info['enc_grad_from_last_pred'] = state_param_grad_mean
        info['state_prediction_loss'] = aux_values['state_prediction_loss']
        info['observation_prediction_loss'] = aux_values['observation_prediction_loss']
        info['reward_loss'] = aux_values['reward_loss']


        return rng, new_train_states, info

    @property
    def model_keys(self):
        return self._model_keys

    @property
    def train_states(self):
        return self._train_states

    @property
    def train_params(self):
        return {key: self.train_states[key].params for key in self.model_keys}


    def _save_param(self, save_path: str, params):
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'wb') as f:
            f.write(flax.serialization.to_bytes(params))

    def _load_params(self, load_path: str, model):
        with open(load_path, 'rb') as f:
            params = flax.serialization.from_bytes(model.params, f.read())
        return model.replace(params=params)

    #TODO: fix the save/load functions
    ########## Saving and loading ########################
    def save(self, save_path: str):
        ## save params from train-state
        for name, train_state in self._train_states.items():
            pathname = os.path.join(save_path, '{}.ckpt'.format(name))
            self._save_param(pathname, train_state.params)

    def load(self, load_path: str):
        ## load params for train-state
        for name, train_state in self._train_states.items():
            pathname = os.path.join(load_path, '{}.ckpt'.format(name))
            with open(pathname, 'rb') as f:
                params = flax.serialization.from_bytes(train_state.params, f.read())
                self._train_states[name] = train_state.replace(params=params)