from argparse import Action
import copy
from typing import Any

import flax
import jax
import jax.numpy as jnp
from functools import partial
from jax.scipy.special import logsumexp
import numpy as np
import ml_collections
import optax

# from utils.encoders import encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import ActorVectorField, RewardFlowModel


class EvorAgent(flax.struct.PyTreeNode):
    """Evor agent with action chunking."""

    rng: Any
    network: Any
    config: Any = nonpytree_field()

    def critic_loss(self, batch, grad_params, rng):
        if self.config['action_chunking']:
            batch_actions = jnp.reshape(batch['actions'], (batch['actions'].shape[0], -1))
        else:
            batch_actions = batch['actions'][..., 0, :]  # take the first action
        batch_size = batch['observations'].shape[0]
        rng, y_0_rng, y_0_prime_rng, t_rng, t_prime_rng, next_action_rng = jax.random.split(rng, 6)

        # ===========================
        # === y prediction ====
        # ===========================
        y_0 = jax.random.normal(y_0_rng, (batch_size, 1))
        y_1 = self.sample_target_reward_to_go(
            batch['observations'],
            batch_actions,
            y_0,
        )
        t = jax.random.uniform(t_rng, (batch_size, 1))
        y_t = (1.0 - t) * y_0 + t * y_1
        y_pred = self.network.select('critic')(
            batch['observations'],
            batch_actions,
            y_t,
            t,
            params=grad_params,
        )

        # ===========================
        # === target ====
        # ===========================
        next_actions = self.sample_actions_train(batch['next_observations'][..., -1, :], rng=next_action_rng)
        if self.config['resample_y_t']:
            y_0 = jax.nomral(y_0_prime_rng, (batch_size, 1))
            t = jax.random.uniform(t_prime_rng, (batch_size, 1))
            y_1 = self.sample_target_reward_to_go(
                batch['next_observations'][..., -1, :],
                next_actions,
                y_t,
            )
            y_t = (1.0 - t) * y_0 + t * y_1
        y_prime = self.network.select('target_critic')(
            batch['next_observations'][..., -1, :],
            next_actions,
            y_t,
            t,
        )

        y_pred = jnp.squeeze(y_pred)
        y_prime = jnp.squeeze(y_prime)
        r = batch['rewards'][..., -1]
        target = r + ((self.config['discount'] ** self.config['horizon_length']) * batch['masks'][..., -1] * y_prime)
        td_loss = (jnp.square(y_pred - target) * batch['valid'][..., -1]).mean()

        # ===========================
        # === Explained Variance ====
        # ===========================
        y_var = jnp.var(target)

        def compute_explained_var(_):
            return 1 - jnp.var(target - y_pred) / y_var

        vel_explained_var = jax.lax.cond(y_var == 0.0, lambda _: jnp.nan, compute_explained_var, operand=None)

        return td_loss, {
            'td_loss': td_loss,
            'y_pred': jnp.mean(y_pred),
            'y_pred_max': jnp.max(y_pred),
            'y_pred_min': jnp.min(y_pred),
            'y_pred_var': jnp.var(y_pred),
            'target': jnp.mean(target),
            'target_max': jnp.max(target),
            'target_min': jnp.min(target),
            'target_var': y_var,
            'y_prime': jnp.mean(y_prime),
            'y_prime_max': jnp.max(y_prime),
            'y_prime_min': jnp.min(y_prime),
            'r': jnp.mean(r),
            'vel_explained_var': vel_explained_var,
            'vel_var': y_var,
        }

    def critic_pretrain_loss(self, batch, grad_params, rng):
        """Reward-to-go flow matching on (s , a₁…a_k)"""
        if self.config['action_chunking']:
            batch_actions = jnp.reshape(batch['actions'], (batch['actions'].shape[0], -1))
        else:
            batch_actions = batch['actions'][..., 0, :]  # take the first action
        batch_size = batch['observations'].shape[0]
        rng, y_0_rng, t_rng = jax.random.split(rng, 3)

        # Sample flow pairs x_t, v_t.
        y_0 = jax.random.normal(y_0_rng, (batch_size, 1))
        y_1 = batch['rewards_to_go'][..., -1]
        y_1 = jnp.expand_dims(y_1, axis=-1)
        t = jax.random.uniform(t_rng, (batch_size, 1))
        y_t = (1.0 - t) * y_0 + t * y_1
        vel = y_1 - y_0

        pred = self.network.select('critic')(
            batch['observations'],
            batch_actions,  #  k-action chunk
            y_t,
            t,
            params=grad_params,
        )
        pred = jnp.squeeze(pred)
        vel = jnp.squeeze(vel)
        reward_flow_loss = (jnp.square(pred - vel) * batch['valid'][..., -1]).mean()

        return reward_flow_loss, {
            'reward_flow_loss': reward_flow_loss,
        }

    def actor_loss(self, batch, grad_params, rng):
        """Compute the Evor actor loss."""
        if self.config['action_chunking']:
            batch_actions = jnp.reshape(
                batch['actions'], (batch['actions'].shape[0], -1)
            )  # fold in horizon_length together with action_dim
        else:
            batch_actions = batch['actions'][..., 0, :]  # take the first one
        batch_size, action_dim = batch_actions.shape
        rng, x_rng, t_rng = jax.random.split(rng, 3)

        # BC flow loss.
        x_0 = jax.random.normal(x_rng, (batch_size, action_dim))
        x_1 = batch_actions
        t = jax.random.uniform(t_rng, (batch_size, 1))
        x_t = (1.0 - t) * x_0 + t * x_1
        vel = x_1 - x_0

        pred = self.network.select('actor_bc_flow')(batch['observations'], x_t, t, params=grad_params)

        # only bc on the valid chunk indices
        if self.config['action_chunking']:
            bc_flow_loss = jnp.mean(
                jnp.reshape((pred - vel) ** 2, (batch_size, self.config['horizon_length'], self.config['action_dim']))
                * batch['valid'][..., None]
            )
        else:
            bc_flow_loss = jnp.mean(jnp.square(pred - vel))

        # ===========================
        # === Explained Variance ====
        # ===========================
        vel_var = jnp.var(vel)

        def compute_explained_var(_):
            return 1 - jnp.var(vel - pred) / vel_var

        actor_explained_var = jax.lax.cond(vel_var == 0.0, lambda _: jnp.nan, compute_explained_var, operand=None)

        return bc_flow_loss, {
            'actor_loss': bc_flow_loss,
            'actor_explained_var': actor_explained_var,
            'vel_var': vel_var,
            'pred_mean': jnp.mean(pred),
            'pred_var': jnp.var(pred),
            'vel_mean': jnp.mean(vel),
        }

    @jax.jit
    def total_loss(self, batch, grad_params, rng=None):
        """Compute the total loss."""
        info = {}
        rng = rng if rng is not None else self.rng

        rng, actor_rng, critic_rng = jax.random.split(rng, 3)

        critic_loss, critic_info = self.critic_loss(batch, grad_params, critic_rng)
        for k, v in critic_info.items():
            info[f'critic/{k}'] = v

        actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
        for k, v in actor_info.items():
            info[f'actor/{k}'] = v

        loss = critic_loss + actor_loss
        return loss, info

    @jax.jit
    def pretrain_total_loss(self, batch, grad_params, rng=None):
        """Compute the total loss."""
        info = {}
        rng = rng if rng is not None else self.rng

        rng, actor_rng, critic_rng = jax.random.split(rng, 3)

        critic_loss, critic_info = self.critic_pretrain_loss(batch, grad_params, critic_rng)
        for k, v in critic_info.items():
            info[f'critic/{k}'] = v

        actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
        for k, v in actor_info.items():
            info[f'actor/{k}'] = v

        loss = critic_loss + actor_loss
        return loss, info

    def target_update(self, network, module_name):
        """Update the target network."""
        new_target_params = jax.tree_util.tree_map(
            lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),
            self.network.params[f'modules_{module_name}'],
            self.network.params[f'modules_target_{module_name}'],
        )
        network.params[f'modules_target_{module_name}'] = new_target_params

    @staticmethod
    def _update(agent, batch):
        """Update the agent and return a new agent with information dictionary."""
        new_rng, rng = jax.random.split(agent.rng)

        def loss_fn(grad_params):
            return agent.total_loss(batch, grad_params, rng=rng)

        new_network, info = agent.network.apply_loss_fn(loss_fn=loss_fn)
        agent.target_update(new_network, 'critic')
        return agent.replace(network=new_network, rng=new_rng), info

    @staticmethod
    def _pretrain_update(agent, batch):
        """Update the agent and return a new agent with information dictionary."""
        new_rng, rng = jax.random.split(agent.rng)

        def loss_fn(grad_params):
            return agent.pretrain_total_loss(batch, grad_params, rng=rng)

        new_network, info = agent.network.apply_loss_fn(loss_fn=loss_fn)
        agent.target_update(new_network, 'critic')
        return agent.replace(network=new_network, rng=new_rng), info

    @jax.jit
    def update(self, batch):
        return self._update(self, batch)

    @jax.jit
    def pretrain_update(self, batch):
        return self._pretrain_update(self, batch)

    @jax.jit
    def batch_update(self, batch):
        """Update the agent and return a new agent with information dictionary."""
        # update_size = batch["observations"].shape[0]
        agent, infos = jax.lax.scan(self._update, self, batch)
        return agent, jax.tree_util.tree_map(lambda x: x.mean(), infos)

    @jax.jit
    def sample_target_reward_to_go(self, observations, actions, noises):
        """Compute actions from the BC flow model using the Euler method."""
        # if self.config['encoder'] is not None:
        #     observations = self.network.select('actor_bc_flow_encoder')(observations)
        rtg = noises
        # Euler method.
        for i in range(self.config['critic_flow_steps']):
            t = jnp.full((*rtg.shape[:-1], 1), i / self.config['critic_flow_steps'])
            vels = self.network.select('target_critic')(observations, actions, rtg, t, is_encoded=True)
            rtg = rtg + vels / self.config['critic_flow_steps']
        return rtg

    @jax.jit
    def sample_reward_to_go(self, observations, actions, noises):
        """Compute actions from the BC flow model using the Euler method."""
        # if self.config['encoder'] is not None:
        #     observations = self.network.select('actor_bc_flow_encoder')(observations)
        rtg = noises
        # Euler method.
        for i in range(self.config['critic_flow_steps']):
            t = jnp.full((*rtg.shape[:-1], 1), i / self.config['critic_flow_steps'])
            vels = self.network.select('critic')(observations, actions, rtg, t, is_encoded=True)
            rtg = rtg + vels / self.config['critic_flow_steps']
        return rtg

    @partial(jax.jit, static_argnames=('actor_type', 'actor_num_samples', 'num_rtg_samples'))
    def sample_actions(
        self,
        observations,
        rng=None,
        beta=1.0,
        q_star_beta=1.0,
        actor_num_samples=32,
        actor_type='bon',
        num_rtg_samples=10,
    ):
        action_dim = self.config['action_dim'] * (
            self.config['horizon_length'] if self.config['action_chunking'] else 1
        )
        rng_actions, rng_rtg, rng_select = jax.random.split(rng, 3) if rng is not None else (None, None, None)

        # 1) Propose action candidates from the flow policy (same as best-of-n)
        noises = jax.random.normal(
            rng_actions,
            (
                *observations.shape[: -len(self.config['ob_dims'])],  # batch_size
                actor_num_samples,
                action_dim,
            ),
        )
        obs_rep = jnp.repeat(observations[..., None, :], actor_num_samples, axis=-2)
        actions = self.compute_flow_actions(obs_rep, noises)
        actions = jnp.clip(actions, -1, 1)  # (*batch, K, action_dim)

        # Shape: (*batch, K, num_rtg_samples, 1) for noises fed into RTG sampler
        rtg_noises = jax.random.normal(
            rng_rtg,
            obs_rep.shape[:-1] + (num_rtg_samples, 1),
        )

        obs_tiled = jnp.repeat(obs_rep[..., None, :], num_rtg_samples, axis=-2)  # (..., K, N, obs_dim)
        act_tiled = jnp.repeat(actions[..., None, :], num_rtg_samples, axis=-2)  # (..., K, N, action_dim)

        # Broadcast obs/actions along the N axis and query reward-to-go
        # rtg: (..., K, N)
        rtg = self.sample_reward_to_go(
            obs_tiled,  # (..., K, N, obs_dim)
            act_tiled,  # (..., K, N, action_dim)
            rtg_noises,  # (..., K, N, 1)
        ).squeeze(-1)

        if actor_type == 'bon':
            # Mean over N, pick candidate with highest mean RTG
            score = rtg.mean(axis=-1)  # (..., K)
            indices = jnp.argmax(score, axis=-1)  # (...,)
        else:
            # Q* estimate via log-mean-exp, then sample proportional to exp(q/beta)
            lme = logsumexp(rtg / q_star_beta, axis=-1) - jnp.log(num_rtg_samples)  # (..., K)
            q_est = lme * q_star_beta
            logits = q_est / beta
            indices = jax.random.categorical(rng_select, logits, axis=-1)  # (...,)

        gather_idx = indices[..., None, None]
        actions = jnp.take_along_axis(actions, gather_idx, axis=-2).squeeze(-2)
        return actions

    @jax.jit
    def sample_actions_train(
        self,
        observations,
        rng=None,
    ):
        action_dim = self.config['action_dim'] * (
            self.config['horizon_length'] if self.config['action_chunking'] else 1
        )
        q_star_beta = self.config['train_q_star_beta']
        beta = self.config['train_beta']
        actor_num_samples = self.config['train_actor_num_samples']
        actor_type = self.config['train_actor_type']
        num_rtg_samples = self.config['num_train_rtg_samples']
        rng_actions, rng_rtg, rng_select = jax.random.split(rng, 3) if rng is not None else (None, None, None)

        # 1) Propose action candidates from the flow policy (same as best-of-n)
        noises = jax.random.normal(
            rng_actions,
            (
                *observations.shape[: -len(self.config['ob_dims'])],  # batch_size
                actor_num_samples,
                action_dim,
            ),
        )
        obs_rep = jnp.repeat(observations[..., None, :], actor_num_samples, axis=-2)
        actions = self.compute_flow_actions(obs_rep, noises)
        actions = jnp.clip(actions, -1, 1)  # (*batch, K, action_dim)

        # Shape: (*batch, K, num_rtg_samples, 1) for noises fed into RTG sampler
        rtg_noises = jax.random.normal(
            rng_rtg,
            obs_rep.shape[:-1] + (num_rtg_samples, 1),
        )

        obs_tiled = jnp.repeat(obs_rep[..., None, :], num_rtg_samples, axis=-2)  # (..., K, N, obs_dim)
        act_tiled = jnp.repeat(actions[..., None, :], num_rtg_samples, axis=-2)  # (..., K, N, action_dim)

        # Broadcast obs/actions along the N axis and query reward-to-go
        # rtg: (..., K, N)
        rtg = self.sample_reward_to_go(
            obs_tiled,  # (..., K, N, obs_dim)
            act_tiled,  # (..., K, N, action_dim)
            rtg_noises,  # (..., K, N, 1)
        ).squeeze(-1)

        if actor_type == 'bon':
            # Mean over N, pick candidate with highest mean RTG
            score = rtg.mean(axis=-1)  # (..., K)
            indices = jnp.argmax(score, axis=-1)  # (...,)
        else:
            # Q* estimate via log-mean-exp, then sample proportional to exp(q/beta)
            lme = logsumexp(rtg / q_star_beta, axis=-1) - jnp.log(num_rtg_samples)  # (..., K)
            q_est = lme * q_star_beta
            logits = q_est / beta
            indices = jax.random.categorical(rng_select, logits, axis=-1)  # (...,)

        gather_idx = indices[..., None, None]  # (..., 1, 1); will broadcast over K and A
        actions = jnp.take_along_axis(actions, gather_idx, axis=-2).squeeze(-2)

        return actions

    @jax.jit
    def compute_flow_actions(
        self,
        observations,
        noises,
    ):
        """Compute actions from the BC flow model using the Euler method."""
        # if self.config['encoder'] is not None:
        #     observations = self.network.select('actor_bc_flow_encoder')(observations)
        actions = noises
        # Euler method.
        for i in range(self.config['actor_flow_steps']):
            t = jnp.full((*observations.shape[:-1], 1), i / self.config['actor_flow_steps'])
            vels = self.network.select('actor_bc_flow')(observations, actions, t, is_encoded=True)
            actions = actions + vels / self.config['actor_flow_steps']
        actions = jnp.clip(actions, -1, 1)
        return actions

    @classmethod
    def create(
        cls,
        seed,
        ex_observations,
        ex_actions,
        config,
    ):
        """Create a new agent.

        Args:
            seed: Random seed.
            ex_observations: Example batch of observations.
            ex_actions: Example batch of actions.
            config: Configuration dictionary.
        """
        rng = jax.random.PRNGKey(seed)
        rng, init_rng = jax.random.split(rng, 2)

        ex_times = ex_actions[..., :1]
        ex_returns = ex_times * 0.0
        ob_dims = ex_observations.shape
        action_dim = ex_actions.shape[-1]
        if config['action_chunking']:
            full_actions = jnp.concatenate([ex_actions] * config['horizon_length'], axis=-1)
        else:
            full_actions = ex_actions
        full_action_dim = full_actions.shape[-1]

        # Define encoders.
        # encoders = dict()
        # if config['encoder'] is not None:
        #     encoder_module = encoder_modules[config['encoder']]
        #     encoders['critic'] = encoder_module()
        #     encoders['actor_bc_flow'] = encoder_module()
        #     encoders['actor_onestep_flow'] = encoder_module()

        # Define networks.
        # critic_def = Value(
        #     hidden_dims=config['value_hidden_dims'],
        #     layer_norm=config['layer_norm'],
        #     num_ensembles=config['num_qs'],
        #     # encoder=encoders.get('critic'),
        # )
        critic_def = RewardFlowModel(
            hidden_dims=config['value_hidden_dims'],
            layer_norm=config['layer_norm'],
            # encoder=encoders.get('critic'),
        )

        actor_bc_flow_def = ActorVectorField(
            hidden_dims=config['actor_hidden_dims'],
            action_dim=full_action_dim,
            layer_norm=config['actor_layer_norm'],
            # encoder=encoders.get('actor_bc_flow'),
            use_fourier_features=config['use_fourier_features'],
            fourier_feature_dim=config['fourier_feature_dim'],
        )

        network_info = dict(
            actor_bc_flow=(actor_bc_flow_def, (ex_observations, full_actions, ex_times)),
            critic=(critic_def, (ex_observations, full_actions, ex_returns, ex_times)),
            target_critic=(copy.deepcopy(critic_def), (ex_observations, full_actions, ex_returns, ex_times)),
        )
        # if encoders.get('actor_bc_flow') is not None:
        #     # Add actor_bc_flow_encoder to ModuleDict to make it separately callable.
        #     network_info['actor_bc_flow_encoder'] = (encoders.get('actor_bc_flow'), (ex_observations,))
        networks = {k: v[0] for k, v in network_info.items()}
        network_args = {k: v[1] for k, v in network_info.items()}

        network_def = ModuleDict(networks)
        if config['weight_decay'] > 0.0:
            network_tx = optax.adamw(learning_rate=config['lr'], weight_decay=config['weight_decay'])
        else:
            network_tx = optax.adam(learning_rate=config['lr'])
        network_params = network_def.init(init_rng, **network_args)['params']
        network = TrainState.create(network_def, network_params, tx=network_tx)

        params = network.params

        params[f'modules_target_critic'] = params[f'modules_critic']

        config['ob_dims'] = ob_dims
        config['action_dim'] = action_dim

        return cls(rng, network=network, config=flax.core.FrozenDict(**config))


