
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import random
import yaml
from types import SimpleNamespace
import os
     

# Recursive function to convert nested dictionaries to SimpleNamespace
def dict_to_namespace(config_dict):
	for key, value in config_dict.items():
		if isinstance(value, dict):
			config_dict[key] = dict_to_namespace(value)
	return SimpleNamespace(**config_dict)

# Recursive function to convert SimpleNamespace to dictionary for saving
def namespace_to_dict(namespace):
	if isinstance(namespace, SimpleNamespace):
		return {key: namespace_to_dict(value) for key, value in vars(namespace).items()}
	elif isinstance(namespace, dict):
		return {key: namespace_to_dict(value) for key, value in namespace.items()}
	else:
		return namespace

# Load YAML config file
def load_config(file_path):
	with open(file_path, 'r') as file:
		config_dict = yaml.safe_load(file)
		return dict_to_namespace(config_dict)

# Save YAML config file
def save_config(config, save_path):
	# Convert SimpleNamespace config back to dictionary
	config_dict = namespace_to_dict(config)
	
	# Ensure the directory exists
	os.makedirs(os.path.dirname(save_path), exist_ok=True)
	
	# Write config to the specified path
	with open(save_path, 'w') as file:
		yaml.dump(config_dict, file)

def utils_l2_norm(named_params):
    total_norm = 0.0
    for name, param in named_params:
        if 'original_last_layer_params' in name \
            or 'init_params' in name or 'layer_norm' in name:
                continue

        if param.requires_grad:
            param_norm = param.data.norm(2)
            total_norm += param_norm.item() ** 2

    total_norm = total_norm ** 0.5
    return total_norm
    
def new_utils_l2_norm(named_params):
    total_norm = 0
    total_params = 0
    
    l1_norm_dict = {}

    for name, param in named_params:
        if 'original_last_layer_params' in name \
            or 'init_params' in name or 'layer_norm' in name:
                continue
            
        if param.requires_grad:
            l1_norm = param.data.pow(2).sum().sqrt().detach().item()
            num_neurons = param.numel()
            
            total_norm += l1_norm
            total_params += num_neurons
            
            l1_norm_dict[f'model_l1_norm/{name}'] = l1_norm / num_neurons

    average_l1_norm = total_norm / total_params
    l1_norm_dict['model_l1_norm/avg_l1_norm'] = average_l1_norm
    
    return l1_norm_dict    


def utils_l1_norm(named_params):
    total_norm = 0
    total_params = 0
    
    l1_norm_dict = {}

    for name, param in named_params:
        if 'original_last_layer_params' in name \
            or 'init_params' in name or 'layer_norm' in name:
                continue
            
        if param.requires_grad:
            l1_norm = param.data.abs().sum().detach().item()
            num_neurons = param.numel()
            
            total_norm += l1_norm
            total_params += num_neurons
            
            l1_norm_dict[f'model_l1_norm/{name}'] = l1_norm / num_neurons

    average_l1_norm = total_norm / total_params
    l1_norm_dict['model_l1_norm/avg_l1_norm'] = average_l1_norm
    
    return l1_norm_dict


def set_seed(seed: int):
    torch.manual_seed(seed)  # Set seed for PyTorch
    torch.cuda.manual_seed(seed)  # Set seed for CUDA
    torch.cuda.manual_seed_all(seed)  # Set seed for all CUDA devices (if using multi-GPU)
    np.random.seed(seed)  # Set seed for NumPy
    random.seed(seed)  # Set seed for Python's built-in random module


    # Ensure deterministic behavior
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = False  # Disable benchmarking to avoid nondeterministic behavior



