from models.common import *
import numpy as np
from torch.distributions import Categorical, Normal, Independent, MixtureSameFamily

import torch
import torch.nn as nn
import torch.nn.functional as F
# Stability bounds for log standard deviation
# LOG_STD_MIN = -20
# LOG_STD_MAX = 2


LOG_STD_MIN = np.log(0.1) #-20.0
LOG_STD_MAX = np.log(0.8) #2.0
print(f"{RED} {LOG_STD_MIN}---{LOG_STD_MAX} {ENDC}")

class FastMixturePolicy(nn.Module):
    """
    - Enforces a minimum mixing prob per component
    - Uses Gumbel-Softmax for soft selection
    - Applies exact tanh log-Jacobian correction
    """
    def __init__(
        self, obs_dim, act_dim, hidden_sizes, activation,
        act_limit, n_components=5, dropout_rate=0.1, gumbel_tau=1.0
    ):
        super().__init__()
        self.act_dim      = act_dim
        self.act_limit    = act_limit
        self.n_components = n_components
        self.gumbel_tau   = gumbel_tau

        # Shared encoder
        layers = []
        in_size = obs_dim
        for size in hidden_sizes:
            layers += [
                nn.Linear(in_size, size),
                nn.LayerNorm(size),
                activation(),
                nn.Dropout(dropout_rate)
            ]
            in_size = size
        self.encoder = nn.Sequential(*layers)

        # Heads
        self.mu_head     = nn.Linear(in_size, n_components * act_dim)
        self.logstd_head = nn.Linear(in_size, n_components * act_dim)
        self.mix_head    = nn.Linear(in_size, n_components)

    def forward(self, obs, deterministic=False, with_logprob=True,
                manual_indices=None, bias_config=None):
        x = self.encoder(obs)                  # [B, H]
        B = x.size(0)

        # Component params
        mus = self.mu_head(x).view(B, self.n_components, self.act_dim)

        logstd = torch.clamp(
            self.logstd_head(x), LOG_STD_MIN, LOG_STD_MAX
        ).view(B, self.n_components, self.act_dim)
        stds = torch.exp(logstd)

        # Optional std bias
        if bias_config and "std" in bias_config:
            print(f"Applying std bias: {bias_config}")
            stds = torch.clamp(stds, max=bias_config["std"], min=1e-6)

        # Mixing weights
        logits      = self.mix_head(x)
        p_raw       = F.softmax(logits, dim=-1)
        min_p       = 0.5 / self.n_components 
        p_clamped   = torch.clamp(p_raw, min=min_p)
        mixing_probs = p_clamped / p_clamped.sum(-1, keepdim=True)
        # Sample / select index
        if not deterministic:
            g = F.gumbel_softmax(torch.log(mixing_probs + 1e-8),
                                 tau=self.gumbel_tau, hard=False)
            indices = g.argmax(-1)
        else:
            indices = mixing_probs.argmax(-1)

        if manual_indices is not None:
            indices = manual_indices
        # Gather selected component
        idxs = torch.arange(B, device=obs.device)
        sel_mu  = mus[idxs, indices]
        sel_std = stds[idxs, indices]
        # Raw action
        if deterministic:
            pi_action = sel_mu
        else:
            pi_action = Normal(sel_mu, sel_std).rsample()
        
        # Squash+scale
        a = torch.tanh(pi_action) * self.act_limit

        if not with_logprob:
            info = dict(mixing_logits=logits,
                        mixing_probs=mixing_probs,
                        indices=indices,
                        selected_mu=sel_mu,
                        selected_std=sel_std,
                        all_mus=mus,
                        all_stds=stds)
            return a, None, None, pi_action, sel_mu, sel_std, info

        # Mixture log-prob before squash
        comp_dist = Independent(Normal(mus, stds), 1)
        mix_dist  = MixtureSameFamily(
            mixture_distribution=Categorical(probs=mixing_probs),
            component_distribution=comp_dist
        )
        logp_raw = mix_dist.log_prob(pi_action)                   # [B]

        # Tanh Jacobian correction
        tanh_corr = (
            2 * (np.log(2) - pi_action - F.softplus(-2 * pi_action))
        ).sum(-1)                                                 # [B]

        logp_mixture   = logp_raw - tanh_corr

        # Component log-prob
        sel_dist       = Independent(Normal(sel_mu, sel_std), 1)
        logp_comp_raw  = sel_dist.log_prob(pi_action)             # [B]
        logp_component = logp_comp_raw - tanh_corr

        mix_mean = mix_dist.mean
        mix_var  = mix_dist.variance
        info = dict(mixing_logits=logits,
                    mixing_probs=mixing_probs,
                    indices=indices,
                    selected_mu=sel_mu,
                    selected_std=sel_std,
                    mix_mean=mix_mean,
                    mix_var=mix_var,
                    all_mus=mus,
                    all_stds=stds,
                    p_raw=p_raw,
                    logp_comp_raw=logp_comp_raw,
                    logp_component=logp_component,
                    logp_raw=logp_raw, 
                    logp_mixture=logp_mixture,)

        return a, logp_mixture, logp_component, pi_action, sel_mu, sel_std, info

    def set_training_mode(self, mode: bool):
        """Set training mode for SB3 compatibility"""
        self.train(mode)
        return self

class FastMixturePolicyActorCritic(nn.Module):
    """
    Actor-critic architecture using the MixturePolicy as the actor.
    Includes surrogate value functions for each mixture component to
    estimate V_k(s) = E[Q(s,a) - log π^m(a|s) + α * log π_k^b(a|s)].
    """
    def __init__(
        self,
        obs_dim,
        act_dim,
        act_limit,
        hidden_sizes=(256, 256),
        activation=nn.ReLU,
        alpha=0.2,  # Entropy temperature parameter
        dropout_rate=0.1,
        n_components=5
    ):
        super().__init__()
        
        
        logger.info(f"obs_dim: {obs_dim}")
        
        # Build policy and value functions
        self.pi = FastMixturePolicy(
            obs_dim,
            act_dim,
            hidden_sizes,
            activation,
            act_limit,
            n_components=n_components,
            dropout_rate=dropout_rate,
        )
        
        # Q-functions for critic
        self.q1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.q2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)
        
    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, *_ = self.pi(obs, deterministic, False)
            return a.cpu().numpy()

    def act_extended(self, obs, deterministic=False, bias_config=None, **kwargs):
        with torch.no_grad():
            a, logp_mixture, logp_component, pi, mu, std, pi_info = self.pi(
                obs,
                deterministic=deterministic,
                with_logprob=True
            )
            
            return {
                "a": a.cpu().numpy(),
                "logp_mixture": logp_mixture.cpu().numpy(),
                "logp_a": logp_mixture.cpu().numpy(),
                "logp_component": logp_component.cpu().numpy(),
                "pi": pi.cpu().numpy(),
                "mu": mu.cpu().numpy(),
                "std": std.cpu().numpy(),
                "activation_outputs": None,
                "mixing_logits": pi_info["mixing_logits"].cpu().numpy(),
                "mixing_probs": pi_info["mixing_probs"].cpu().numpy(),
                "indices": pi_info["indices"].cpu().numpy(),
            }

