from functools import partial
import copy
from typing import Any, Optional, Union, Tuple

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.core import FrozenDict
from flax.struct import PyTreeNode

import distrax

from wsrl.common.common import JaxRLTrainState, ModuleDict, nonpytree_field
from wsrl.common.typing import Batch, PRNGKey
from wsrl.networks.actor_critic_nets import Policy, Critic, ensemblize
from wsrl.networks.mlp import MLP


class BCAgent(flax.struct.PyTreeNode):
    state: JaxRLTrainState
    lr_schedule: Any = nonpytree_field()
    fqe_config: dict = nonpytree_field()

    @partial(jax.jit, static_argnames="pmap_axis")
    def update(self, batch: Batch, pmap_axis: str = None):
        def actor_loss_fn(params, rng):
            rng, key = jax.random.split(rng)
            dist = self.state.apply_fn(
                {"params": params},
                batch["observations"],
                temperature=1.0,
                train=True,
                rngs={"dropout": key},
                name="actor",
            )
            pi_actions = dist.mode()
            log_probs = dist.log_prob(batch["actions"])
            mse = ((pi_actions - batch["actions"]) ** 2).sum(-1)
            actor_loss = -(log_probs).mean()
            actor_std = dist.stddev().mean(axis=1)

            return actor_loss, {
                "actor_loss": actor_loss,
                "mse": mse.mean(),
                "entropy": -dist.log_prob(pi_actions).mean(),
                "log_probs": log_probs,
                "pi_actions": pi_actions,
                "mean_std": actor_std.mean(),
                "max_std": actor_std.max(),
            }

        # actor-only update (critic frozen)
        loss_fns = {
            "actor": actor_loss_fn,
            "critic": lambda p, r: (0.0, {}),
        }

        new_state, info = self.state.apply_loss_fns(
            loss_fns, pmap_axis=pmap_axis, has_aux=True
        )

        # log learning rates
        info["lr"] = self.lr_schedule(self.state.step)

        return self.replace(state=new_state), info

    @partial(jax.jit, static_argnames="argmax")
    def sample_actions(
        self,
        observations: np.ndarray,
        *,
        seed: Optional[PRNGKey] = None,
        temperature: float = 1.0,
        argmax=False,
    ) -> jnp.ndarray:
        dist = self.forward_policy(
            observations,
            rng=seed,
            train=False
        )
        if argmax:
            assert seed is None, "Cannot specify seed when sampling deterministically"
            actions = dist.mode()
        else:
            actions = dist.sample(seed=seed)
        return actions

    def forward_policy(
        self,
        observations: np.ndarray,
        rng: Optional[PRNGKey] = None,
        *,
        grad_params: Optional[dict] = None,  # Replaced Params with dict
        train: bool = True,
        get_intermediates: bool = False,
    ) -> Union[distrax.Distribution, Tuple[distrax.Distribution, dict]]:
        """
        Forward pass for policy network.
        Pass grad_params to use non-default parameters (e.g. for gradients).
        """
        if train:
            assert rng is not None, "Must specify rng when training"
        if get_intermediates:
            dist, intermediate = self.state.apply_fn(
                {"params": grad_params or self.state.params},
                observations,
                name="actor",
                rngs={"dropout": rng} if train else {},
                train=train,
                capture_intermediates=True
            )
            return dist, intermediate
        else:
            dist = self.state.apply_fn(
                {"params": grad_params or self.state.params},
                observations,
                name="actor",
                rngs={"dropout": rng} if train else {},
                train=train,
                capture_intermediates=False
            )
            return dist

    def forward_critic(
        self,
        observations: np.ndarray,
        actions: jnp.ndarray,
        rng: Optional[PRNGKey] = None,
        *,
        grad_params: Optional[dict] = None,
        train: bool = True,
        get_intermediates: bool = False,
    ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, dict]]:
        if train:
            assert rng is not None, "Must specify rng when training"
        if get_intermediates:
            qs, intermediate = self.state.apply_fn(
                {"params": grad_params or self.state.params},
                observations,
                actions,
                name="critic",
                rngs={"dropout": rng} if train else {},
                train=train,
                capture_intermediates=True,
            )
            return qs, intermediate
        else:
            qs = self.state.apply_fn(
                {"params": grad_params or self.state.params},
                observations,
                actions,
                name="critic",
                rngs={"dropout": rng} if train else {},
                train=train,
                capture_intermediates=False,
            )
            return qs

    def forward_target_critic(
        self,
        observations: np.ndarray,
        actions: jnp.ndarray,
    ) -> jnp.ndarray:
        return self.forward_critic(
            observations, actions, train=False, grad_params=self.state.target_params
        )

    @jax.jit
    def get_debug_metrics(self, batch, **kwargs):
        dist, _ = self.forward_policy(
            batch["observations"],
            train=False,
            get_intermediates=True
        )
        pi_actions = dist.mode()
        log_probs = dist.log_prob(batch["actions"])
        mse = ((pi_actions - batch["actions"]) ** 2).sum(-1)

        return {
            "mse": mse,
            "log_probs": log_probs,
            "pi_actions": pi_actions,
        }

    # -------------------- FQE: Critic for fixed BC policy --------------------
    @partial(jax.jit, static_argnames=("deterministic_target", "pmap_axis"))
    def update_critic(
        self,
        batch: Batch,
        *,
        deterministic_target: bool = True,
        pmap_axis: Optional[str] = None,
    ):
        discount = self.fqe_config["discount"]
        tau = self.fqe_config["target_update_rate"]
        critic_ensemble_size = self.fqe_config["critic_ensemble_size"]
        critic_subsample_size = self.fqe_config.get("critic_subsample_size")

        def critic_loss_fn(params, rng):
            # Next action from frozen BC actor
            rng, key = jax.random.split(rng)
            dist = self.state.apply_fn(
                {"params": self.state.params},
                batch["next_observations"],
                name="actor",
                rngs={},
                train=False,
            )
            next_actions = dist.mode() if deterministic_target else dist.sample(seed=key)
            next_actions = jax.lax.stop_gradient(next_actions)

            # Target Q
            target_next_qs = self.state.apply_fn(
                {"params": self.state.target_params},
                batch["next_observations"],
                next_actions,
                name="critic",
                rngs={},
                train=False,
            )
            # Optional REDQ-style subsampling for target aggregation
            if (
                critic_subsample_size is not None
                and 0 < critic_subsample_size < critic_ensemble_size
            ):
                rng, subsample_key = jax.random.split(rng)
                subsample_idcs = jax.random.choice(
                    subsample_key,
                    critic_ensemble_size,
                    (critic_subsample_size,),
                    replace=False,
                )
                target_subset = target_next_qs[subsample_idcs]
            else:
                target_subset = target_next_qs

            if critic_ensemble_size > 1:
                target_next_q = target_subset.min(axis=0)
            else:
                target_next_q = target_next_qs
            target_q = batch["rewards"] + discount * batch["masks"] * target_next_q

            # Current Q
            qs = self.state.apply_fn(
                {"params": params},
                batch["observations"],
                batch["actions"],
                name="critic",
                rngs={"dropout": rng},
                train=True,
            )
            if critic_ensemble_size > 1:
                loss = jnp.mean((qs - target_q[None]) ** 2)
                q_mean = jnp.mean(qs)
            else:
                loss = jnp.mean((qs - target_q) ** 2)
                q_mean = jnp.mean(qs)

            info = {
                "critic_loss": loss,
                "q": q_mean,
                "target_q": jnp.mean(target_q),
            }
            return loss, info

        loss_fns = {
            "actor": lambda p, r: (0.0, {}),  # freeze actor
            "critic": critic_loss_fn,
        }
        new_state, info = self.state.apply_loss_fns(
            loss_fns, pmap_axis=pmap_axis, has_aux=True
        )
        new_state = new_state.target_update(tau)
        return self.replace(state=new_state), info

    # Backward compatible alias
    update_fqe = update_critic

    @partial(jax.jit, static_argnames=("pmap_axis", "networks_to_update"))
    def get_metrics(
        self,
        batch: Batch,
        *,
        pmap_axis: str = None,
        networks_to_update: frozenset[str] = frozenset({"actor"}),
    ) -> dict:
        """
        Compute metrics for the current state of the agent without updating parameters.
        """
        metrics = {}
        return metrics

    @classmethod
    def create(
        cls,
        rng: PRNGKey,
        observations: FrozenDict,
        actions: jnp.ndarray,
        # Model architecture
        encoder_def: nn.Module,
        shared_encoder: bool = True,
        # Backward-compat: if policy_network_kwargs is None, fall back to network_kwargs
        network_kwargs: dict = {
            "hidden_dims": [256, 256],
        },
        policy_network_kwargs: dict = {
            "hidden_dims": [256, 256, 256],
        },
        policy_kwargs: dict = {
            "tanh_squash_distribution": True,
            "std_parameterization": "fixed",
            "fixed_std": 0.0,
        },
        critic_network_kwargs: dict = {
            "hidden_dims": [256, 256, 256],
        },
        critic_ensemble_size: int = 10,
        critic_subsample_size: Optional[int] = 2,
        # Optimizers
        actor_optimizer_kwargs={
            "learning_rate": 3e-4,
        },
        critic_optimizer_kwargs={
            "learning_rate": 3e-4,
        },
        # Actor LR schedule (for policy)
        #learning_rate: float = 3e-4,
        warmup_steps: int = 1000,
        decay_steps: int = 1000000,
        # FQE/critic config
        discount: float = 0.99,
        target_update_rate: float = 0.005,
        **kwargs,
    ):
        # Choose policy network kwargs
        if policy_network_kwargs is None:
            policy_network_kwargs = dict(network_kwargs)
        policy_network_kwargs["activate_final"] = True

        # Encoders
        if shared_encoder:
            encoders = {
                "actor": encoder_def,
                "critic": encoder_def,
            }
        else:
            encoders = {
                "actor": encoder_def,
                "critic": copy.deepcopy(encoder_def) if encoder_def is not None else None,
            }

        # Networks
        policy_def = Policy(
            encoders["actor"],
            MLP(**policy_network_kwargs),
            action_dim=actions.shape[-1],
            **policy_kwargs,
        )

        critic_backbone = ensemblize(partial(MLP, **critic_network_kwargs), critic_ensemble_size)(
            name="critic_ensemble"
        )
        critic_def = partial(
            Critic,
            encoder=encoders["critic"],
            network=critic_backbone,
        )(name="critic")

        networks = {
            "actor": policy_def,
            "critic": critic_def,
        }

        model_def = ModuleDict(networks)

        # Optimizers
        actor_lr_schedule = optax.warmup_cosine_decay_schedule(
            init_value=0.0,
            peak_value=actor_optimizer_kwargs["learning_rate"],
            warmup_steps=warmup_steps,
            decay_steps=decay_steps,
            end_value=0.0,
        )
        txs = {
            "actor": optax.adam(actor_lr_schedule),
            "critic": optax.adam(**critic_optimizer_kwargs),
        }

        rng, init_rng = jax.random.split(rng)
        params = model_def.init(
            init_rng,
            actor=[observations],
            critic=[observations, actions],
        )["params"]

        rng, create_rng = jax.random.split(rng)
        state = JaxRLTrainState.create(
            apply_fn=model_def.apply,
            params=params,
            txs=txs,
            target_params=params,
            rng=create_rng,
        )

        fqe_config = flax.core.FrozenDict(
            dict(
                discount=discount,
                target_update_rate=target_update_rate,
                critic_ensemble_size=critic_ensemble_size,
                critic_subsample_size=critic_subsample_size,
            )
        )

        lr_schedule = actor_lr_schedule

        return cls(state, lr_schedule, fqe_config)