def ext_soft_histogram(x, num_bins, min_val, max_val, sigma):
    """
    Computes a differentiable histogram using soft binning.

    Args:
        x (Tensor): 1D or 2D tensor containing the values. If 2D, the first dim is the batch size.
        num_bins (int): Number of histogram bins.
        min_val (float): Minimum value of the histogram range.
        max_val (float): Maximum value of the histogram range.
        sigma (float): Standard deviation for the Gaussian kernel used in soft-assignment.

    Returns:
        Tensor: Normalized histogram.
            - If input is 1D, returns shape (num_bins,)
            - If input is 2D, returns shape (batch_size, num_bins)
    """
    # If input is 1D, add a batch dimension for uniform processing
    if x.dim() == 1:
        x = x.unsqueeze(0)
        batch_input = False
    else:
        batch_input = True

    # Create bin centers uniformly in [min_val, max_val]
    bin_centers = torch.linspace(min_val, max_val, steps=num_bins, device=x.device)  # shape: (num_bins,)
    # Reshape bin_centers for broadcasting: (1, 1, num_bins)
    bin_centers = bin_centers.view(1, 1, num_bins)

    # Expand x: from (batch, N) to (batch, N, 1)
    x_expanded = x.unsqueeze(-1)
    # Compute soft assignments using a Gaussian kernel; result shape: (batch, N, num_bins)
    weights = torch.exp(-0.5 * ((x_expanded - bin_centers) / sigma) ** 2)
    
    # Normalize weights so that each value's contributions over bins sum to 1
    weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-8)
    
    # Sum the soft assignments over the sample dimension (dim=1) for each batch
    hist = weights.sum(dim=1)  # shape: (batch, num_bins)
    
    # Normalize each histogram to form a probability distribution
    hist = hist / (hist.sum(dim=-1, keepdim=True) + 1e-8)
    
    # If the original input was 1D, remove the added batch dimension
    if not batch_input:
        hist = hist.squeeze(0)
    return hist


def ext_histogram_divergence_loss(x, num_bins, min_val, max_val, sigma):

    """
    Computes the divergence (KL divergence) between the histogram of x and a uniform histogram.

    Args:
        x (Tensor): 1D or 2D tensor of values.
            - If 2D, the first dim is the batch size.
        num_bins (int): Number of histogram bins.
        min_val (float): Minimum value of the histogram range.
        max_val (float): Maximum value of the histogram range.
        sigma (float): Standard deviation for the Gaussian kernel used in soft-assignment.

    Returns:
        Tensor: Scalar divergence loss (KL divergence between computed and uniform histograms).
                For 2D input, the divergence is averaged over the batch.
    """
    if sigma < 1e-4:
        sigma = 1e-4

    hist = ext_soft_histogram(x, num_bins, min_val, max_val, sigma)
    # Target uniform distribution
    target = torch.full_like(hist, 1.0 / num_bins)
    eps = 1e-8
    # Compute KL divergence: sum(hist * log(hist/target))
    divergence = (hist + eps) * torch.log((hist + eps) / target)
    
    if divergence.dim() == 1:
        # 1D input: return single scalar divergence
        return divergence.sum()
    else:
        # 2D input: compute divergence per batch and average
        return divergence.sum(dim=1).mean()


def ext_soft_max(x, beta=50.0):
    """
    Computes a weighted sum that approximates max(x).

    Args:
        x (Tensor): 1D or 2D tensor.
            - If 2D, the first dim is the batch size.
        beta (float): Parameter controlling the sharpness of the approximation.

    Returns:
        Tensor: Approximated maximum value.
            - For 1D input, returns a scalar.
            - For 2D input, returns a tensor of shape (batch_size,)
    """
    if x.dim() == 1:
        weights = torch.softmax(beta * x, dim=0)
        return torch.abs(torch.sum(x * weights))
    elif x.dim() == 2:
        weights = torch.softmax(beta * x, dim=1)
        return torch.abs(torch.sum(x * weights, dim=1))


def ext_soft_min(x, beta=50.0):
    """
    Computes a weighted sum that approximates min(x).

    Args:
        x (Tensor): 1D or 2D tensor.
            - If 2D, the first dim is the batch size.
        beta (float): Parameter controlling the sharpness of the approximation.

    Returns:
        Tensor: Approximated minimum value.
            - For 1D input, returns a scalar.
            - For 2D input, returns a tensor of shape (batch_size,)
    """
    if x.dim() == 1:
        weights = torch.softmax(-beta * x, dim=0)
        return torch.abs(torch.sum(x * weights))
    elif x.dim() == 2:
        weights = torch.softmax(-beta * x, dim=1)
        return torch.abs(torch.sum(x * weights, dim=1))

def linear_per_sample(x, weight, bias):
    """
    Applies a separate linear transformation to each sample in the batch.

    Args:
        x (Tensor): Input tensor of shape (b, in_features).
        weight (Tensor): Weights tensor of shape (b, out_features, in_features).
        bias (Tensor): Bias tensor of shape (b, out_features).

    Returns:
        Tensor: Output tensor of shape (b, out_features).
    """
    # Unsqueeze x to shape (b, in_features, 1) so that we can batch multiply.
    # bmm returns shape (b, out_features, 1); squeeze the last dim then add bias.
    return torch.bmm(weight, x.unsqueeze(2)).squeeze(2) + bias


