from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.core import FrozenDict

from minto.networks.architectures.dqn import DQNNet
from minto.sample_collection.replay_buffer import ReplayBuffer, ReplayElement

from flax import struct

@struct.dataclass
class TDErrorNormState:
    count: jnp.ndarray
    mean: jnp.ndarray
    m2: jnp.ndarray
    running_std: jnp.ndarray

def update_td_norm_state(td_errors: jnp.ndarray,
                         state: TDErrorNormState,
                         eps: float = 1e-8) -> TDErrorNormState:
    batch_size = td_errors.size
    if batch_size == 0:
        return state

    batch_mean = jnp.mean(td_errors)
    batch_var = jnp.var(td_errors)
    batch_m2 = batch_var * batch_size

    new_count = state.count + batch_size
    delta = batch_mean - state.mean
    new_mean = state.mean + delta * batch_size / new_count
    new_m2 = state.m2 + batch_m2 + delta**2 * state.count * batch_size / new_count

    running_var = new_m2 / jnp.maximum(new_count, 1.0)
    running_std = jnp.sqrt(jnp.maximum(running_var, eps))

    return TDErrorNormState(
        count=new_count,
        mean=new_mean,
        m2=new_m2,
        running_std=running_std,
    )

def compute_sigma(td_errors: jnp.ndarray,
                  norm_state: TDErrorNormState,
                  eps_min: float = 0.01):
    batch_std = jnp.std(td_errors)
    sigma_running = norm_state.running_std
    sigma = jnp.maximum(jnp.maximum(sigma_running, batch_std), eps_min)
    new_state = update_td_norm_state(td_errors, norm_state)
    return sigma, new_state

