from typing import Any

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import ml_collections
import optax
from utils.encoders import GCEncoder, encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import MLP, GCActor, GCDiscreteActor, GCValue, Identity, LengthNormalize


class OTAAgent(flax.struct.PyTreeNode):
    """Option-aware temporally abstracted Value Learning (OTA) agent."""

    rng: Any
    network: Any
    config: Any = nonpytree_field()

    @staticmethod
    def expectile_loss(adv, diff, expectile):
        """Compute the expectile loss."""
        weight = jnp.where(adv >= 0, expectile, (1 - expectile))
        return weight * (diff**2)

    def low_value_loss(self, batch, grad_params):
        """Compute the IVL value loss."""
        (next_v1_t, next_v2_t) = self.network.select('target_low_value')(batch['next_observations'], batch['value_goals'])
        next_v_t = jnp.minimum(next_v1_t, next_v2_t)
        q = batch['rewards'] + self.config['low_discount'] * batch['masks'] * next_v_t

        (v1_t, v2_t) = self.network.select('target_low_value')(batch['observations'], batch['value_goals'])
        v_t = (v1_t + v2_t) / 2
        adv = q - v_t

        q1 = batch['rewards'] + self.config['low_discount'] * batch['masks'] * next_v1_t
        q2 = batch['rewards'] + self.config['low_discount'] * batch['masks'] * next_v2_t
        (v1, v2) = self.network.select('low_value')(batch['observations'], batch['value_goals'], params=grad_params)
        v = (v1 + v2) / 2

        low_value_loss1 = self.expectile_loss(adv, q1 - v1, self.config['expectile']).mean()
        low_value_loss2 = self.expectile_loss(adv, q2 - v2, self.config['expectile']).mean()
        low_value_loss = low_value_loss1 + low_value_loss2

        return low_value_loss, {
            'low_value_loss': low_value_loss,
            'low_v_mean': v.mean(),
            'low_v_max': v.max(),
            'low_v_min': v.min(),
        }

    def high_value_loss(self, batch, grad_params):
        """Compute the option-aware value loss."""
        (next_v1_t, next_v2_t) = self.network.select('target_high_value')(batch['high_value_option_observations'], batch['high_value_goals'])
        next_v_t = jnp.minimum(next_v1_t, next_v2_t)
        q = batch['high_value_rewards'] + self.config['high_discount'] * batch['high_value_masks'] * next_v_t

        (v1_t, v2_t) = self.network.select('target_high_value')(batch['observations'], batch['high_value_goals'])
        v_t = (v1_t + v2_t) / 2
        adv = q - v_t

        q1 = batch['high_value_rewards'] + self.config['high_discount'] * batch['high_value_masks'] * next_v1_t
        q2 = batch['high_value_rewards'] + self.config['high_discount'] * batch['high_value_masks'] * next_v2_t
        (v1, v2) = self.network.select('high_value')(batch['observations'], batch['high_value_goals'], params=grad_params)
        v = (v1 + v2) / 2

        high_value_loss1 = self.expectile_loss(adv, q1 - v1, self.config['expectile']).mean()
        high_value_loss2 = self.expectile_loss(adv, q2 - v2, self.config['expectile']).mean()
        high_value_loss = high_value_loss1 + high_value_loss2
        
        return high_value_loss, {
            'high_value_loss': high_value_loss,
            'high_v_mean': v.mean(),
            'high_v_max': v.max(),
            'high_v_min': v.min()
        }

    def low_actor_loss(self, batch, grad_params):
        """Compute the low-level actor loss."""
        v1, v2 = self.network.select('low_value')(batch['observations'], batch['low_actor_goals'])
        nv1, nv2 = self.network.select('low_value')(batch['next_observations'], batch['low_actor_goals'])
        v = (v1 + v2) / 2
        nv = (nv1 + nv2) / 2
        adv = nv - v

        exp_a = jnp.exp(adv * self.config['low_alpha'])
        exp_a = jnp.minimum(exp_a, 100.0)

        # Compute the goal representations of the subgoals.
        goal_reps = self.network.select('goal_rep')(
            jnp.concatenate([batch['observations'], batch['low_actor_goals']], axis=-1),
            params=grad_params,
        )
        if not self.config['low_actor_rep_grad']:
            # Stop gradients through the goal representations.
            goal_reps = jax.lax.stop_gradient(goal_reps)
        dist = self.network.select('low_actor')(batch['observations'], goal_reps, goal_encoded=True, params=grad_params)
        log_prob = dist.log_prob(batch['actions'])

        actor_loss = -(exp_a * log_prob).mean()

        actor_info = {
            'actor_loss': actor_loss,
            'adv': adv.mean(),
            'bc_log_prob': log_prob.mean(),
        }
        if not self.config['discrete']:
            actor_info.update(
                {
                    'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),
                    'std': jnp.mean(dist.scale_diag),
                }
            )

        return actor_loss, actor_info

    def high_actor_loss(self, batch, grad_params):
        """Compute the high-level actor loss."""
        v1, v2 = self.network.select('high_value')(batch['observations'], batch['high_actor_goals'])
        nv1, nv2 = self.network.select('high_value')(batch['high_actor_targets'], batch['high_actor_goals'])
        v = (v1 + v2) / 2
        nv = (nv1 + nv2) / 2
        adv = nv - v

        exp_a = jnp.exp(adv * self.config['high_alpha'])
        exp_a = jnp.minimum(exp_a, 100.0)

        dist = self.network.select('high_actor')(batch['observations'], batch['high_actor_goals'], params=grad_params)
        target = self.network.select('goal_rep')(
            jnp.concatenate([batch['observations'], batch['high_actor_targets']], axis=-1)
        )
        log_prob = dist.log_prob(target)

        actor_loss = -(exp_a * log_prob).mean()

        return actor_loss, {
            'actor_loss': actor_loss,
            'adv': adv.mean(),
            'bc_log_prob': log_prob.mean(),
            'mse': jnp.mean((dist.mode() - target) ** 2),
            'std': jnp.mean(dist.scale_diag),
        }
        

    @jax.jit
    def total_loss(self, batch, grad_params, rng=None):
        """Compute the total loss."""
        info = {}
        rng = rng if rng is not None else self.rng

        low_value_loss, low_value_info = self.low_value_loss(batch, grad_params)
        for k, v in low_value_info.items():
            info[f'low_value/{k}'] = v

        high_value_loss, high_value_info = self.high_value_loss(batch, grad_params)
        for k, v in high_value_info.items():
            info[f'high_value/{k}'] = v

        low_actor_loss, low_actor_info = self.low_actor_loss(batch, grad_params)
        for k, v in low_actor_info.items():
            info[f'low_actor/{k}'] = v

        high_actor_loss, high_actor_info = self.high_actor_loss(batch, grad_params)
        for k, v in high_actor_info.items():
            info[f'high_actor/{k}'] = v

        loss = low_value_loss + high_value_loss + low_actor_loss + high_actor_loss

        return loss, info

    def target_update(self, network, module_name):
        """Update the target network."""
        new_target_params = jax.tree_util.tree_map(
            lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),
            self.network.params[f'modules_{module_name}'],
            self.network.params[f'modules_target_{module_name}'],
        )
        network.params[f'modules_target_{module_name}'] = new_target_params

    @jax.jit
    def update(self, batch):
        """Update the agent and return a new agent with information dictionary."""
        new_rng, rng = jax.random.split(self.rng)

        def loss_fn(grad_params):
            return self.total_loss(batch, grad_params, rng=rng)

        new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)

        self.target_update(new_network, 'low_value')
        self.target_update(new_network, 'high_value')

        return self.replace(network=new_network, rng=new_rng), info

    @jax.jit
    def sample_actions(
        self,
        observations,
        goals=None,
        seed=None,
        temperature=1.0,
    ):
        """Sample actions from the actor.

        It first queries the high-level actor to obtain subgoal representations, and then queries the low-level actor
        to obtain raw actions.
        """
        high_seed, low_seed = jax.random.split(seed)

        high_dist = self.network.select('high_actor')(observations, goals, temperature=temperature)
        goal_reps = high_dist.sample(seed=high_seed)
        goal_reps = goal_reps / jnp.linalg.norm(goal_reps, axis=-1, keepdims=True) * jnp.sqrt(goal_reps.shape[-1])

        low_dist = self.network.select('low_actor')(observations, goal_reps, goal_encoded=True, temperature=temperature)
        actions = low_dist.sample(seed=low_seed)

        if not self.config['discrete']:
            actions = jnp.clip(actions, -1, 1)
        return actions

    @jax.jit
    def sample_values(self, observations, goals):
        """Return the averaged low-level value estimate V(s, g)."""
        v1, v2 = self.network.select('low_value')(observations, goals)
        return (v1 + v2) / 2

    def _value_to_steps_gcneg(self, v: jnp.ndarray, gamma: float, eps: float = 1e-3) -> jnp.ndarray:
        """
        Convert V(s,g) in reward=0(success) / -1(failure) semantics to "expected steps":
            V = - (1 - gamma^d) / (1 - gamma)  =>  d = log(1 + (1-gamma)*V) / log(gamma)

        The unit of d depends on the "time granularity" of this value:
          - low_value: d is env-step
          - high_value: d is option-step (strictly speaking, the number of failed options K = D_opt - 1; see the high processing below)
        """
        # Theoretical range v ∈ [ -1/(1-γ), 0 ]
        v_min = -1.0 / (1.0 - gamma) + eps
        v_max = -eps  # Numerical stability: avoid boundary problems with log(1) or log(<=0)
        v = jnp.clip(v, v_min, v_max)

        steps = jnp.log1p((1.0 - gamma) * v) / jnp.log(gamma)  # log(gamma)<0，steps>=0
        steps = jnp.maximum(steps, 0.0)
        # Prevent NaN (TTGS also does this)
        return jnp.where(jnp.isnan(steps), 10000.0, steps)

    @jax.jit
    def sample_low_distance(self, observations: jnp.ndarray, goals: jnp.ndarray) -> jnp.ndarray:
        """
        Use low_value to reverse the env-step distance (high resolution, suitable for TTGS's TE/threshold/sorting).
        """
        v1, v2 = self.network.select('low_value')(observations, goals)
        v = 0.5 * (v1 + v2)
        gamma = float(self.config['low_discount'])
        return self._value_to_steps_gcneg(v, gamma)

    @jax.jit
    def sample_high_distance(self, observations: jnp.ndarray, goals: jnp.ndarray) -> jnp.ndarray:
        """
        Use high_value to reverse the env-step distance (more stable but less resolution).
        """
        v1, v2 = self.network.select('high_value')(observations, goals)
        v = 0.5 * (v1 + v2)

        gamma = float(self.config['high_discount'])
        n = float(self.config['abstraction_factor'])

        # First reverse to get K = D_opt - 1 (number of failed options)
        k_opt = self._value_to_steps_gcneg(v, gamma)

        # Then reverse to D_opt = K + 1, and convert to env-step: d_env ≈ n * D_opt
        d_env = (k_opt + 1.0) * n
        return jnp.where(jnp.isnan(d_env), 10000.0, d_env)

    @jax.jit
    def sample_distance(self, observations: jnp.ndarray, goals: jnp.ndarray) -> jnp.ndarray:
        """
        TTGS final distance (env-step scale):
          - short distance: use low (fine, suitable for precise filtering in TE=dist/H and threshold=24)
          - long distance: gradually switch to high (more stable, reduce "false shortcuts" on the graph)

        Only two existing parameters:
          - subgoal_steps: switch start point
          - abstraction_factor: transition width + high bucket width (linearly mixed in [subgoal_steps, subgoal_steps+n])
        """
        d_low = self.sample_low_distance(observations, goals)
        d_high = self.sample_high_distance(observations, goals)

        switch = float(self.config['subgoal_steps'])          # env-step scale
        ramp = float(self.config['abstraction_factor'])       # transition width (env-step)
        ramp = jnp.maximum(ramp, 1.0)

        # linear ramp: w=0 when d_low<=switch
        #           w=1 when d_low>=switch+ramp
        w = jnp.clip((d_low - switch) / ramp, 0.0, 1.0)

        d_high = jnp.maximum(d_high, d_low)

        d = (1.0 - w) * d_low + w * d_high
        return jnp.where(jnp.isnan(d), 10000.0, d)

    @classmethod
    def create(
        cls,
        seed,
        ex_observations,
        ex_actions,
        config,
    ):
        """Create a new agent.

        Args:
            seed: Random seed.
            ex_observations: Example batch of observations.
            ex_actions: Example batch of actions. In discrete-action MDPs, this should contain the maximum action value.
            config: Configuration dictionary.
        """
        rng = jax.random.PRNGKey(seed)
        rng, init_rng = jax.random.split(rng, 2)

        ex_goals = ex_observations
        if config['discrete']:
            action_dim = ex_actions.max() + 1
        else:
            action_dim = ex_actions.shape[-1]

        # Define (state-dependent) subgoal representation phi([s; g]) that outputs a length-normalized vector.
        if config['encoder'] is not None:
            encoder_module = encoder_modules[config['encoder']]
            goal_rep_seq = [encoder_module()]
        else:
            goal_rep_seq = []
        goal_rep_seq.append(
            MLP(
                hidden_dims=(*config['value_hidden_dims'], config['rep_dim']),
                activate_final=False,
                layer_norm=config['layer_norm'],
            )
        )
        goal_rep_seq.append(LengthNormalize())
        goal_rep_def = nn.Sequential(goal_rep_seq)         

        if config['encoder'] is not None:
            low_value_encoder_def = GCEncoder(state_encoder=encoder_module(), concat_encoder=goal_rep_def)
            target_low_value_encoder_def = GCEncoder(state_encoder=encoder_module(), concat_encoder=goal_rep_def)
            high_value_encoder_def = GCEncoder(state_encoder=encoder_module(), concat_encoder=goal_rep_def)
            target_high_value_encoder_def = GCEncoder(state_encoder=encoder_module(), concat_encoder=goal_rep_def)
            low_actor_encoder_def = GCEncoder(state_encoder=encoder_module(), concat_encoder=goal_rep_def)
            high_actor_encoder_def = GCEncoder(concat_encoder=encoder_module())
        else:
            low_value_encoder_def = GCEncoder(state_encoder=Identity(), concat_encoder=goal_rep_def)
            target_low_value_encoder_def = GCEncoder(state_encoder=Identity(), concat_encoder=goal_rep_def)
            high_value_encoder_def = GCEncoder(state_encoder=Identity(), concat_encoder=goal_rep_def)
            target_high_value_encoder_def = GCEncoder(state_encoder=Identity(), concat_encoder=goal_rep_def)
            low_actor_encoder_def = GCEncoder(state_encoder=Identity(), concat_encoder=goal_rep_def)
            high_actor_encoder_def = None
        
        # Define value and actor networks.
        low_value_def = GCValue(
            hidden_dims=config['value_hidden_dims'],
            layer_norm=config['layer_norm'],
            ensemble=True,
            gc_encoder=low_value_encoder_def,
        )
        target_low_value_def = GCValue(
            hidden_dims=config['value_hidden_dims'],
            layer_norm=config['layer_norm'],
            ensemble=True,
            gc_encoder=target_low_value_encoder_def,
        )

        high_value_def = GCValue(
            hidden_dims=config['value_hidden_dims'],
            layer_norm=config['layer_norm'],
            ensemble=True,
            gc_encoder=high_value_encoder_def,
        )
        target_high_value_def = GCValue(
            hidden_dims=config['value_hidden_dims'],
            layer_norm=config['layer_norm'],
            ensemble=True,
            gc_encoder=target_high_value_encoder_def,
        )
        
        if config['discrete']:
            low_actor_def = GCDiscreteActor(
                hidden_dims=config['actor_hidden_dims'],
                action_dim=action_dim,
                gc_encoder=low_actor_encoder_def,
            )
        else:
            low_actor_def = GCActor(
                hidden_dims=config['actor_hidden_dims'],
                action_dim=action_dim,
                state_dependent_std=False,
                const_std=config['const_std'],
                gc_encoder=low_actor_encoder_def,
            )

        high_actor_def = GCActor(
            hidden_dims=config['actor_hidden_dims'],
            action_dim=config['rep_dim'],
            state_dependent_std=False,
            const_std=config['const_std'],
            gc_encoder=high_actor_encoder_def,
        )
        
        network_info = dict(
            goal_rep=(goal_rep_def, (jnp.concatenate([ex_observations, ex_goals], axis=-1))),
            low_value=(low_value_def, (ex_observations, ex_goals)),
            target_low_value=(target_low_value_def, (ex_observations, ex_goals)),
            high_value=(high_value_def, (ex_observations, ex_goals)),
            target_high_value=(target_high_value_def, (ex_observations, ex_goals)),
            low_actor=(low_actor_def, (ex_observations, ex_goals)),
            high_actor=(high_actor_def, (ex_observations, ex_goals)),
        )
        
        networks = {k: v[0] for k, v in network_info.items()}
        network_args = {k: v[1] for k, v in network_info.items()}

        network_def = ModuleDict(networks)
        network_tx = optax.adam(learning_rate=config['lr'])
        network_params = network_def.init(init_rng, **network_args)['params']
        network = TrainState.create(network_def, network_params, tx=network_tx)

        params = network.params
        params['modules_target_low_value'] = params['modules_low_value']
        params['modules_target_high_value'] = params['modules_high_value']

        config['low_discount'] = config['discount']
        config['high_discount'] = config['discount']

        return cls(rng, network=network, config=flax.core.FrozenDict(**config))


