import jax, distrax
import jax.numpy as jnp
from typing import Any, Tuple, Sequence, Optional
from jax.scipy.stats.norm import logpdf
from jaxrl_m.typing import *
import jax.tree_util as tree_util
from flax import traverse_util
from jaxrl_m.dataset import Dataset
import numpy as np

def get_params_shape(params):
    return tree_util.tree_map(lambda x: x.shape, params)

def merge_batch(b1:Batch, b2:Batch)->Batch:
    def merge_dict(d1,d2):
        new_data = jnp.concatenate([d1,d2],axis=0)
        return new_data
        
    return tree_util.tree_map(merge_dict, b1, b2)

####################################
# ensemble dynamics training utils 
####################################


def get_log_prob(state, mean, std):
    distribution = distrax.MultivariateNormalDiag(loc=mean, scale_diag=std)
    state_log_prob = distribution.log_prob(state)
    return state_log_prob

def get_log_prob_jnp(state, mean, std):
    d = state.shape[-1]
    var = std ** 2
    # log(|Σ|)=∑log(σ²)
    log_det = jnp.sum(jnp.log(var), axis=-1)
    diff = state - mean
    # (x-μ)²/σ²
    mahalanobis = jnp.sum((diff ** 2) / var, axis=-1)
    log_prob = -0.5 * (d * jnp.log(2 * jnp.pi) + log_det + mahalanobis)
    return log_prob

def sample_from_norm(
    means: jnp.ndarray,
    log_stds: jnp.ndarray,
    key: PRNGKey,
    temperature: float = 1.0,
) -> jnp.ndarray:
    scaled_stds = jnp.exp(log_stds) * temperature
    samples = means + scaled_stds * temperature * jax.random.normal(
        key,
        shape=means.shape,
    )
    return samples


def decay_loss(params, weight_decay:float=1.0):
    flat_params = traverse_util.flatten_dict(params)
    loss = 0.0
    for path, param in flat_params.items():
        if "kernel" in path[-1]:
            loss += jnp.sum(param ** 2)
    return weight_decay * loss

# weighted mse loss
def msew_loss(
    pred_mean: jnp.ndarray, pred_logstd: jnp.ndarray, gt: jnp.ndarray
) -> jnp.ndarray:
    pred_logvar = 2 * pred_logstd
    weighted_mse = jnp.square(pred_mean - gt) * jnp.exp(-pred_logvar)
    return jnp.mean(jnp.mean(weighted_mse, axis=-1))

def var_loss(pred_logstd: jnp.ndarray) -> jnp.ndarray:
    pred_logvar = 2 * pred_logstd
    return jnp.mean(jnp.mean(pred_logvar, axis=-1))

def nll_loss(pred_means: jnp.ndarray, pred_logstds, gt: jnp.ndarray) -> jnp.ndarray:
    return -logpdf(gt, pred_means, jnp.exp(2 * pred_logstds)).mean()

def l1_loss(pred: jnp.ndarray, gt: jnp.ndarray) -> jnp.ndarray:
    return jnp.mean(jnp.abs(pred - gt))

def l2_loss(pred: jnp.ndarray, gt: jnp.ndarray) -> jnp.ndarray:
    return jnp.mean(jnp.square(pred - gt))

def soft_clip(
    x: jnp.ndarray,
    min_val: Optional[float] = None,
    max_val: Optional[float] = None,
) -> jnp.ndarray:
    """
    平滑版的 clip 函数，利用 softplus 保持梯度
    """
    if max_val is not None:
        x = max_val - jax.nn.softplus(max_val - x)
    if min_val is not None:
        x = min_val + jax.nn.softplus(x - min_val)
    return x


##############################################
# reward tuning
##############################################

def get_iql_normalization(dataset):
        returns = []
        ret = 0
        for r, term in zip(dataset['rewards'], dataset['dones_float']):
            ret += r
            if term:
                returns.append(ret)
                ret = 0
        return (max(returns) - min(returns)) / 1000
    
def get_tuned_dataset(dataset, reward_tune):
    if reward_tune == 'no':
        pass
    elif reward_tune == 'iql_locomotion':
        normalizing_factor = get_iql_normalization(dataset)
        dataset = dataset.copy({'rewards': dataset['rewards'] / normalizing_factor})
    elif reward_tune == 'normalize':
        mean = dataset['rewards'].mean()
        std = dataset['rewards'].std()
        dataset = dataset.copy({'rewards': (dataset['rewards'] - mean) / std})
    elif reward_tune == 'cql_antmaze':
        dataset = dataset.copy({'rewards': (dataset['rewards']-0.5 )*4.0})
    elif reward_tune == 'iql_antmaze':
        dataset = dataset.copy({'rewards': dataset['rewards']-1.0})
    elif reward_tune == 'antmaze':
        dataset = dataset.copy({'rewards': (dataset['rewards']-0.25)*2.0})
    return dataset


