import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class DiagonalGaussian:
    """Diagonal Gaussian distribution over parameters for PAC-Bayes framework."""
    
    def __init__(self, mean, log_std=None, std=None):
        """Initialize with either log_std or std."""
        self.mean = mean
        if log_std is None and std is None:
            # Default initialization with small std
            self.log_std = torch.zeros_like(mean) - 2
        elif log_std is not None:
            self.log_std = log_std
        else:
            self.log_std = torch.log(std)
        
        self.std = torch.exp(self.log_std)
            
    def sample(self):
        """Sample parameters from the Gaussian distribution."""
        noise = torch.randn_like(self.mean)
        return self.mean + noise * self.std  
    
    def kl_divergence(self, other):
        """KL(self || other) - MUST be non-negative."""
        # KL(p||q) = 0.5 * (log(det(Σ_q)/det(Σ_p)) + tr(Σ_q^{-1}Σ_p) + (μ_q-μ_p)^T Σ_q^{-1} (μ_q-μ_p) - d)
        # For diagonal Gaussians:
        
        # Add small epsilon for numerical stability without changing the formula
        eps = 1e-8
        
        # Clamp std values to avoid division by very small numbers
        other_std_stable = torch.clamp(other.std, min=eps)
        self_std_stable = torch.clamp(self.std, min=eps)
        
        # Recompute log_std from clamped std to maintain consistency
        other_log_std_stable = torch.log(other_std_stable)
        self_log_std_stable = torch.log(self_std_stable)
        
        kl = 0.5 * (
            2 * (other_log_std_stable - self_log_std_stable) +  # log(σ_q/σ_p)
            (self_std_stable / other_std_stable) ** 2 +  # σ_p^2/σ_q^2
            ((other.mean - self.mean) / other_std_stable) ** 2 -  # (μ_q-μ_p)^2/σ_q^2
            1
        )
        
        kl_sum = kl.sum()
        return kl_sum

    def log_prob(self, x):
        """Log probability of x under this distribution."""
        return -0.5 * (
            ((x - self.mean) / self.std) ** 2 + 
            2 * self.log_std + 
            math.log(2 * math.pi)
        ).sum()
    
    def entropy(self):
        """Compute entropy of the distribution."""
        return (self.log_std + 0.5 * math.log(2 * math.pi * math.e)).sum()
    

import torch
from torch.optim import Optimizer

###############################################################################


class COCOB(Optimizer):

    def __init__(self, params, weight_decay=0, alpha=100):
        """
        Initialize the COCOB optimizer

        Parameters
        ----------
        params :
            Parameters of the model to optimize
        weight_decay : float
            The weight decay applied to the parameters
        alpha : float
            The parameter of the algorithm COCOB-Backprop (see [1])
        """
        defaults = dict(weight_decay=weight_decay)
        super(COCOB, self).__init__(params, defaults)

        assert weight_decay >= 0.0
        assert alpha > 0.0
        self.weight_decay = weight_decay
        self.alpha = alpha

        self.state = {}

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        # We update all the parameters with Algorithm 2 (COCOB-Backprop)
        # that was introduced in [1]
        for group in self.param_groups:
            for w in group['params']:

                # We get the gradient
                grad = w.grad

                if grad is None:
                    continue
                if grad.is_sparse:
                    raise RuntimeError(
                        "COCOB does not support sparse gradients")

                # We initialize the state
                if(w not in self.state):
                    self.state[w] = {}
                state = self.state[w]

                # We initialize the initial weights
                if("w_1" not in state):
                    state["w_1"] = w.data.clone()

                # We add the weight decay
                if group['weight_decay'] != 0:
                    grad = grad.add(w, alpha=group['weight_decay'])

                # We get the negative gradient (line 4)
                grad = -grad

                # We update the maximum observed scale (line 6)
                if("L" not in state):
                    state["L"] = torch.tensor(1.0e-9, device=grad.device)
                state["L"] = torch.max(state["L"], torch.abs(grad))

                # We update the sum of the absolute values of the grad (line 7)
                if("G" not in state):
                    state["G"] = torch.tensor(0.0, device=grad.device)
                state["G"] = state["G"] + torch.abs(grad)

                # We update the sum of the grad (line 9)
                if("theta" not in state):
                    state["theta"] = torch.tensor(0.0, device=grad.device)
                state["theta"] = state["theta"] + grad

                # We update the reward (line 8)
                if("reward" not in state):
                    state["reward"] = torch.tensor(0.0, device=grad.device)
                state["reward"] = (state["reward"]
                                   + (w.data - state["w_1"])*grad)
                state["reward"] = torch.max(
                    torch.tensor(0.0, device=grad.device), state["reward"])

                # We compute the associated wealth
                wealth = state["reward"] + state["L"]

                # We compute the beta (line 10)
                beta = torch.max(state["G"]+state["L"], self.alpha*state["L"])
                beta = state["theta"]/(state["L"]*beta)

                # We calculate the parameters (line 10)
                w.data = state["w_1"] + beta*wealth

        return loss


class CReLU(nn.Module):
    """
    Concatenated ReLU (CReLU) activation module.

    This module implements the CReLU activation function, which concatenates
    the ReLU activations of both the input and its negation along the last dimension:

        CReLU(x) = ReLU([x, -x])

    This increases the representational capacity of the model by doubling the
    number of features while preserving non-linearity.
    """

    def __init__(self) -> None:
        """
        Initialize the CReLU module.
        Args:
            None
        Returns:
            None
        """
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the CReLU activation.
        Args:
            x (torch.Tensor): Input tensor of shape (..., features)
        Returns:
            torch.Tensor: Output tensor with ReLU applied to both x and -x,
                          concatenated along the last dimension (shape: (..., 2 * features))
        """
        x = torch.cat((x, -x), -1)
        return F.relu(x)