"""
Utility functions for weighted adaptive pooling operations.

This module provides utilities for computing pooling parameters and performing
weighted adaptive pooling for both 2D (audio) and 3D (visual) feature tensors.
The weighted pooling operations are essential for attention-guided feature extraction
in multimodal learning tasks.
"""

import math
import torch
from torch import nn
import torch.nn.functional as F
from transformers.activations import ACT2FN


def get_size_visual(input_size, output_size=None, kernel=None, stride=None):
    """
    Calculate pooling parameters for 3D visual features (D, H, W).
    
    Given input dimensions and two of the three parameters (output_size, kernel, stride),
    this function computes the missing parameter to ensure consistent pooling operations.
    
    Args:
        input_size (tuple): Input tensor dimensions (D, H, W)
        output_size (tuple, optional): Target output dimensions (D', H', W')
        kernel (tuple, optional): Pooling kernel size (kD, kH, kW)
        stride (tuple, optional): Pooling stride (sD, sH, sW)
    
    Returns:
        tuple: (output_size, kernel_size, stride) with all parameters computed
        
    Raises:
        ValueError: If input_size is None or invalid parameter combinations
    """
    """
    Calculate pooling parameters for 3D visual features (D, H, W).
    
    Given input dimensions and two of the three parameters (output_size, kernel, stride),
    this function computes the missing parameter to ensure consistent pooling operations.
    
    Args:
        input_size (tuple): Input tensor dimensions (D, H, W)
        output_size (tuple, optional): Target output dimensions (D', H', W')
        kernel (tuple, optional): Pooling kernel size (kD, kH, kW)
        stride (tuple, optional): Pooling stride (sD, sH, sW)
    
    Returns:
        tuple: (output_size, kernel_size, stride) with all parameters computed
        
    Raises:
        ValueError: If input_size is None or invalid parameter combinations
    """
    if input_size is None:
        raise ValueError("input_size must not be None")

    I_D, I_H, I_W = input_size

    # Case 1: Compute kernel and stride from input and output sizes
    if kernel is None and stride is None:
        if output_size is None:
            raise ValueError("output_size must not be None when both kernel and stride are None")
        O_D, O_H, O_W = output_size
        stride_D = I_D // O_D
        stride_H = I_H // O_H
        stride_W = I_W // O_W
        kernel_D = I_D - (O_D - 1) * stride_D
        kernel_H = I_H - (O_H - 1) * stride_H
        kernel_W = I_W - (O_W - 1) * stride_W
        return (O_D, O_H, O_W), (kernel_D, kernel_H, kernel_W), (stride_D, stride_H, stride_W)

    # Case 2: Compute output size from input, kernel, and stride
    if output_size is None and kernel is not None and stride is not None:
        kernel_D, kernel_H, kernel_W = kernel
        stride_D, stride_H, stride_W = stride
        O_D = (I_D - kernel_D) // stride_D + 1
        O_H = (I_H - kernel_H) // stride_H + 1
        O_W = (I_W - kernel_W) // stride_W + 1
        return (O_D, O_H, O_W), kernel, stride

    # Case 3: Compute kernel from input, output, and stride
    elif kernel is None and output_size is not None and stride is not None:
        O_D, O_H, O_W = output_size
        stride_D, stride_H, stride_W = stride
        kernel_D = I_D - (O_D - 1) * stride_D
        kernel_H = I_H - (O_H - 1) * stride_H
        kernel_W = I_W - (O_W - 1) * stride_W
        return output_size, (kernel_D, kernel_H, kernel_W), stride

    # Case 4: Compute stride from input, output, and kernel
    elif stride is None and output_size is not None and kernel is not None:
        O_D, O_H, O_W = output_size
        kernel_D, kernel_H, kernel_W = kernel
        stride_D = (I_D - kernel_D) // (O_D - 1)
        stride_H = (I_H - kernel_H) // (O_H - 1)
        stride_W = (I_W - kernel_W) // (O_W - 1)
        return output_size, kernel, (stride_D, stride_H, stride_W)

    else:
        raise ValueError("Invalid combination of parameters. One of output_size, kernel, or stride must be None.")