def get_config():
    config = ml_collections.ConfigDict(
        dict(
            agent_name='evor',  # Agent name.
            ob_dims=ml_collections.config_dict.placeholder(list),  # Observation dimensions (will be set automatically).
            action_dim=ml_collections.config_dict.placeholder(int),  # Action dimension (will be set automatically).
            lr=3e-4,  # Learning rate.
            batch_size=256,  # Batch size.
            actor_hidden_dims=(512, 512, 512, 512),  # Actor network hidden dimensions.
            value_hidden_dims=(512, 512, 512, 512),  # Value network hidden dimensions.
            layer_norm=True,  # Whether to use layer normalization.
            actor_layer_norm=False,  # Whether to use layer normalization for the actor.
            discount=0.99,  # Discount factor.
            tau=0.005,  # Target network update rate.
            actor_flow_steps=10,  # Number of flow steps.
            critic_flow_steps=10,  # Number of flow steps.
            # encoder=ml_collections.config_dict.placeholder(str),  # Visual encoder name (None, 'impala_small', etc.).
            horizon_length=ml_collections.config_dict.placeholder(int),  # will be set
            action_chunking=True,  # False means n-step return
            use_fourier_features=False,
            fourier_feature_dim=64,
            weight_decay=0.0,
            # ---- train ----
            train_actor_type='q_star',  # 'q_star', 'bon'
            train_actor_num_samples=32,
            num_train_rtg_samples=1,
            resample_y_t=False,
            # ---- train: q_star ----
            train_q_star_beta=1.0,
            train_beta=0.001,
            # ---- eval ----
            eval_actor_type_sweep=('q_star'),
            eval_actor_num_samples_sweep=(32),  # n_bc_candidates values to sweep
            eval_num_rtg_samples_sweep=(50,),  # for actor_type="q_star" only
            # ---- eval: q_star ----
            eval_beta_sweep=(0.001,),
            eval_q_star_beta_sweep=(1.0,),
        )
    )
    return config
