import copy
from typing import Any, Callable

import flax
import jax
import jax.numpy as jnp
import ml_collections
import optax

from utils.encoders import encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.hlg import cross_entropy_loss_on_scalar, hl_gauss_transform
from utils.networks import ActorVectorField, Value





class DEASAgent(flax.struct.PyTreeNode):
    """
    """

    rng: Any
    network: Any
    config: Any = nonpytree_field()
    transform_to_probs: Callable = nonpytree_field()
    transform_from_probs: Callable = nonpytree_field()

    @staticmethod
    def compute_mse(pred, target, max_action_sequence, key=''):
        n_action_steps = [1]
        power_steps = [1 << i for i in range(1, max_action_sequence.bit_length())]
        n_action_steps.extend([step for step in power_steps if step <= max_action_sequence])
        mse_list = {}
        for n_a in n_action_steps:
            mse = jnp.mean((pred[:, :n_a, ...] - target[:, :n_a, ...]) ** 2)
            mse_list[f'{key}_mse_as{n_a}'] = mse
        return mse_list

    def transform_value_to_probs(self, value):
        if value.ndim == 1:
            return jax.vmap(self.transform_to_probs, in_axes=0, out_axes=0)(value)
        else:
            return jax.vmap(jax.vmap(self.transform_to_probs, in_axes=0, out_axes=0), in_axes=0, out_axes=0)(value)

    def value_loss(self, batch, grad_params):
        """Compute the IQL value loss with distributional value function."""
        # Get target Q-values from critic
        batch_size = batch['actions'].shape[0]
        batch_chunk_actions = jnp.reshape(batch["actions"], (batch["actions"].shape[0], -1))
        q_logits = self.network.select('target_critic')(
            batch['observations'], actions=batch_chunk_actions
        )
        q_probs = jax.nn.softmax(q_logits, axis=-1)
        qs = self.transform_from_probs(q_probs)

        # Aggregate Q values across ensemble
        if self.config['q_agg'] == 'min':
            min_q_idx = jnp.argmin(qs, axis=0)  # [batch_size]
            batch_indices = jnp.arange(qs.shape[1])  # [batch_size]
            q = qs[min_q_idx, batch_indices]  # [batch_size]
            q_prob = q_probs[min_q_idx, batch_indices]  # [batch_size, num_atoms]
        else:
            q = jnp.mean(qs, axis=0)
            q_prob = jnp.mean(q_probs, axis=0)

        v_logit = self.network.select('value')(observations=batch['observations'], params=grad_params)
        v_prob = jax.nn.softmax(v_logit, axis=-1)
        v = self.transform_from_probs(v_prob)

        # Change this part to use CVaR for dividing the loss into expectile / 1-expectile.
        g_hard = jnp.where(q >= v, self.config['expectile'], 1.0 - self.config['expectile'])

        ce_loss = -jnp.sum(q_prob * jax.nn.log_softmax(v_logit, axis=-1), axis=-1)
        value_loss = (g_hard * ce_loss).mean()

        # Assemble metrics
        metrics = {
            # losses
            'value_loss': value_loss,
            # value stats
            'v_mean': v.mean(),
            'v_max': v.max(),
            'v_min': v.min(),
            # prob stats (spread/peakedness)
            'v_prob_mean': v_prob.mean(),
            'v_prob_std_per_sample_mean': jnp.std(v_prob, axis=-1).mean(),
            # target-Q scalar stats
            'target_q_mean': q.mean(),
            'target_q_max': q.max(),
            'target_q_min': q.min(),
        }
        return value_loss, metrics

    def critic_loss(self, batch, grad_params):
        """Compute the IQL critic loss."""
        batch_size = batch['actions'].shape[0]
        next_observations = batch['next_observations'][..., -1,:]
        chunk_batch_actions = jnp.reshape(batch["actions"], (batch["actions"].shape[0], -1))
        # logic: rewards is summed over nstep in the dataset, so we divide it here for scaling.
        next_v_dists = self.network.select('value')(observations=next_observations)
        next_v_probs = jax.nn.softmax(next_v_dists, axis=-1)
        next_v = self.transform_from_probs(next_v_probs)

        # batch['rewards] : sum of discounted n-step rewards
        # self.config['discount] == gamma2
        # self.config['nstep'] == 1
        target_v = (
            batch['rewards'][...,-1]
            + (self.config['discount'] ** (self.config['nstep'] * self.config['action_sequence']))
            * batch['masks'][...,-1]
            * next_v
        )

        q_dists = self.network.select('critic')(
            batch['observations'],
            actions=chunk_batch_actions,
            params=grad_params,
        )
        q_probs = jax.nn.softmax(q_dists, axis=-1)
        critic_loss = (
            cross_entropy_loss_on_scalar(q_dists, target_v, self.transform_to_probs).mean(axis=-1).mean(axis=-1)
        )
        q_vals = self.transform_from_probs(q_probs)
        q = jnp.mean(q_vals, axis=0)

        return critic_loss, {
            'critic_loss': critic_loss,
            'q_mean': q.mean(),
            'q_max': q.max(),
            'q_min': q.min(),
            'batch_rewards': batch['rewards'].mean(),
        }

    def bc_flow_loss(self, batch, grad_params, rng):
        """Compute the BC flow loss."""
        batch_size, action_sequence, action_dim = batch['actions'].shape

        # BC Flow loss
        rng, x_rng, t_rng = jax.random.split(rng, 3)
        x_0 = jax.random.normal(x_rng, (batch_size, action_sequence * action_dim))
        x_1 = batch['actions'].reshape(batch_size, -1)
        t = jax.random.uniform(t_rng, (batch_size, 1))
        x_t = (1 - t) * x_0 + t * x_1
        vel = x_1 - x_0

        pred = self.network.select('actor_bc_flow')(batch['observations'], actions=x_t, times=t, params=grad_params)
        bc_flow_loss = jnp.mean((pred - vel) ** 2)

        return bc_flow_loss, {
            'bc_flow_loss': bc_flow_loss,
        }

    def actor_loss(self, batch, grad_params, rng):
        """Compute the FQL actor loss."""
        batch_size, action_sequence, action_dim = batch['actions'].shape

        # Distillation loss
        rng, noise_rng = jax.random.split(rng)
        noises = jax.random.normal(noise_rng, (batch_size, action_sequence * action_dim))
        target_flow_actions = self.compute_flow_actions(batch['observations'], noises=noises)
        actor_actions = self.network.select('actor_onestep_flow')(batch['observations'], noises, params=grad_params)
        distill_loss = jnp.mean((actor_actions - target_flow_actions) ** 2)

        # Q loss.
        actor_actions = jnp.clip(actor_actions, -1, 1)
        q_dists = self.network.select('critic')(
            batch['observations'],
            actions=actor_actions,
        )
        q_probs = jax.nn.softmax(q_dists, axis=-1)
        qs = self.transform_from_probs(q_probs)
        q = jnp.mean(qs, axis=0)

        q_loss = -q.mean()
        if self.config['normalize_q_loss']:
            lam = jax.lax.stop_gradient(1 / jnp.abs(q).mean())
            q_loss = lam * q_loss

        # Total loss.
        actor_loss = self.config['alpha'] * distill_loss + q_loss

        # Additional metrics for logging.
        mse_dict = self.compute_mse(
            actor_actions.reshape(batch_size, action_sequence, action_dim),
            batch['actions'],
            self.config['action_sequence'],
            key='actor',
        )
        flow_mse_dict = self.compute_mse(
            target_flow_actions.reshape(batch_size, action_sequence, action_dim),
            batch['actions'],
            self.config['action_sequence'],
            key='flow',
        )

        return actor_loss, {
            'actor_loss': actor_loss,
            'distill_loss': distill_loss,
            'q_loss': q_loss,
            'q': q.mean(),
            **mse_dict,
            **flow_mse_dict,
        }

    @jax.jit
    def total_loss(self, batch, grad_params, rng=None):
        """Compute the total loss."""
        info = {}
        rng = rng if rng is not None else self.rng

        bc_flow_loss, bc_flow_info = self.bc_flow_loss(batch, grad_params, rng)
        for k, v in bc_flow_info.items():
            info[f'bc_flow/{k}'] = v

        value_loss, value_info = self.value_loss(batch, grad_params)
        for k, v in value_info.items():
            info[f'value/{k}'] = v

        critic_loss, critic_info = self.critic_loss(batch, grad_params)
        for k, v in critic_info.items():
            info[f'critic/{k}'] = v

        rng, actor_rng = jax.random.split(rng)
        actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
        for k, v in actor_info.items():
            info[f'actor/{k}'] = v

        loss = bc_flow_loss + value_loss + critic_loss + actor_loss
        return loss, info

    def target_update(self, network, module_name, tau):
        """Update the target network."""
        new_target_params = jax.tree_util.tree_map(
            lambda p, tp: p * tau + tp * (1 - tau),
            self.network.params[f'modules_{module_name}'],
            self.network.params[f'modules_target_{module_name}'],
        )
        network.params[f'modules_target_{module_name}'] = new_target_params

    @jax.jit
    def update(self, batch):
        """Update the agent and return a new agent with information dictionary."""
        new_rng, rng = jax.random.split(self.rng)

        def loss_fn(grad_params):
            return self.total_loss(batch, grad_params, rng=rng)

        new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
        self.target_update(new_network, 'critic', self.config['tau'])
        return self.replace(network=new_network, rng=new_rng), info

    @jax.jit
    def sample_actions(
        self,
        observations,
        rng=None,
        temperature=1.0,
    ):
        seed = rng if rng is not None else self.seed
        """Sample actions from the one-step policy."""
        action_seed, noise_seed = jax.random.split(seed)
        noises = jax.random.normal(
            action_seed,
            (
                *observations.shape[: -len(self.config['ob_dims'])],
                self.config['action_sequence'] * self.config['action_dim'],
            ),
        )
        actions = self.network.select('actor_onestep_flow')(observations, noises)
        actions = jnp.clip(actions, -1, 1)
        actions = actions.reshape(
            *observations.shape[: -len(self.config['ob_dims'])],
            self.config['action_sequence'],
            self.config['action_dim'],
        )
        return actions

    @jax.jit
    def compute_flow_actions(
        self,
        observations,
        noises,
    ):
        """Compute actions from the BC flow model using the Euler method."""
        if self.config['encoder'] is not None:
            observations = self.network.select('actor_bc_flow_encoder')(observations)
        actions = noises
        # Euler method.
        for i in range(self.config['flow_steps']):
            t = jnp.full((*observations.shape[:-1], 1), i / self.config['flow_steps'])
            vels = self.network.select('actor_bc_flow')(observations, actions, t, is_encoded=True)
            actions = actions + vels / self.config['flow_steps']
        actions = jnp.clip(actions, -1, 1)
        return actions

    @classmethod
    def create(
        cls,
        seed,
        ex_observations,
        ex_actions,
        config,
    ):
        """Create a new agent.

        Args:
            seed: Random seed.
            ex_observations: Example batch of observations.
            ex_actions in DEAS: Example batch of actions. b x h x a_d
            config: Configuration dictionary.
        #
        When we create the agent, batch is just b x a_d
        """
        rng = jax.random.PRNGKey(seed)
        rng, init_rng = jax.random.split(rng, 2)

        ex_times = ex_actions[...,:1]
        action_sequence=config["horizon_length"]
        action_dim = ex_actions.shape[-1]

        # flattened_ex_actions = ex_actions.reshape(1, -1)
        chunk_actions = jnp.concatenate([ex_actions] * config["horizon_length"], axis=-1)
        flattened_ex_actions= chunk_actions

        # Define encoders.
        encoders = dict()
        if config['encoder'] is not None:
            encoder_module = encoder_modules[config['encoder']]
            encoders['value'] = encoder_module()
            encoders['critic'] = encoder_module()
            encoders['actor_bc_flow'] = encoder_module()
            encoders['actor_onestep_flow'] = encoder_module()

        # Define networks.
        value_def = Value(
            hidden_dims=config['value_hidden_dims'],
            num_ensembles=1,
            encoder=encoders.get('value'),
            output_dim=config['num_atoms'],
        )
        critic_def = Value(
            hidden_dims=config['critic_hidden_dims'],
            num_ensembles=config['num_critic_ensembles'],
            layer_norm=config['layer_norm'],
            encoder=encoders.get('critic'),
            output_dim=config['num_atoms'],
            use_zero_output=True,
        )
        actor_bc_flow_def = ActorVectorField(
            hidden_dims=config['actor_hidden_dims'],
            action_dim=action_sequence * action_dim,
            layer_norm=config['actor_layer_norm'],
            encoder=encoders.get('actor_bc_flow'),
        )
        actor_onestep_flow_def = ActorVectorField(
            hidden_dims=config['actor_hidden_dims'],
            action_dim=action_sequence * action_dim,
            layer_norm=config['actor_layer_norm'],
            encoder=encoders.get('actor_onestep_flow'),
        )

        network_info = dict(
            value=(value_def, (ex_observations,)),
            critic=(critic_def, (ex_observations, flattened_ex_actions)),
            target_critic=(copy.deepcopy(critic_def), (ex_observations, flattened_ex_actions)),
            actor_bc_flow=(actor_bc_flow_def, (ex_observations, flattened_ex_actions, ex_times)),
            actor_onestep_flow=(actor_onestep_flow_def, (ex_observations, flattened_ex_actions)),
        )
        if encoders.get('actor_bc_flow') is not None:
            # Add actor_bc_flow_encoder to ModuleDict to make it separately callable.
            network_info['actor_bc_flow_encoder'] = (encoders.get('actor_bc_flow'), (ex_observations,))
        networks = {k: v[0] for k, v in network_info.items()}
        network_args = {k: v[1] for k, v in network_info.items()}

        network_def = ModuleDict(networks)
        modulate_fn = flax.linen.tabulate(network_def, jax.random.PRNGKey(0))
        print(modulate_fn(**network_args))
        network_tx = optax.adam(learning_rate=config['lr'])
        network_params = network_def.init(init_rng, **network_args)['params']
        network = TrainState.create(network_def, network_params, tx=network_tx)

        params = network.params
        params['modules_target_critic'] = params['modules_critic']

        config['ob_dims'] = ex_observations.shape[1:]
        config['action_dim'] = action_dim
        config['action_sequence'] = action_sequence

        transform_to_probs, transform_from_probs = hl_gauss_transform(
            min_value=config['v_min'],
            max_value=config['v_max'],
            num_bins=config['num_atoms'],
            sigma=config['sigma'] * (config['v_max'] - config['v_min']) / config['num_atoms'],
        )

        return cls(
            rng,
            network=network,
            config=flax.core.FrozenDict(**config),
            transform_to_probs=transform_to_probs,
            transform_from_probs=transform_from_probs,
        )