def get_size_audio(input_size, output_size=None, kernel=None, stride=None):
    """
    Calculate pooling parameters for 2D audio features (T, B).
    
    Given input dimensions and two of the three parameters (output_size, kernel, stride),
    this function computes the missing parameter for 2D audio pooling operations.
    
    Args:
        input_size (tuple): Input tensor dimensions (T, B) where T=time, B=frequency bands
        output_size (tuple, optional): Target output dimensions (T', B')
        kernel (tuple, optional): Pooling kernel size (kT, kB)
        stride (tuple, optional): Pooling stride (sT, sB)
    
    Returns:
        tuple: (output_size, kernel_size, stride) with all parameters computed
        
    Raises:
        ValueError: If input_size is None or invalid parameter combinations
    """
    """
    Calculate pooling parameters for 2D audio features (T, B).
    
    Given input dimensions and two of the three parameters (output_size, kernel, stride),
    this function computes the missing parameter for 2D audio pooling operations.
    
    Args:
        input_size (tuple): Input tensor dimensions (T, B) where T=time, B=frequency bands
        output_size (tuple, optional): Target output dimensions (T', B')
        kernel (tuple, optional): Pooling kernel size (kT, kB)
        stride (tuple, optional): Pooling stride (sT, sB)
    
    Returns:
        tuple: (output_size, kernel_size, stride) with all parameters computed
        
    Raises:
        ValueError: If input_size is None or invalid parameter combinations
    """
    if input_size is None:
        raise ValueError("input_size must not be None")
    I_T, I_B = input_size
    output_size = None

    # Case 1: Compute kernel and stride from input and output sizes
    if kernel is None and stride is None:
        if output_size is None:
            raise ValueError("output_size must not be None when both kernel and stride are None")
        O_T, O_B = output_size
        stride_T = I_T // O_T
        stride_B = I_B // O_B
        kernel_T = I_T - (O_T - 1) * stride_T
        kernel_B = I_B - (O_B - 1) * stride_B
        return (O_T, O_B), (kernel_T, kernel_B), (stride_T, stride_B)

    # Case 2: Compute output size from input, kernel, and stride
    if output_size is None and kernel is not None and stride is not None:
        kernel_T, kernel_B = kernel
        stride_T, stride_B = stride
        O_T = (I_T - kernel_T) // stride_T + 1
        O_B = (I_B - kernel_B) // stride_B + 1
        return (O_T, O_B), kernel, stride

    # Case 3: Compute kernel from input, output, and stride
    elif kernel is None and output_size is not None and stride is not None:
        O_T, O_B = output_size
        stride_T, stride_B = stride
        kernel_T = I_T - (O_T - 1) * stride_T
        kernel_B = I_B - (O_B - 1) * stride_B
        return output_size, (kernel_T, kernel_B), stride

    # Case 4: Compute stride from input, output, and kernel
    elif stride is None and output_size is not None and kernel is not None:
        O_T, O_B = output_size
        kernel_T, kernel_B = kernel
        stride_T = (I_T - kernel_T) // (O_T - 1)
        stride_B = (I_B - kernel_B) // (O_B - 1)
        return output_size, kernel, (stride_T, stride_B)

    else:
        raise ValueError("Invalid combination of parameters. One of output_size, kernel, or stride must be None.")


