from typing import Sequence, Tuple, Dict
from functools import partial

from rpp import flax
import distrax
from flax import struct
from flax.training import train_state
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
import numpy as np

from emlp.reps import Rep
from emlp.groups import Group

from jax_rl.networks.pe_mlp import PEMLP, PMLP


@struct.dataclass
class RunningMeanStd:
    count: jnp.ndarray
    mean: jnp.ndarray
    M2: jnp.ndarray
    eps: jnp.ndarray
    scale: jnp.ndarray
    
    k_std: jnp.ndarray
    beta: jnp.ndarray #EMA for threshold
    thr_update_steps: jnp.ndarray
    
    cur_threshold_raw: jnp.ndarray
    
    @staticmethod
    def init(scale:float=1., k_std: float = 1.5, eps: float=1e-8, 
             beta=0.1, thr_update_steps=1000, dtype=jnp.float32):        
        return RunningMeanStd(
            count=jnp.array(0., dtype),
            mean=jnp.array(0., dtype),
            M2=jnp.array(0., dtype),
            eps=jnp.array(eps, dtype),
            scale=jnp.array(scale, dtype),
            k_std=jnp.array(k_std, dtype),
            beta=jnp.array(beta, dtype),
            thr_update_steps=jnp.array(thr_update_steps, jnp.int32),
            cur_threshold_raw=jnp.array(jnp.nan, dtype)
        )
        
    def update(self, raw_values:jnp.ndarray):
        """raw values: d = ||pA - pB||^2, shape (B,), use x = log1p(scale*d)"""
        x = jnp.ravel(jnp.asarray(raw_values, jnp.float32))
        x = jnp.log1p(self.scale * x)

        nB = jnp.array(x.shape[0], dtype=self.count.dtype)
        bx= jnp.mean(x)
        bvar = jnp.mean((x - bx)**2)
        bM2 = bvar * nB        
        delta = bx - self.mean
        tot = self.count + nB        
        newMean = self.mean + delta * (nB / jnp.maximum(tot, self.eps))
        newM2 = self.M2 + bM2 + delta**2 * (self.count * nB / jnp.maximum(tot, self.eps))
        return self.replace(count=tot, mean=newMean, M2=newM2)
    
    @property
    def var(self):
        return jnp.maximum(self.M2 / jnp.maximum(self.count, self.eps), self.eps)
    
    def _raw_threshold(self):
        """ map mu + k * std to raw disagreement"""
        thr_soft = self.mean + self.k_std * jnp.sqrt(self.var)
        return jnp.maximum(jnp.expm1(thr_soft) / self.scale, 1e-6)
    
    def get_threshold(self):
        cur = self.cur_threshold_raw
        return jnp.where(jnp.isnan(cur), self._raw_threshold(), cur)    
    
    def adjust_k(self, pos_rate: jnp.ndarray,
                 target:float=0.15, eta:float=0.05,
                 k_min:float=0.5, k_max:float=3.0):
        raise NotImplementedError
        
    def refresh_threshold(self, step):
        return jax.lax.cond(
            step % self.thr_update_steps == 0,
            lambda s: s._threshold_update(),
            lambda s: s,
            operand=self
        )
        
    def _threshold_update(self):
        now_raw = self._raw_threshold()
        prev = self.cur_threshold_raw
        
        prev_log = jnp.log(jnp.where(jnp.isnan(prev), now_raw, prev))
        now_log = jnp.log(now_raw)
        new_log = jnp.where(jnp.isnan(prev), now_log,
                            (1.0 - self.beta) * prev_log + self.beta * now_log)
        new_raw = jnp.exp(new_log)
        lo = jnp.exp(prev_log - 0.051293) # - 5%
        hi = jnp.exp(prev_log + 0.048790) # + 5%
        new_raw = jnp.where(jnp.isnan(prev), now_raw, jnp.clip(new_raw, lo, hi))
        
        # new_raw = jnp.where(jnp.isnan(prev), now_raw,
        #                     (1.0 - self.beta) * prev + self.beta * now_raw)
        return self.replace(cur_threshold_raw=new_raw)


class TransitionPredictor(nn.Module):
    hidden:Sequence[int]
    output_dim:int
    state_transform:callable = lambda x: x
    action_transform:callable = lambda x: x
    small_init:bool = True
    bias_value:float = 0.0
    
    @nn.compact
    def __call__(self, observations:jnp.ndarray, actions:jnp.ndarray) -> jnp.ndarray:
        state = self.state_transform(observations)
        actions = self.action_transform(actions)
        x = jnp.concatenate([state, actions], axis=-1)

        # return MLP((*self.hidden, self.output_dim))(x)
        return PMLP(self.hidden, self.output_dim, small_init=self.small_init, bias_value=self.bias_value)(x)

class EquiTransitionPredictor(nn.Module):
    rep_in: Rep
    rep_out: Rep
    group: Group
    ch: Sequence[int]
    state_transform:callable = lambda x: x
    action_transform:callable = lambda x: x
    small_init:bool = True
    
    @nn.compact
    def __call__(self, observations:jnp.ndarray, actions:jnp.ndarray) -> jnp.ndarray:
        state = self.state_transform(observations)
        actions = self.action_transform(actions)
        x = jnp.concatenate([state, actions], axis=-1)
        return PEMLP(self.rep_in, self.rep_out, self.group, self.ch, small_init=self.small_init)(x)


class LambdaQNet(nn.Module):
    hidden:Sequence[int]
    state_transform:callable = lambda x: x
    action_transform:callable = lambda x: x
    small_init:bool = True
    bias_value:float=0.0
    
    @nn.compact
    def __call__(self, observations:jnp.ndarray, actions:jnp.ndarray) -> jnp.ndarray:
        s = self.state_transform(observations)
        a = self.action_transform(actions)
        x = jnp.concatenate([s, a], axis=-1)

        # logits = MLP((*self.hidden, 1))(x).squeeze(-1)
        # logits = logits + self.bias_value
        logits = PMLP(self.hidden, 1, small_init=self.small_init, bias_value=self.bias_value)(x).squeeze(-1) 
        return logits

class LambdaPiNet(nn.Module):
    hidden:Sequence[int]
    state_transform:callable = lambda x: x
    small_init:bool = True
    bias_value:float=0.0
    
    @nn.compact
    def __call__(self, observations:jnp.ndarray) -> jnp.ndarray:
        s = self.state_transform(observations)

        # logits = MLP((*self.hidden, 1))(s).squeeze(-1)
        # logits = logits + self.bias_value
        logits = PMLP(self.hidden, 1, small_init=self.small_init, bias_value=self.bias_value)(s).squeeze(-1)
        return logits