def conv2d_per_sample(x, weight, bias=None, stride=1, padding=0, dilation=1):
    """
    Applies a 2D convolution where each sample in the batch has its own weight and bias.

    Args:
        x (Tensor): Input tensor of shape (b, c_in, h, w).
        weight (Tensor): Weights tensor of shape (b, c_out, c_in, kH, kW).
        bias (Tensor, optional): Bias tensor of shape (b, c_out). Default is None.
        stride (int or tuple): Stride for the convolution.
        padding (int or tuple): Padding for the convolution.
        dilation (int or tuple): Dilation for the convolution.

    Returns:
        Tensor: Output tensor of shape (b, c_out, h_out, w_out).
    """
    b, c_in, h, w = x.shape
    b, c_out, _, kH, kW = weight.shape

    # Reshape x to (1, b*c_in, h, w)
    x_reshaped = x.reshape(1, b * c_in, h, w)
    # Reshape weight to (b*c_out, c_in, kH, kW)
    weight_reshaped = weight.reshape(b * c_out, c_in, kH, kW)

    # Use groups=b so that the convolution is applied separately for each sample.
    out = F.conv2d(x_reshaped, weight_reshaped, bias=None,
                   stride=stride, padding=padding, dilation=dilation, groups=b)

    # out has shape (1, b*c_out, h_out, w_out). Reshape it back to (b, c_out, h_out, w_out).
    h_out, w_out = out.shape[-2], out.shape[-1]
    out = out.reshape(b, c_out, h_out, w_out)

    if bias is not None:
        # Add bias: reshape bias from (b, c_out) to (b, c_out, 1, 1) for broadcasting.
        out = out + bias.unsqueeze(2).unsqueeze(3)
    return out

def conv_output_dim(input_size, kernel_size, stride, padding):
    """
    Calculate the output size for a convolutional layer dimension.
    
    Parameters:
      input_size (int): height or width of the input
      kernel_size (int): size of the convolution kernel (assumed square)
      stride (int): stride of the convolution
      padding (int): amount of zero-padding applied on each side
      
    Returns:
      int: output size after applying convolution
    """
    return ((input_size - kernel_size + 2 * padding) // stride) + 1

def pooling_output_dim(input_size, pool_size):
    """
    Calculate the output size for a pooling layer assuming the stride equals pool_size.
    
    Parameters:
      input_size (int): height or width of the input
      pool_size (int): size of the pooling kernel
      
    Returns:
      int: output size after pooling
    """
    return input_size // pool_size

def calculate_network_dims(input_height, input_width, input_channels, layers):
    """
    Calculate output dimensions for multiple convolution and pooling layers, 
    then return the final flattened dimension.
    
    Parameters:
      input_height (int): initial height of the image
      input_width (int): initial width of the image
      input_channels (int): initial number of channels (e.g., 3 for RGB)
      layers (list of dicts): Each dictionary contains parameters for the layer:
          - 'kernel_size': int (convolution kernel size)
          - 'stride': int (convolution stride)
          - 'padding': int (convolution padding)
          - 'filters': int (number of output filters/channels)
          - 'pool_size': int (pooling kernel size, assumed to be applied after conv)
          
    Returns:
      final_flatten_dim (int): size of the flattened layer output
    """
    current_height = input_height
    current_width = input_width
    current_channels = input_channels
    output_shapes = []
    for i, layer in enumerate(layers):
        k = layer['kernel_size']
        s = layer['stride']
        p = layer['padding']
        filters = layer['filters']
        pool_size = layer.get('pool_size', None)
        #print(type(k), type(s), type(p), type(filters), type(pool_size))
    
        # Calculate conv output dimensions
        conv_height = conv_output_dim(current_height, k, s, p)
        conv_width = conv_output_dim(current_width, k, s, p)
        output_shapes.append((filters, conv_height, conv_width))
        #print(f"Layer {i+1} - Convolution output: {conv_height} x {conv_width} x {filters}")
        
        # Update channels to the number of filters
        current_channels = filters
        
        # If pooling is specified, calculate pooling output dimensions
        if pool_size is not None:
            pool_height = pooling_output_dim(conv_height, pool_size)
            pool_width = pooling_output_dim(conv_width, pool_size)
            #print(f"Layer {i+1} - After pooling (pool size {pool_size}): {pool_height} x {pool_width} x {filters}")
            # Set new dimensions to the output of the pooling layer
            current_height, current_width = pool_height, pool_width
        else:
            # If no pooling, keep convolution output dims for the next layer
            current_height, current_width = conv_height, conv_width
    
    # Flatten the final output dimension
    final_flatten_dim = current_height * current_width * current_channels
    output_shapes.append((final_flatten_dim))
    #print(f"Final flattened dimension: {final_flatten_dim}")
    return output_shapes