from typing import Optional, Sequence, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import optax
import os
from functools import partial
from collections import deque

from sources.networks import critic, policy, value, discriminator
from sources.utils import Batch, MixBatch, InfoDict, Model, PRNGKey, target_update
from .critic import update_value, update_critic
from .actor import update_actor,update_actor_Q
from .disc import update_discriminator
from jax.debug import print

@partial(jax.jit, static_argnames=['double', 'args', 'cal_log',
                                   'discount', 'tau', 'expectile', 'actor_temperature',
                                   'reward_gap'])
def _update_UNIQ(
    rng: PRNGKey, actor: Model, critic: Model, value: Model,target_critic: Model, 
    bad_disc: Model, mix_disc: Model,
    bad_batch: Batch, mix_batch: MixBatch, discount: float, tau: float,
    expectile: float, actor_temperature: float, scale_mix: float, reward_gap: float,
    double: bool, args, cal_log: bool,
) -> Tuple[PRNGKey, Model, Model, Model, Model, Model, InfoDict]:
    """
    Main update function for UNIQ algorithm that updates all components (actor, critic, value function)
    """
    def get_ratio(observations,actions):
        # Calculate importance ratio between bad and mixed discriminators
        bad_out = bad_disc(observations,actions)
        bad_out = jnp.where(bad_out < args.clip_threshold, 0.0, bad_out)
        mix_out = mix_disc(observations,actions)
        return jax.lax.stop_gradient(bad_out/mix_out)
    
    # Get target values and clip them within reward bounds
    t_V = jax.lax.stop_gradient(value(mix_batch.observations))
    t_V = t_V.clip(min=-reward_gap/(1-discount), max=reward_gap/(1-discount))
    
    # Calculate importance ratios for mixed batch
    mix_ratio = get_ratio(mix_batch.observations,mix_batch.actions)
    
    # Combine bad and mixed batches for joint processing
    combined_observations = jnp.concatenate((bad_batch.observations,mix_batch.observations),axis=0)
    combined_actions = jnp.concatenate((bad_batch.actions,mix_batch.actions),axis=0)
    combined_rewards = jnp.concatenate((bad_batch.rewards,mix_batch.rewards),axis=0)*0
    combined_masks = jnp.concatenate((bad_batch.masks,mix_batch.masks),axis=0)
    combined_next_observations = jnp.concatenate((bad_batch.next_observations,mix_batch.next_observations),axis=0)
    combined_batch = Batch(observations=combined_observations, actions=combined_actions,
                        rewards=combined_rewards, masks=combined_masks, 
                        next_observations=combined_next_observations)
    
    # Track which samples are from bad demonstrations
    is_bad = jnp.concatenate((jnp.ones(bad_batch.observations.shape[0]),
                            mix_batch.is_bad),axis=0)
    
    # Normalize importance ratios using quantile clipping
    ratio = mix_ratio
    quantile_high = jnp.quantile(ratio, 0.95)
    quantile_low = jnp.quantile(ratio, 0.05)
    ratio = ratio.clip(min=quantile_low, max=quantile_high)
    ratio = (ratio - quantile_low)/(quantile_high-quantile_low)
    ratio = ratio/ratio.mean()
    
    visible_bad_ratio = jnp.ones(bad_batch.observations.shape[0])*ratio.mean()

    if (args.update_Q_inference):
        update_a = update_actor_Q
    else:
        update_a = update_actor

    ratio = jnp.concatenate((visible_bad_ratio, ratio),axis=0)
    reward_weight = -ratio
    batch_size = combined_batch.observations.shape[0]

    key, rng = jax.random.split(rng)
    new_value, value_info = update_value(target_critic, value, mix_batch, mix_batch.is_bad, expectile, 
                                            double, key, args, cal_log=cal_log)
    new_actor, actor_info = update_a(key, actor, target_critic, t_V, mix_batch, 
                                         reward_weight[batch_size//2:], mix_batch.is_bad, 
                                         actor_temperature, double, cal_log=cal_log)
    new_critic, critic_info = update_critic(critic, new_value, combined_batch, is_bad, reward_weight, discount, 
                                            double, key, reward_gap, args, cal_log=cal_log)
    new_target_critic = target_update(new_critic, target_critic, tau)

    return rng, new_actor, new_critic, new_value, new_target_critic, {
        **critic_info,
        **value_info,
        **actor_info
    }

@partial(jax.jit, static_argnames=['args', 'train_bad', 'train_mix', 'cal_log'])
def _update_discriminator(
    rng: PRNGKey, bad_disc: Model, mix_disc: Model, bad_batch: Batch, mix_batch: MixBatch, args,
    train_bad: bool, train_mix: bool, cal_log: bool
) -> Tuple[PRNGKey, Model, Model, bool, bool, InfoDict]:
    key, rng = jax.random.split(rng)
    if (train_bad):
        new_bad_disc, bad_disc_info, continue_bad = update_discriminator(
            key, discriminator=bad_disc, high_batch=bad_batch,
            low_batch=mix_batch, cal_log=cal_log, 
            is_bad=jnp.concatenate(
                (jnp.ones(bad_batch.observations.shape[0]),
                mix_batch.is_bad),axis=0), 
            noise_scale=args.noise_std,
            prefix='bad')
    else:
        new_bad_disc = bad_disc
        bad_disc_info = {}
        continue_bad = False
    if (train_mix):
        new_mix_disc, mix_disc_info, continue_mix = update_discriminator(
            key, discriminator=mix_disc, high_batch=mix_batch,
            low_batch=bad_batch, cal_log=cal_log, 
            is_bad=jnp.concatenate(
                (mix_batch.is_bad,
                jnp.ones(bad_batch.observations.shape[0])),axis=0), 
            noise_scale=args.noise_std,
            prefix='mix')
    else:
        new_mix_disc = mix_disc
        mix_disc_info = {}
        continue_mix = False

    ratio_info = {}
        

    return rng, new_bad_disc, new_mix_disc, continue_bad, continue_mix, {
        **bad_disc_info,
        **ratio_info,
        **mix_disc_info,
    }

class UNIQ(object):
    """
    UNIQ (Uncertainty-based Negative-sampling for Imitation with Quality) algorithm implementation
    Learns from mixed-quality demonstrations by identifying and learning from good behaviors
    """
    def __init__(self, seed: int,
                observations: jnp.ndarray,actions: jnp.ndarray,
                actor_lr: float,critic_lr: float,value_lr: float,disc_lr: float,
                hidden_dims: Sequence[int],discount: float,expectile: float,
                actor_temperature: float,dropout_rate: float,
                layernorm: bool,tau: float, double_q: bool = True,
                opt_decay_schedule: Optional[str] = 'cosine',
                max_steps: Optional[int] = None,
                value_dropout_rate: Optional[float] = None,
                reward_gap: float = 2.0,
                weight_decay: float = 0.0,
                args = None):
        """
        Initialize UNIQ algorithm with:
        - Neural networks: actor, critic, value function, and two discriminators
        - Optimizers and hyperparameters for training
        - Tracking variables for training state
        """
        self.expectile = expectile
        self.reward_gap = reward_gap
        self.tau = tau
        self.discount = discount
        self.actor_temperature = actor_temperature
        self.double_q = double_q
        self.args = args
        self.train_bad = True
        self.train_mix = True
        self.bad_arr = deque(maxlen=100)
        self.mix_arr = deque(maxlen=100)

        if (args.update_Q_inference):
            print('update actor Q inference')
        else:
            print('update actor BC')

        rng = jax.random.PRNGKey(seed)
        rng, actor_key, critic_key, value_key, bad_disc_key, mix_disc_key = jax.random.split(rng, 6)
        
        action_dim = actions.shape[1]

        #---- actor ----#
        print('actor with tanh squash = True')
        actor_def = policy.NormalTanhPolicy(
            hidden_dims,
            action_dim,
            log_std_scale=1e-3,
            log_std_min=-5.0,
            dropout_rate=dropout_rate,
            state_dependent_std=False,
            tanh_squash_distribution=True)
        
        if opt_decay_schedule == "cosine":
            print("Using cosine decay schedule")
            schedule_fn = optax.cosine_decay_schedule(-actor_lr, max_steps)
            optimiser = optax.chain(optax.scale_by_adam(),
                                    optax.scale_by_schedule(schedule_fn))
        else:
            print(f"Using AdamW with weight decay {weight_decay}")
            optimiser = optax.adamw(learning_rate=actor_lr,weight_decay=weight_decay)
        
        actor_net = Model.create(actor_def,
                                inputs=[actor_key, observations],
                                tx=optimiser)
        
        #---- critic ----#
        critic_def = critic.DoubleCritic(hidden_dims)
        critic_net = Model.create(critic_def,
                              inputs=[critic_key, observations, actions],
                              tx=optax.adamw(learning_rate=critic_lr,
                                             weight_decay=weight_decay))
        
        #---- bad discriminator ----#
        bad_disc_def = discriminator.Discriminator(hidden_dims)
        bad_disc_net = Model.create(bad_disc_def,
                                    inputs=[bad_disc_key, observations,actions],
                                    tx=optax.adamw(learning_rate=disc_lr,
                                                   weight_decay=0.01))
        
        #---- mixed discriminator ----#
        mix_disc_def = discriminator.Discriminator(hidden_dims)
        mix_disc_net = Model.create(mix_disc_def,
                                    inputs=[mix_disc_key, observations,actions],
                                    tx=optax.adamw(learning_rate=disc_lr,
                                                   weight_decay=0.01))
        
        #---- target critic ----#
        target_critic_net = Model.create(
            critic_def, inputs=[critic_key, observations, actions])
        
        #---- value critic -----#
        value_def = value.ValueCritic(hidden_dims,
                                          layer_norm=layernorm,
                                          dropout_rate=value_dropout_rate)
        value_net = Model.create(value_def,
                             inputs=[value_key, observations],
                             tx=optax.adamw(learning_rate=value_lr,
                                            weight_decay=weight_decay))

        self.actor = actor_net
        self.critic = critic_net
        self.value = value_net
        self.target_critic = target_critic_net
        self.bad_disc = bad_disc_net
        self.mix_disc = mix_disc_net
        self.rng = rng
       
    def update_discriminator(self, bad_batch: Batch, mixed_batch: MixBatch,step: int) -> InfoDict:
        """
        Update both discriminators:
        - bad_disc: identifies bad demonstrations
        - mix_disc: identifies mixed-quality demonstrations
        Adaptively enables/disables training based on performance
        """
        self.rng, self.bad_disc, self.mix_disc, continue_bad, continue_mix, info = _update_discriminator(
            self.rng, self.bad_disc, self.mix_disc, bad_batch, mixed_batch,self.args,
            train_bad=self.train_bad, train_mix=self.train_mix, cal_log=step%self.args.eval_interval==0)

        self.bad_arr.append(continue_bad.item())
        self.mix_arr.append(continue_mix.item())
        
        self.train_bad = bool(np.mean(self.bad_arr)<0.2)
        self.train_mix = bool(np.mean(self.mix_arr)<0.2)
        return info
     
    def update(self, bad_batch: Batch, mixed_batch: MixBatch, scale_mix: float, step: int) -> InfoDict:
        """
        Main training step that:
        1. Updates value network to estimate expected returns
        2. Updates actor to maximize expected returns
        3. Updates critic to estimate Q-values
        4. Updates target networks
        """
        new_rng, new_actor, new_critic, new_value, new_target_critic, info = _update_UNIQ(
            self.rng, self.actor, self.critic, self.value, self.target_critic,
            self.bad_disc,self.mix_disc,
            bad_batch, mixed_batch,self.discount, self.tau, self.expectile, 
            self.actor_temperature, scale_mix, self.reward_gap, self.double_q, self.args,
            cal_log=step%self.args.eval_interval==0)
        
        self.rng = new_rng
        self.actor = new_actor
        self.critic = new_critic
        self.value = new_value
        self.target_critic = new_target_critic

        return info
       
    def load_discriminator(self, load_dir: str) -> bool:
        bad_path = os.path.join(load_dir, 'bad_disc')
        mix_path = os.path.join(load_dir, 'mix_disc')
        if (os.path.exists(bad_path) and os.path.exists(mix_path)):
            self.bad_disc = self.bad_disc.load(bad_path)
            self.mix_disc = self.mix_disc.load(mix_path)
            return True
        return False
    
    def save_discriminator(self, save_dir: str):
        print(f'save discriminator to {save_dir}')
        os.makedirs(save_dir, exist_ok=True)
        self.bad_disc.save(os.path.join(save_dir, 'bad_disc'))
        self.mix_disc.save(os.path.join(save_dir, 'mix_disc'))
       
    def sample_actions(self,
                       observations: np.ndarray,
                       random_tempurature: float = 1.0,
                       training: bool = False) -> jnp.ndarray:
        """
        Sample actions from the policy for given observations
        Uses temperature parameter to control exploration vs exploitation
        """
        rng, actions = policy.sample_actions(self.rng, 
                                           self.actor.apply_fn,
                                           self.actor.params, 
                                           observations,
                                           random_tempurature, 
                                           training=training)
        self.rng = rng
        actions = np.asarray(actions)
        return np.clip(actions, -1, 1)
        
    def load(self, save_dir: str):
        if (os.path.exists(save_dir)):
            print(f"Loading model from {save_dir}")
            self.actor = self.actor.load(os.path.join(save_dir, 'actor'))
            self.critic = self.critic.load(os.path.join(save_dir, 'critic'))
            self.value = self.value.load(os.path.join(save_dir, 'value'))
            self.target_critic = self.target_critic.load(os.path.join(save_dir, 'target_critic'))
        else:
            print(f"Model not found in {save_dir}")

    def save(self, save_dir: str):
        print(f"Saving model to {save_dir}")
        os.makedirs(save_dir, exist_ok=True)
        self.actor.save(os.path.join(save_dir, 'actor'))
        self.critic.save(os.path.join(save_dir, 'critic'))
        self.value.save(os.path.join(save_dir, 'value'))
        self.target_critic.save(os.path.join(save_dir, 'target_critic'))