import copy
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn
from typing import Sequence
from functools import partial
import numpy as np

from agent.networks import get_ensemble, GaussianActor, Critic, Encoder

import pickle

from utils.etc_utils import get_latest_checkpoint
from pathlib import Path

# --------------------------------------------------
# CSF-SAC Agent
# --------------------------------------------------

@jax.jit
def get_kl_loss(mu, logstd):
    logvar = 2.0 * logstd
    var    = jnp.exp(logvar)
    kl_per_ex = 0.5 * jnp.sum(var + mu**2 - 1.0 - logvar + 1e-8, axis=-1)
    kl_loss = kl_per_ex.mean()
    return kl_loss

@jax.jit
def get_alpha(log_alpha):
    return jnp.exp(log_alpha)

@jax.jit
def get_lambda(log_lambda):
    return jnp.exp(log_lambda)

class MISL_SAC_Agent:
    def __init__(self, args, obs_dim: int, action_dim: int, option_dim:int, phi_dims: Sequence[int],
                 actor_dims: Sequence[int], critic_dims: Sequence[int], seed: int = 0,
                 lr: float = 1e-4, lag_lr: float = 1e-4, tau: float = 0.005, gamma: float = 0.99,
                 target_entropy: float = None, n_critics: int = 2, env_normalizer=None):
        self.args = args
        self.gamma = gamma
        self.tau = tau
        self.rng = jax.random.PRNGKey(seed)
        self.target_entropy = -np.prod(action_dim).item() / 2. 
        self.option_dim = option_dim
        
        if args.obs_type == 'states':
            pixel_shape = None
            self.pixel_dim=None
        else:
            pixel_shape = args.pixel_shape
            self.pixel_dim = args.pixel_dim

        self.phi = Encoder(phi_dims, use_encoder=(args.obs_type == 'pixels' or args.obs_type == 'hybrid'), 
                                    pix_latent_dim=args.pixel_latent_dim,
                                    pixel_shape=pixel_shape, pixel_dim=self.pixel_dim)

        self.actor = GaussianActor(actor_dims, action_dim, 
                                   use_encoder=(args.obs_type == 'pixels' or args.obs_type == 'hybrid'), 
                                   pix_latent_dim=args.pixel_latent_dim,
                                   pixel_dim=self.pixel_dim,
                                   pixel_shape=pixel_shape,
                                   use_local=args.with_local_state)
        
        self.critic = get_ensemble(n_critics, Critic, methods=['__call__'])(critic_dims, option_dim, 
                                   use_encoder=(args.obs_type == 'pixels' or args.obs_type == 'hybrid'), 
                                   use_local=args.local_for_critic,
                                   pix_latent_dim=args.pixel_latent_dim,
                                   pixel_dim=self.pixel_dim,
                                   pixel_shape=pixel_shape,
                                   )

        dummy_s = jnp.zeros((1, obs_dim + option_dim))
        dummy_z  = jnp.zeros((1, option_dim))
        dummy_obs = jnp.zeros((1, obs_dim))
        dummy_act = jnp.zeros((1, action_dim))

        self.rng, *keys = jax.random.split(self.rng, 8)
        k_phi, k_actor, k_c, k_alpha, k_lambda, k_prior, k_decoder = keys
        self.phi_params = self.phi.init(k_phi, dummy_obs)
        self.target_phi_params = copy.deepcopy(self.phi_params)
        self.actor_params = self.actor.init(k_actor, dummy_s, k_actor)

        self.critic_params = self.critic.init(k_c, dummy_s, dummy_act)
        self.critic_target_params = self.critic_params.copy()

        self.log_alpha = jnp.log(0.01)
        self.log_lambda = jnp.log(30.)
        self.num_negative_z = self.args.batch_size

        self.opt_phi = optax.adam(lr)
        self.opt_actor = optax.adam(lr)
        self.opt_critic = optax.adam(lr)
        self.opt_alpha = optax.adam(lag_lr)
        self.opt_lambda = optax.adam(lag_lr)

        self.opt_phi_state = self.opt_phi.init(self.phi_params)
        self.opt_target_phi_state = self.opt_phi.init(self.target_phi_params)
        self.opt_lambda_state = self.opt_lambda.init(self.log_lambda)

        self.opt_actor_state = self.opt_actor.init(self.actor_params)
        self.opt_critic_state = self.opt_critic.init(self.critic_params)
        self.opt_alpha_state = self.opt_alpha.init(self.log_alpha)
        self.dual_slack = 1e-3

        self.rng, k_le, k_ld, k_ge, k_gd, k_dm = jax.random.split(self.rng, 6)


    def obs_preprocess(self, s: jnp.ndarray) -> jnp.ndarray:
        if self.args.obs_type == 'hybrid':
            return s.at[:, self.pixel_dim:].set((s[:, self.pixel_dim:] - self.env_normalizer.mean) / np.sqrt(self.env_normalizer.var + self.rms_epsilon))
        return (s - self.env_normalizer.mean) / np.sqrt(self.env_normalizer.var + self.rms_epsilon)


    @partial(jax.jit, static_argnums=0)
    def input_preprocess(self, s: jnp.ndarray, z: jnp.ndarray) -> jnp.ndarray:
        return jnp.concatenate([s, z], axis=-1)


    @partial(jax.jit, static_argnums=(0,))
    def _skill_preprocess(self, delta):
        if self.args.discrete:
            z_skill= jnp.eye(self.option_dim)[jnp.argmax(delta, axis=-1)]
        else:
            z_skill = delta / (jnp.linalg.norm((delta), axis=-1, keepdims=True) + 1e-8)
        return z_skill
    
    @partial(jax.jit, static_argnums=(0,))
    def _reward_function(self, delta, z):
        if self.args.discrete:
            maks = (z - z.mean(axis=-1, keepdims=True)) * self.option_dim / (self.option_dim - 1 if self.option_dim != 1 else 1)
            r = (delta * maks).sum(axis=-1)
        else:
            r = (delta * z).sum(axis=-1)
        return r
    
    
    def learn(self, jax_batch, train_itr, train_only_local_encoder=False):
        self.rng, update_rng = jax.random.split(self.rng)
        
        opt_states = (self.opt_phi_state, self.opt_actor_state,
                    self.opt_critic_state, self.opt_alpha_state, self.opt_lambda_state)
        params = (self.phi_params, self.target_phi_params, self.actor_params,
                self.critic_params, self.critic_target_params, self.log_alpha, self.log_lambda,)
        

        new_params, new_opt_states, metrics, new_rng = self.update(
            jax_batch, params, opt_states, self.rng)
        
        self.opt_phi_state, self.opt_actor_state, \
        self.opt_critic_state, self.opt_alpha_state, self.opt_lambda_state = new_opt_states

        self.phi_params, self.target_phi_params, self.actor_params, \
        self.critic_params, self.critic_target_params, self.log_alpha, self.log_lambda = new_params

        self.rng = new_rng
        return metrics
    

    @partial(jax.jit, static_argnums=(0,))
    def _get_action(self, actor_param, s, z, key):
        a, _ = self.actor.apply(actor_param, self.input_preprocess(s, z), key)
        return a[0]


    @partial(jax.jit, static_argnums=(0,))
    def _get_eval_action(self, actor_param, s, z):
        a = self.actor.apply(actor_param, self.input_preprocess(s, z), method='eval_action')
        return a[0]
        

    def get_action(self, s, z, eval=False):
        s = jnp.asarray(s[None]).astype(jnp.float32)
        z = jnp.asarray(z[None]).astype(jnp.float32)
        if eval:
            return np.array(self._get_eval_action(self.actor_params, s, z))

        self.rng, key = jax.random.split(self.rng)
        return np.array(self._get_action(self.actor_params, s, z, key))

    @partial(jax.jit, static_argnums=0)
    def update(self, batch, params, opt_states, rng):
        
        (opt_phi_state, opt_actor_state, opt_critic_state, opt_alpha_state, opt_lambda_state) = opt_states
        (phi_params, target_phi_params, actor_params, critic_params, critic_target_params, log_alpha, log_lambda) = params

        rng, key, key2, key_rand, key_rand2, key_prior, subkey = jax.random.split(rng, 7)
        metrics = {}
        s = batch['obs']
        s2 = batch['next_obs']
        a = batch['act']
        r_ext = batch['rew']
        done = batch['done'][:, None]
        z = zphi = batch['options']
        z2 = batch['next_options']
        sc = batch['after_c_obs']

        # preprocess obs
        sz = self.input_preprocess(s, z)
        s2z2 = self.input_preprocess(s2, z2)
        a2, logp2 = self.actor.apply(actor_params, s2z2, key2)
        alpha_v = get_alpha(log_alpha)

        neg_z = jax.random.normal(key_rand, shape=(self.num_negative_z, z.shape[-1]))
        neg_z = neg_z / (jnp.linalg.norm(neg_z, axis=-1, keepdims=True)+1e-8)

        # phi + lambda
        def phi_loss_fn(phi_p):
            ps = self.phi.apply(phi_p, s)
            ps2 = self.phi.apply(phi_p, s2)
            delta = ps2 - ps  # [B, d]
            pos = (delta * zphi).sum(axis=1) 
            neg_logits = jnp.sum(delta[:, None] * neg_z[None], axis=-1)     # [B, B] 
            lse = jax.nn.logsumexp(neg_logits, axis=-1)  # [B]
            loss = -jnp.mean(pos - 5.0 * lse)
            return loss, (jnp.mean(pos), jnp.mean(lse))

        (phi_loss, (mean_pos, mean_lse)), grads_phi = jax.value_and_grad(
            phi_loss_fn, has_aux=True)(phi_params)
        updates_phi, opt_phi_state = self.opt_phi.update(grads_phi, opt_phi_state)
        phi_params = optax.apply_updates(phi_params, updates_phi)

        r = self.phi.apply(phi_params, s2) - self.phi.apply(phi_params, s)
        a2, logp2 = self.actor.apply(actor_params, s2z2, key_rand2)

        sf_t = self.critic.apply(critic_target_params, s2z2, a2) # [n_critic, n_batch, dim]
        q_t = sf_t * z2[None, :, :]  # [n_critic, B, dim]
        q_t = q_t.sum(axis=-1)  # [n_critic, B]

        min_idxs = jnp.argmin(q_t, axis=0)          # [B]
        B = sf_t.shape[1]
        sf_min = sf_t[min_idxs, jnp.arange(B)]     

        sf_backup = r + (1. - done) * self.gamma * sf_min  # [B, dim]

        # critic
        def critic_loss_fn(c_p):
            q = self.critic.apply(c_p, sz, a)
            loss = 0.5 * ((q - sf_backup)**2).mean()
            return loss, loss
        
        grads_c, q_loss = jax.grad(critic_loss_fn, has_aux=True)(critic_params)
        updates_c, opt_critic_state = self.opt_critic.update(grads_c, opt_critic_state)
        critic_params = optax.apply_updates(critic_params, updates_c)
       
        # actor + alpha
        def actor_loss_fn(actor_p, log_a):
            a_pred, logp = self.actor.apply(actor_p, sz, key)

            sf = self.critic.apply(critic_params, sz, a_pred).min(axis=0)
            q = sf * z2[None, :, :]  # [n_critic, B, dim]
            q = q.sum(axis=-1).min(axis=0)  # [B]
            
            alpha = get_alpha(log_a)
            
            loss = (alpha * logp.squeeze() - q).mean() 
            return loss, (alpha, q.mean(), logp.mean())


        def alpha_loss_fn(actor_p, log_a, logp):
            alpha = get_alpha(log_a)
            loss = -alpha * (logp + self.target_entropy).mean()
            return loss

        (actor_loss, (alpha, actor_q_min, logp)), (grads_actor,) = jax.value_and_grad(actor_loss_fn, argnums=(0,), has_aux=True)(
                                                                                                        actor_params, log_alpha)            
        grads_alpha = jax.grad(alpha_loss_fn, argnums=(1))(actor_params, log_alpha, logp)

        updates_a, opt_actor_state = self.opt_actor.update(grads_actor, opt_actor_state)
        updates_alpha, opt_alpha_state = self.opt_alpha.update(grads_alpha, opt_alpha_state)
        actor_params = optax.apply_updates(actor_params, updates_a)
        log_alpha = optax.apply_updates(log_alpha, updates_alpha)

        critic_target_params = optax.incremental_update(critic_params, critic_target_params, self.tau)

        metrics.update({
            'critic_loss': q_loss,
            'phi_dot': mean_pos, 'phi_neg_loss': mean_lse,
            'actor_loss': actor_loss, 'alpha': alpha, 'actor_q_min': actor_q_min, 'actor_logp': logp,
        })

        new_states = (opt_phi_state, opt_actor_state,
                      opt_critic_state, opt_alpha_state, opt_lambda_state)
        new_params = (phi_params, target_phi_params, actor_params,
                critic_params, critic_target_params, log_alpha, log_lambda)


        return new_params, new_states, metrics, rng


    def save_checkpoint(self, epoch, log_dir: str):
        log_dir = Path(log_dir)
        log_dir.mkdir(parents=True, exist_ok=True)

        checkpoint = {
            'rng': self.rng,
            'phi_params': self.phi_params,
            'actor_params': self.actor_params,
            'critic_params': self.critic_params,
            'critic_target_params': self.critic_target_params,
            'log_alpha': self.log_alpha,
            'log_lambda': self.log_lambda,
        }

        with open(log_dir / f'{epoch}_checkpoint.pkl', 'wb') as f:
            pickle.dump(checkpoint, f)
        print(f"[Checkpoint Saved] → {log_dir / 'checkpoint.pkl'}")


    def load_checkpoint(self, epoch, log_dir: str):
        if epoch == -1:
            path = get_latest_checkpoint(log_dir)
        else:
            path = Path(log_dir) / f'{epoch}_checkpoint.pkl'
        with open(path, 'rb') as f:
            checkpoint = pickle.load(f)

        self.rng                     = checkpoint['rng']
        self.phi_params              = checkpoint['phi_params']
        self.actor_params            = checkpoint['actor_params']
        self.critic_params           = checkpoint['critic_params']
        self.critic_target_params    = checkpoint['critic_target_params']
        self.log_alpha               = checkpoint['log_alpha']
        self.log_lambda              = checkpoint['log_lambda']

        print(f"[Checkpoint Loaded] ← {path}")