class TRDQN:
    """
        Trust-Region DQN.
        Inspired by "Human-level Atari 200x faster" trust-region component A1 (EQ1) and Normalization scheme in B1.
    """
    def __init__(
        self,
        key: jax.random.PRNGKey,
        observation_dim,
        n_actions,
        features: list,
        architecture_type: str,
        learning_rate: float,
        gamma: float,
        update_horizon: int,
        data_to_update: int,
        target_update_frequency: int,
        adam_eps: float = 1e-8,
        layer_norm: bool = False,
        alpha: float = 1.0,
        episilon: float = 0.01,
    ):
        print("Layer Norm?", layer_norm)
        print("TRDQN alpha:", alpha)
        print("TRDQN epsilon:", episilon)
        
        self.network = DQNNet(features, architecture_type, n_actions, layer_norm=layer_norm)
        self.params = self.network.init(
            key, jnp.zeros(observation_dim, dtype=jnp.float32)
        )

        self.optimizer = optax.adam(learning_rate, eps=adam_eps)
        self.optimizer_state = self.optimizer.init(self.params)
        self.target_params = self.params

        self.gamma = gamma
        self.update_horizon = update_horizon
        self.data_to_update = data_to_update
        self.target_update_frequency = target_update_frequency
        self.cumulated_info = {
            "loss": 0,
            "grad_norm": 0,
            "param_norm": 0,
            "online_fraction": 0,
            "q_value": 0,
            "target": 0,
            "churn": 0,
	    "mask_fraction": 0,
            "sigma": 0,
        }
        self.listed_info = {
            "online_fraction_all": []
        }

        # TRDQN specific
        self.alpha = alpha
        self.epsilon = episilon
        self.td_error_norm_state = TDErrorNormState(
            count=jnp.array(0.0),
            mean=jnp.array(0.0),
            m2=jnp.array(0.0),
            running_std=jnp.array(1.0),
        )


    def update_online_params(self, step: int, replay_buffer: ReplayBuffer):
        if step % self.data_to_update == 0:
            batch_samples = replay_buffer.sample()

            (
                self.params,
                self.optimizer_state,
                self.td_error_norm_state,
                info,
            ) = self.learn_on_batch(
                self.params,
                self.target_params,
                self.optimizer_state,
                batch_samples,
                self.td_error_norm_state,
            )

            # cumulate the info
            for k in info.keys():
                if k in self.cumulated_info:
                    self.cumulated_info[k] += info[k]
                if f"{k}_all" in self.listed_info:
                    self.listed_info[f"{k}_all"].append(
                        info[k].item()
                    )


    def update_target_params(self, step: int):
        if step % self.target_update_frequency == 0:
            self.target_params = self.params.copy()
            # average the cumulated info and reset
            logs = {
                k: v / (self.target_update_frequency / self.data_to_update)
                for k, v in self.cumulated_info.items()
            }
            logs.update({k: v for k, v in self.listed_info.items()})

            self.cumulated_info = {k: 0 for k in self.cumulated_info.keys()}
            self.listed_info = {k: [] for k in self.listed_info.keys()}

            return True, logs
        return False, {}

    @partial(jax.jit, static_argnames="self")
    def learn_on_batch(
        self,
        params: FrozenDict,
        params_target: FrozenDict,
        optimizer_state,
        batch_samples: ReplayElement,
        td_norm_state: TDErrorNormState,
    ):
        # loss_on_batch returns (loss, (info, new_td_norm_state))
        (loss, (info, new_td_norm_state)), grad_loss = jax.value_and_grad(
            self.loss_on_batch, has_aux=True
        )(params, params_target, batch_samples, td_norm_state)

        updates, optimizer_state = self.optimizer.update(grad_loss, optimizer_state)
        params = optax.apply_updates(params, updates)

        info.update({"grad_norm": optax.global_norm(grad_loss)})
        info.update({"param_norm": optax.global_norm(params)})
        info.update({"loss": loss})

        return params, optimizer_state, new_td_norm_state, info

    def loss_on_batch(
        self,
        params: FrozenDict,
        params_target: FrozenDict,
        samples: ReplayElement,
        td_norm_state: TDErrorNormState,
    ):
        """
        Batch loss:
        - TD target from ONLINE network (MEME-style)
        - Trust region: |Q_online(s,a) - Q_target(s,a)| > alpha * sigma
        - sigma from batch TD-errors with running normalization
        """
        # Convert numpy from replay buffer to jax arrays (defensive, jit-friendly)
        states      = jnp.asarray(samples.state)        # [B, ...]
        next_states = jnp.asarray(samples.next_state)   # [B, ...]
        actions     = jnp.asarray(samples.action)       # [B]
        rewards     = jnp.asarray(samples.reward)       # [B]
        terminals   = jnp.asarray(samples.is_terminal)  # [B], bool

        # Q(s, ·) for online and target at current states
        q_online_s = self.network.apply(params, states)          # [B, A]
        q_target_s = self.network.apply(params_target, states)   # [B, A]

        # Q(s', ·) for online at next states (for bootstrapping)
        q_online_sp = jax.lax.stop_gradient(self.network.apply(params, next_states))    # [B, A]

        # Q(s, a_t) from online and target
        qa_online = jnp.take_along_axis(
            q_online_s, actions[:, None], axis=1
        ).squeeze(-1)  # [B]

        qa_target = jnp.take_along_axis(
            q_target_s, actions[:, None], axis=1
        ).squeeze(-1)  # [B]

        # Online bootstrap target: y = r + γ^n (1 - terminal) max_a' Q_online(s', a')
        q_next_online = jnp.max(q_online_sp, axis=1)  # [B]

        targets = (
            rewards
            + (1.0 - terminals.astype(jnp.float32))
            * (self.gamma ** self.update_horizon)
            * q_next_online
        )  # [B]

        # TD-errors using ONLINE bootstrap
        td_errors = targets - qa_online  # [B]

        # Compute σ using MEME-style normalization (batch + running std)
        sigma, new_td_norm_state = compute_sigma(td_errors, td_norm_state, eps_min=self.epsilon)

        # Trust-region mask: equation (1) only
        # |Q_online(s,a) - Q_target(s,a)| > alpha * sigma  => mask out
        diff = jnp.abs(qa_online - qa_target)           # [B]
        keep_mask = diff <= (self.alpha * sigma)        # True => keep

        # Normalized TD-errors
        norm_td = td_errors  #/ sigma                     # [B]

        # Masked MSE over normalized TD-errors
        per_sample_loss = jnp.square(norm_td)           # [B]
        masked_loss = per_sample_loss * keep_mask.astype(per_sample_loss.dtype)
        denom = jnp.maximum(jnp.sum(keep_mask), 1.0)    # avoid divide-by-zero
        loss = jnp.sum(masked_loss) / denom             # scalar

        # --- Logging info over batch ---

        # Policy churn per sample (we already have Q_online and Q_target)
        pis_online = jax.nn.softmax(q_online_s, axis=-1)
        pis_target = jax.nn.softmax(q_target_s, axis=-1)
        churn_per_sample = 0.5 * jnp.sum(jnp.abs(pis_online - pis_target), axis=-1)  # [B]

        info = {
            # In this simplified MEME-style setup, we always bootstrap from online.
            "online_fraction": 1.0,
            "q_value": qa_online.mean(),
            "target": targets.mean(),
            "churn": churn_per_sample.mean(),
            # some extra useful metrics:
            "mask_fraction": keep_mask.mean(),
            "sigma": sigma,
        }

        # aux for value_and_grad: (info, new_td_norm_state)
        return loss, (info, new_td_norm_state)

    @partial(jax.jit, static_argnames="self")
    def best_action(
        self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKey
    ):
        # computes the best action for a single state
        return jnp.argmax(self.network.apply(params, state))

    def get_model(self):
        return {"params": self.params}