def weighted_adaptive_avg_pool2d_unfold(input, output_size=None, kernel=None, stride=None, weights=None, temperature=0.01):
    """
    Perform weighted adaptive average pooling on 2D audio features with padding.
    
    This function applies attention-weighted pooling to 2D tensors representing audio features.
    It includes automatic padding to handle cases where input dimensions don't align with
    kernel and stride parameters, ensuring consistent output dimensions.
    
    Args:
        input (torch.Tensor): Input tensor of shape [N, D, T, B] where N=batch, D=features, T=time, B=bands
        output_size (tuple, optional): Target output dimensions (T', B')
        kernel (tuple): Pooling kernel size (kT, kB)
        stride (tuple): Pooling stride (sT, sB)
        weights (torch.Tensor, optional): Attention weights of shape [N, T, B]. If None, uniform weights are used
        temperature (float): Temperature parameter for softmax normalization of weights
    
    Returns:
        torch.Tensor: Pooled features of shape [N, D, T', B'] where T', B' are output dimensions
    """
    """
    Perform weighted adaptive average pooling on 2D audio features with padding.
    
    This function applies attention-weighted pooling to 2D tensors representing audio features.
    It includes automatic padding to handle cases where input dimensions don't align with
    kernel and stride parameters, ensuring consistent output dimensions.
    
    Args:
        input (torch.Tensor): Input tensor of shape [N, D, T, B] where N=batch, D=features, T=time, B=bands
        output_size (tuple, optional): Target output dimensions (T', B')
        kernel (tuple): Pooling kernel size (kT, kB)
        stride (tuple): Pooling stride (sT, sB)
        weights (torch.Tensor, optional): Attention weights of shape [N, T, B]. If None, uniform weights are used
        temperature (float): Temperature parameter for softmax normalization of weights
    
    Returns:
        torch.Tensor: Pooled features of shape [N, D, T', B'] where T', B' are output dimensions
    """
    N, D, T, B = input.shape
    stride_T, stride_B = stride
    kernel_T, kernel_B = kernel
    
    # Calculate padding needed to ensure proper alignment with stride
    padded_T = 0
    padded_B = 0
    if (T + kernel_T) % stride_T != 0:
        padded_T = stride_T - ((T + kernel_T) % stride_T)
    if (B + kernel_B) % stride_B != 0:
        padded_B = stride_B - ((B + kernel_B) % stride_B)
    
    # Apply padding to input and weights
    input = F.pad(input, (0, padded_B, 0, padded_T), "constant", 0)
    weights = F.pad(weights, (0, padded_B, 0, padded_T), "constant", 0)
    N, D, T, B = input.shape
    
    # Compute output dimensions and validate parameters
    output_size, kernel, stride = get_size_audio((T, B), output_size, kernel, stride)
    out_T, out_B = output_size
    
    # Initialize uniform weights if not provided
    if weights is None:
        weights = torch.ones((N, T, B)).to(f'cuda:{input.device}' if isinstance(input.device, int) else input.device, input.dtype)

    # Unfold input and weight tensors to create sliding windows
    input_unf = input.unfold(2, kernel_T, stride_T).unfold(3, kernel_B, stride_B)
    weights_unf = weights.unfold(1, kernel_T, stride_T).unfold(2, kernel_B, stride_B)

    # Reshape for computation: flatten kernel dimensions
    input_unf = input_unf.contiguous().view(N, D, out_T, out_B, -1)
    weights_unf = weights_unf.contiguous().view(N, out_T, out_B, -1)

    # Apply temperature-scaled softmax to normalize attention weights
    weights_unf = F.softmax(weights_unf / temperature, dim=-1)

    # Compute weighted average pooling
    weighted_sum = (input_unf * weights_unf.unsqueeze(1)).sum(dim=-1)

    return weighted_sum


