from models.common import *

# Uncorrected Gaussian Actor
class SquashedGaussianMLPActor(nn.Module):

    def __init__(self, obs_dim, act_dim, d=256, activation=nn.ReLU, act_limit=1.0):
        super().__init__()
        d = 256
        self.hl1 = nn.Linear(obs_dim, d)
        self.ac1 = nn.ReLU()
        self.hl2 = nn.Linear(d, d)
        self.ac2 = nn.ReLU()
        self.hl3 = nn.Linear(d, d)
        self.ac3 = nn.ReLU()
        self.mu_layer = nn.Linear(d, act_dim)
        self.log_std_layer = nn.Linear(d, act_dim)
        self.act_limit = act_limit

    def forward(
        self,
        obs,
        deterministic=False,
    ):
        x = obs
        x = self.ac1(self.hl1(x))
        self.ac1_output = x.clone().detach().cpu().numpy()
        x = self.ac2(self.hl2(x))
        self.ac2_output = x.clone().detach().cpu().numpy()
        x = self.ac3(self.hl3(x))
        self.ac3_output = x.clone().detach().cpu().numpy()
        mu = self.mu_layer(x)
        log_std = self.log_std_layer(x)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)
        pi_distribution = Normal(mu, std)

        if deterministic:
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()
        logp_a = pi_distribution.log_prob(pi_action).sum(axis=-1)
        logp_a -= (2 * (np.log(2) - pi_action - F.softplus(-2 * pi_action))).sum(
            axis=1
        )

        a = torch.tanh(pi_action)
        a = self.act_limit * a

        return a, logp_a, pi_action, mu, std


class MLPActorCritic(nn.Module):

    def __init__(
        self,
        obs_dim,
        act_dim,
        act_limit,
        hidden_sizes=(256, 256),
        activation=nn.ReLU,
    ):
        super().__init__()
        # build policy and value functions
        self.pi = SquashedGaussianMLPActor(
            obs_dim, act_dim, hidden_sizes, activation, act_limit
        )
        self.q1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.q2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, _, _ = self.pi(obs, deterministic, False)
            return a.cpu().numpy()

    def act_extended(self, obs, deterministic=False):
        with torch.no_grad():
            a, logp_a, pi, mu, std = self.pi(
                obs,
                deterministic=deterministic,
            )
            return {
                "a": a.cpu().numpy(),
                "logp_a": logp_a.cpu().numpy(),
                "pi": pi.cpu().numpy(),
                "mu": mu.cpu().numpy(),
                "std": std.cpu().numpy(),
            }