def get_config():
    config = ml_collections.ConfigDict(
        dict(
            agent_name='deas',  # Agent name.
            ob_dims=ml_collections.config_dict.placeholder(list),  # Observation dimensions (will be set automatically).
            action_dim=ml_collections.config_dict.placeholder(int),  # Action dimension (will be set automatically).
            action_sequence=ml_collections.config_dict.placeholder(
                int
            ),  # Action sequence length (will be set automatically).
            horizon_length=10,
            lr=3e-4,  # Learning rate.
            batch_size=256,  # Batch size.
            actor_hidden_dims=(512, 512, 512, 512),
            value_hidden_dims=(256, 256, 256, 256),
            critic_hidden_dims=(128, 128, 128, 128),
            layer_norm=True,  # Whether to use layer normalization.
            actor_layer_norm=True,  # Whether to use layer normalization for the actor.
            q_agg='mean',  # Q aggregation method. "All use mean"
            num_critic_ensembles=2,  # Number of critic ensembles.
            discount=0.99,  # Discount factor.
            nstep=1,  # Number of steps for n-step return.
            tau=0.005,  # Target network update rate.
            expectile=0.9,  # IQL expectile.
            alpha=10.0,  # FQL alpha.
            flow_steps=10,  # Number of flow steps.
            normalize_q_loss=True,  # Whether to normalize the Q loss.
            encoder=ml_collections.config_dict.placeholder(str),  # Visual encoder name (None, 'impala_small', etc.).
            max_grad_norm=0.0,  # Maximum gradient norm.
            # HLG
            num_atoms=101,  # Number of atoms.
            v_min=ml_collections.config_dict.placeholder(float),  # Minimum value.
            v_max=ml_collections.config_dict.placeholder(float),  # Maximum value.
            sigma=0.75,  # Sigma for HL-Gauss transform.
        )
    )
    return config