def weighted_adaptive_avg_pool3d_unfold(input, output_size=None, kernel=None, stride=None, weights=None, temperature=0.01):
    """
    Perform weighted adaptive average pooling on 3D visual features with padding.
    
    This function applies attention-weighted pooling to 3D tensors representing visual features.
    It includes automatic padding to handle cases where input dimensions don't align with
    kernel and stride parameters, ensuring consistent output dimensions.
    
    Args:
        input (torch.Tensor): Input tensor of shape [N, C, D, H, W] where N=batch, C=channels, D=depth, H=height, W=width
        output_size (tuple, optional): Target output dimensions (D', H', W')
        kernel (tuple): Pooling kernel size (kD, kH, kW)
        stride (tuple): Pooling stride (sD, sH, sW)
        weights (torch.Tensor, optional): Attention weights of shape [N, D, H, W]. If None, uniform weights are used
        temperature (float): Temperature parameter for softmax normalization of weights
    
    Returns:
        torch.Tensor: Pooled features of shape [N, C, D', H', W'] where D', H', W' are output dimensions
    """
    """
    Perform weighted adaptive average pooling on 3D visual features with padding.
    
    This function applies attention-weighted pooling to 3D tensors representing visual features.
    It includes automatic padding to handle cases where input dimensions don't align with
    kernel and stride parameters, ensuring consistent output dimensions.
    
    Args:
        input (torch.Tensor): Input tensor of shape [N, C, D, H, W] where N=batch, C=channels, D=depth, H=height, W=width
        output_size (tuple, optional): Target output dimensions (D', H', W')
        kernel (tuple): Pooling kernel size (kD, kH, kW)
        stride (tuple): Pooling stride (sD, sH, sW)
        weights (torch.Tensor, optional): Attention weights of shape [N, D, H, W]. If None, uniform weights are used
        temperature (float): Temperature parameter for softmax normalization of weights
    
    Returns:
        torch.Tensor: Pooled features of shape [N, C, D', H', W'] where D', H', W' are output dimensions
    """
    N, C, D, H, W = input.shape
    stride_D, stride_H, stride_W = stride
    kernel_D, kernel_H, kernel_W = kernel
    
    # Calculate padding needed to ensure proper alignment with stride
    padded_D = 0
    padded_H = 0
    padded_W = 0
    if (D + kernel_D) % stride_D != 0:
        padded_D = stride_D - ((D + kernel_D) % stride_D)
    if (H + kernel_H) % stride_H != 0:
        padded_H = stride_H - ((H + kernel_H) % stride_H)
    if (W + kernel_W) % stride_W != 0:
        padded_W = stride_W - ((W + kernel_W) % stride_W)
        
    # Apply padding to input and weights
    input = F.pad(input, (0, padded_W, 0, padded_H, 0, padded_D), "constant", 0)
    weights = F.pad(weights, (0, padded_W, 0, padded_H, 0, padded_D), "constant", 0)
    N, C, D, H, W = input.shape
    
    # Compute output dimensions and validate parameters
    output_size, kernel, stride = get_size_visual((D, H, W), output_size, kernel, stride)
    out_D, out_H, out_W = output_size
    
    # Initialize uniform weights if not provided
    if weights is None:
        weights = torch.ones((N, D, W, H)).to(f'cuda:{input.device}' if isinstance(input.device, int) else input.device, input.dtype)

    # Unfold input and weight tensors to create sliding windows
    input_unf = input.unfold(2, kernel_D, stride_D).unfold(3, kernel_H, stride_H).unfold(4, kernel_W, stride_W)
    weights_unf = weights.unfold(1, kernel_D, stride_D).unfold(2, kernel_H, stride_H).unfold(3, kernel_W, stride_W)

    # Reshape for computation: flatten kernel dimensions
    input_unf = input_unf.contiguous().view(N, C, out_D, out_H, out_W, -1)
    weights_unf = weights_unf.contiguous().view(N, out_D, out_H, out_W, -1)

    # Apply temperature-scaled softmax to normalize attention weights
    weights_unf = F.softmax(weights_unf / temperature, dim=-1)

    # Compute weighted average pooling
    weighted_sum = (input_unf * weights_unf.unsqueeze(1)).sum(dim=-1)

    return weighted_sum