# Correlated Gaussian Actor
class SquashedCorrelatedGaussianMLPActor(nn.Module):
    def __init__(self, obs_dim, act_dim, act_limit, hidden_sizes, activation,  dropout_rate=0.3, **kwargs):
        super().__init__()
        self.act_dim = act_dim
        self.act_limit = act_limit
        
        # Create the network layers explicitly
        self.layers = nn.ModuleList()
        in_size = obs_dim
        self.activation_layer_names = {}
        
        for idx, size in enumerate(hidden_sizes):
            # 1. Linear layer
            linear_layer = nn.Linear(in_size, size)
            self.layers.append(linear_layer)
            
            # 2. Layer normalization (after linear, before activation)
            # LayerNorm helps stabilize training by normalizing activations
            layer_norm = nn.LayerNorm(size)
            self.layers.append(layer_norm)
            
            # 3. Activation
            activation_layer = activation()
            self.layers.append(activation_layer)
            self.activation_layer_names[len(self.layers) - 1] = f"ac{idx}_output"
            
            # 4. Dropout (after activation)
            # Dropout helps prevent overfitting and improves generalization
            dropout_layer = nn.Dropout(dropout_rate)
            self.layers.append(dropout_layer)
            
            in_size = size
            
        self.activation_outputs = {}
        
        # Output layers (no LayerNorm or Dropout here to preserve the output distribution)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.tril_params_size = act_dim * (act_dim + 1) // 2
        self.cov_layer = nn.Linear(hidden_sizes[-1], self.tril_params_size)
    
    def apply_bias(self, x, layer_name, layer_config):
        """Apply bias to specified neurons in a layer"""
        if len(layer_config) == 0:
            return x
        biased_x = x.clone()
        for neuron_info in layer_config:
            neuron_idx = neuron_info["neuron"]
            value = neuron_info["value"]
            biased_x[:, neuron_idx] += value
            logger.debug(f"{CYAN}Applying bias to layer: {layer_name}, neuron: {neuron_idx} with value: {value}{ENDC}")

        return biased_x
    
    def _compute_L_matrix(self, x, bias_config=None):
        # Obtain covariance parameters
        tril_params = self.cov_layer(x)  # Shape: [batch_size, tril_params_size]

        # Split tril_params into lower triangular elements
        L = torch.zeros(x.size(0), self.act_dim, self.act_dim, device=x.device, dtype=x.dtype)

        # Get indices for lower triangular matrix
        tril_indices = torch.tril_indices(row=self.act_dim, col=self.act_dim, offset=0)

        # Assign the lower triangular elements
        L[:, tril_indices[0], tril_indices[1]] = tril_params

        # Extract diagonal indices
        diag_indices = torch.arange(self.act_dim, device=x.device)

        # Apply softplus to diagonal to ensure positivity
        diagonal_values = F.softplus(L[:, diag_indices, diag_indices]) + 1e-4  # Adding epsilon for numerical stability
        if bias_config and "std" in bias_config:
            logger.debug(f"{RED}max_std: {bias_config['std']}{ENDC}")
            max_std = bias_config["std"]
            diagonal_values = torch.clamp(diagonal_values, min=1e-6, max=max_std)
        
        L[:, diag_indices, diag_indices] = diagonal_values 
        # Optional: You can clamp L to prevent extremely large values
        L = torch.clamp(L, min=-10, max=10)
        
        return L
    
    def forward_with_tree(self, obs, apply_tree=None):
        x = obs
        logger.debug(f"Applying tree")
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i in self.activation_layer_names:
                layer_name = self.activation_layer_names[i]
                
                if i == list(self.activation_layer_names.keys())[-1]:
                    logger.debug(f"{MAGENTA}Applying tree to layer: {layer_name}{ENDC}")
                    with torch.no_grad():
                        values = {}
                        
                        for idx in range(obs.shape[1]):
                            values[f"obs_{idx}"] = obs[0][idx].detach().cpu().numpy()
                        
                        for idx in range(x.shape[1]):
                            values[f"{layer_name}_{idx}"] = x[0][idx].detach().cpu().numpy()
                        
                        values = pd.Series(values)
                        biased_x = apply_tree(values)
                        x = torch.as_tensor(biased_x, dtype=torch.float32).to(obs.device)
                        
                self.activation_outputs[layer_name] = x.clone().detach().cpu().numpy().round(3)
         
        mu = self.mu_layer(x)  # Shape: [batch_size, act_dim]
        L = self._compute_L_matrix(x, None)  # Shape: [batch_size, act_dim, act_dim]
        

        # Compute covariance matrix: Σ = L * L^T
        cov = L @ L.transpose(-1, -2)
        
        # Cholesky decomposition is no longer needed since L is already the Cholesky factor
        # Define the multivariate normal distribution
        dist = MultivariateNormal(mu, scale_tril=L)

        # Sample actions
        pi = dist.rsample()  # Shape: [batch_size, act_dim]

        # Compute log probability if required
        logp_pi = dist.log_prob(pi)  # Shape: [batch_size]
        # Adjust the log probability for the action bounds (tanh squashing)
        logp_a = logp_pi -  (2 * (np.log(2) - pi - F.softplus(-2 * pi))).sum(axis=-1)

        # Apply tanh squashing and scale the actions
        a = torch.tanh(pi) * self.act_limit  # Shape: [batch_size, act_dim]

        # Extract the standard deviations from the diagonal of L
        std = torch.diagonal(L, dim1=-2, dim2=-1)  # Shape: [batch_size, act_dim]
        return a, logp_a, pi, mu, std, cov[0]
    
    def forward(self, obs, deterministic=False, bias_config=None, **kwargs):
        x = obs
        # logger.debug(f"{GREEN}bias_config: {bias_config}{ENDC}")
        # logger.debug(f"{GREEN}activation_layer_names:{self.activation_layer_names}{ENDC}")
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i in self.activation_layer_names:
                layer_name = self.activation_layer_names[i]
                if bias_config and layer_name in bias_config:
                    x = self.apply_bias(x, layer_name, bias_config.get(layer_name, []))
                
                self.activation_outputs[layer_name] = x.clone().detach().cpu().numpy().round(3)
         
        mu = self.mu_layer(x)  # Shape: [batch_size, act_dim]
        L = self._compute_L_matrix(x, bias_config)  # Shape: [batch_size, act_dim, act_dim]
        

        # Compute covariance matrix: Σ = L * L^T
        cov = L @ L.transpose(-1, -2)
        
        # Cholesky decomposition is no longer needed since L is already the Cholesky factor
        # Define the multivariate normal distribution
        dist = MultivariateNormal(mu, scale_tril=L)

        # Sample actions
        if deterministic:
            pi = mu
        else:
            pi = dist.rsample()  # Shape: [batch_size, act_dim]

        # Compute log probability if required
        logp_pi = dist.log_prob(pi)  # Shape: [batch_size]
        # Adjust the log probability for the action bounds (tanh squashing)
        logp_a = logp_pi -  (2 * (np.log(2) - pi - F.softplus(-2 * pi))).sum(axis=-1)

        # Apply tanh squashing and scale the actions
        a = torch.tanh(pi) * self.act_limit  # Shape: [batch_size, act_dim]

        # Extract the standard deviations from the diagonal of L
        std = torch.diagonal(L, dim1=-2, dim2=-1)  # Shape: [batch_size, act_dim]
        return a, logp_a, pi, mu, std, cov[0]

   
class CorreclatedMLPActorCritic(nn.Module):
    def __init__(
        self,
        obs_dim,
        act_dim,
        act_limit,
        hidden_sizes=(256, 256),
        activation=nn.ReLU,
        dropout_rate=0.1
    ):
        super().__init__()

        logger.info(f"obs_dim: {obs_dim}")
        
        # build policy and value functions
        self.pi = SquashedCorrelatedGaussianMLPActor(
            obs_dim,
            act_dim,
            act_limit,
            hidden_sizes,
            activation,
            dropout_rate=dropout_rate,
        )
        self.q1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.q2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, *_ = self.pi(obs, deterministic, False)
            return a.cpu().numpy()

    def act_extended(self, obs, deterministic=False, bias_config=None, **kwargs):
        with torch.no_grad():
            a, logp_a, pi, mu, std, cov, *_ = self.pi(
                obs,
                deterministic=deterministic,
                bias_config=bias_config,
            )
            
            return {
                "a": a.cpu().numpy(),
                "logp_a": logp_a.cpu().numpy(),
                "pi": pi.cpu().numpy(),
                "mu": mu.cpu().numpy(),
                "std": std.cpu().numpy(),
                "cov": cov.cpu().numpy(),
                "activation_outputs": self.pi.activation_outputs,
            }