def split_dataset(dataset:Dataset,holdout_ratio:float):
    n=dataset['rewards'].shape[0]

    indices = np.random.permutation(n)
    shuffled_dataset = dataset.copy({k: v[indices] for k, v in dataset.items()})
    
    split_idx = int(n * (1 - holdout_ratio))
    train_dataset=dataset.copy({k:v[:split_idx] for k,v in shuffled_dataset.items()})
    holdout_dataset=dataset.copy({k:v[split_idx:] for k,v in shuffled_dataset.items()})
    return train_dataset, holdout_dataset

##############################################
# termination fn
##############################################

def termination_fn_halfcheetah(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    not_done = jnp.logical_and(jnp.all(next_obs > -100, axis=-1), jnp.all(next_obs < 100, axis=-1))
    done = ~not_done
    done = done[:, None]
    return done

def termination_fn_neorl_halfcheetah(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = jnp.array([False] * obs.shape[0])
    done = done[:, None]
    return done

def termination_fn_hopper(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    height = next_obs[:, 0]
    angle = next_obs[:, 1]
    not_done =  jnp.isfinite(next_obs).all(axis=-1) \
                    * jnp.abs(next_obs[:,1:] < 100).all(axis=-1) \
                    * (height > .7) \
                    * (jnp.abs(angle) < .2)

    done = ~not_done
    done = done[:,None]
    return done

def termination_fn_neorl_hopper(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    z = next_obs[:, 1:2]
    angle = next_obs[:, 2:3]
    state = next_obs[:, 3:]

    min_state, max_state = (-100.0, 100.0)
    min_z, max_z = (0.7, float('inf'))
    min_angle, max_angle = (-0.2, 0.2)

    healthy_state = jnp.all(jnp.logical_and(min_state < state, state < max_state), axis=-1, keepdims=True)
    healthy_z = jnp.logical_and(min_z < z, z < max_z)
    healthy_angle = jnp.logical_and(min_angle < angle, angle < max_angle)

    is_healthy = jnp.logical_and(jnp.logical_and(healthy_state, healthy_z), healthy_angle)

    done = jnp.logical_not(is_healthy).reshape(-1, 1)
    return done

def termination_fn_halfcheetahveljump(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = jnp.array([False]).repeat(len(obs))
    done = done[:,None]
    return done

def termination_fn_walker2d(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    height = next_obs[:, 0]
    angle = next_obs[:, 1]
    not_done =  jnp.logical_and(jnp.all(next_obs > -100, axis=-1), jnp.all(next_obs < 100, axis=-1)) \
                * (height > 0.8) \
                * (height < 2.0) \
                * (angle > -1.0) \
                * (angle < 1.0)
    done = ~not_done
    done = done[:,None]
    return done

def termination_fn_neorl_walker2d(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    min_z, max_z = (0.8, 2.0)
    min_angle, max_angle = (-1.0, 1.0)
    min_state, max_state = (-100.0, 100.0)
    
    z = next_obs[:, 1:2]
    angle = next_obs[:, 2:3]
    state = next_obs[:, 3:]
    
    healthy_state = jnp.all(jnp.logical_and(min_state < state, state < max_state), axis=-1, keepdims=True)
    healthy_z = jnp.logical_and(min_z < z, z < max_z)
    healthy_angle = jnp.logical_and(min_angle < angle, angle < max_angle)
    is_healthy = jnp.logical_and(jnp.logical_and(healthy_state, healthy_z), healthy_angle)
    done = jnp.logical_not(is_healthy).reshape(-1, 1)
    return done

def termination_fn_pen(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    obj_pos = next_obs[:, 24:27]
    done = obj_pos[:, 2] < 0.075

    done = done[:,None]
    return done

def terminaltion_fn_door(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = jnp.array([False] * obs.shape[0])

    done = done[:, None]
    return done

def terminaltion_fn_hammer(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = jnp.array([False] * obs.shape[0])

    done = done[:, None]
    return done

def get_termination_fn(task):
    if 'halfcheetahvel' in task:
        return termination_fn_halfcheetahveljump
    elif 'halfcheetah' in task:
        return termination_fn_halfcheetah
    elif 'HalfCheetah-v3' in task:
        return termination_fn_neorl_halfcheetah
    elif 'hopper' in task:
        return termination_fn_hopper
    elif 'Hopper-v3' in task:
        return termination_fn_neorl_hopper
    elif 'walker2d' in task:
        return termination_fn_walker2d
    elif 'Walker2d-v3' in task:
        return termination_fn_neorl_walker2d
    elif 'pen' in task:
        return termination_fn_pen
    elif 'door' in task:
        return terminaltion_fn_door
    elif 'hammer' in task:
        return terminaltion_fn_hammer
    else:
        raise jnp.zeros