def get_config():
    config = ml_collections.ConfigDict(
        dict(
            # Agent hyperparameters.
            agent_name='ota',  # Agent name.
            lr=3e-4,  # Learning rate.
            batch_size=1024,  # Batch size.
            actor_hidden_dims=(512, 512, 512),  # Actor network hidden dimensions.
            value_hidden_dims=(512, 512, 512),  # Value network hidden dimensions.
            layer_norm=True,  # Whether to use layer normalization.
            discount=0.99,  # Discount factor.
            tau=0.005,  # Target network update rate.
            expectile=0.7,  # IQL expectile.
            low_alpha=3.0,  # Low-level AWR temperature.
            high_alpha=3.0,  # High-level AWR temperature.
            subgoal_steps=25,  # Subgoal steps for actor.
            abstraction_factor=5,  # Abstraction factor for high-level value function.
            rep_dim=10,  # Goal representation dimension.
            low_actor_rep_grad=False,  # Whether low-actor gradients flow to goal representation (use True for pixels).
            const_std=True,  # Whether to use constant standard deviation for the actors.
            discrete=False,  # Whether the action space is discrete.
            encoder=ml_collections.config_dict.placeholder(str),  # Visual encoder name (None, 'impala_small', etc.).
            # Dataset hyperparameters.
            dataset_class='HGCDataset',  # Dataset class name.
            value_p_curgoal=0.2,  # Probability of using the current state as the value goal.
            value_p_trajgoal=0.5,  # Probability of using a future state in the same trajectory as the value goal.
            value_p_randomgoal=0.3,  # Probability of using a random state as the value goal.
            value_geom_sample=True,  # Whether to use geometric sampling for future value goals.
            actor_p_curgoal=0.0,  # Probability of using the current state as the actor goal.
            actor_p_trajgoal=1.0,  # Probability of using a future state in the same trajectory as the actor goal.
            actor_p_randomgoal=0.0,  # Probability of using a random state as the actor goal.
            actor_geom_sample=False,  # Whether to use geometric sampling for future actor goals.
            gc_negative=True,  # Whether to use '0 if s == g else -1' (True) or '1 if s == g else 0' (False) as reward.
            p_aug=0.0,  # Probability of applying image augmentation.
            frame_stack=ml_collections.config_dict.placeholder(int),  # Number of frames to stack.
        )
    )
    return config