def weighted_adaptive_avg_pool2d_unfold_old(input, output_size=None, kernel=None, stride=None, weights=None, temperature=0.01):
    """
    Legacy version of 2D weighted adaptive pooling without padding.
    
    This is the original implementation without automatic padding support.
    Kept for backward compatibility but the padded version is recommended.
    
    Args:
        input (torch.Tensor): Input tensor of shape [N, D, T, B]
        output_size (tuple, optional): Target output dimensions (T', B')
        kernel (tuple): Pooling kernel size (kT, kB)
        stride (tuple): Pooling stride (sT, sB)
        weights (torch.Tensor, optional): Attention weights of shape [N, T, B]
        temperature (float): Temperature parameter for softmax normalization
    
    Returns:
        torch.Tensor: Pooled features of shape [N, D, T', B']
    """
    """
    Legacy version of 2D weighted adaptive pooling without padding.
    
    This is the original implementation without automatic padding support.
    Kept for backward compatibility but the padded version is recommended.
    
    Args:
        input (torch.Tensor): Input tensor of shape [N, D, T, B]
        output_size (tuple, optional): Target output dimensions (T', B')
        kernel (tuple): Pooling kernel size (kT, kB)
        stride (tuple): Pooling stride (sT, sB)
        weights (torch.Tensor, optional): Attention weights of shape [N, T, B]
        temperature (float): Temperature parameter for softmax normalization
    
    Returns:
        torch.Tensor: Pooled features of shape [N, D, T', B']
    """
    N, D, T, B = input.shape
    output_size, kernel, stride = get_size_audio((T, B), output_size, kernel, stride)
    out_T, out_B = output_size
    stride_T, stride_B = stride
    kernel_T, kernel_B = kernel
    
    if weights is None:
        weights = torch.ones((N, T, B)).to(f'cuda:{input.device}' if isinstance(input.device, int) else input.device, input.dtype)

    input_unf = input.unfold(2, kernel_T, stride_T).unfold(3, kernel_B, stride_B)
    weights_unf = weights.unfold(1, kernel_T, stride_T).unfold(2, kernel_B, stride_B)

    input_unf = input_unf.contiguous().view(N, D, out_T, out_B, -1)
    weights_unf = weights_unf.contiguous().view(N, out_T, out_B, -1)

    weights_unf = F.softmax(weights_unf / temperature, dim=-1)

    weighted_sum = (input_unf * weights_unf.unsqueeze(1)).sum(dim=-1)

    return weighted_sum


def weighted_adaptive_avg_pool3d_unfold_old(input, output_size=None, kernel=None, stride=None, weights=None, temperature=0.01):
    """
    Legacy version of 3D weighted adaptive pooling without padding.
    
    This is the original implementation without automatic padding support.
    Kept for backward compatibility but the padded version is recommended.
    
    Args:
        input (torch.Tensor): Input tensor of shape [N, C, D, H, W]
        output_size (tuple, optional): Target output dimensions (D', H', W')
        kernel (tuple): Pooling kernel size (kD, kH, kW)
        stride (tuple): Pooling stride (sD, sH, sW)
        weights (torch.Tensor, optional): Attention weights of shape [N, D, H, W]
        temperature (float): Temperature parameter for softmax normalization
    
    Returns:
        torch.Tensor: Pooled features of shape [N, C, D', H', W']
    """
    N, C, D, H, W = input.shape

    output_size, kernel, stride = get_size_visual((D, H, W), output_size, kernel, stride)
    out_D, out_H, out_W = output_size
    stride_D, stride_H, stride_W = stride
    kernel_D, kernel_H, kernel_W = kernel
    
    if weights is None:
        weights = torch.ones((N, D, W, H)).to(f'cuda:{input.device}' if isinstance(input.device, int) else input.device, input.dtype)

    input_unf = input.unfold(2, kernel_D, stride_D).unfold(3, kernel_H, stride_H).unfold(4, kernel_W, stride_W)
    weights_unf = weights.unfold(1, kernel_D, stride_D).unfold(2, kernel_H, stride_H).unfold(3, kernel_W, stride_W)

    input_unf = input_unf.contiguous().view(N, C, out_D, out_H, out_W, -1)
    weights_unf = weights_unf.contiguous().view(N, out_D, out_H, out_W, -1)

    weights_unf = F.softmax(weights_unf / temperature, dim=-1)

    weighted_sum = (input_unf * weights_unf.unsqueeze(1)).sum(dim=-1)

    return weighted_sum
