"""
Utility Functions for Influence Function Computation

This module provides utility functions for working with model parameters,
gradients, and tensor operations in influence function computation.
"""

from __future__ import annotations

import math
from typing import TYPE_CHECKING, List, Optional, Dict, Tuple, Any, Callable
import functools

import numpy as np
import torch
from torch import Tensor

if TYPE_CHECKING:
    from collections.abc import Callable


def vectorize_gradients(
    gradients: Dict[str, Tensor],
    include_batch_dim: bool = True,
    output_tensor: Optional[Tensor] = None,
    device: Optional[str] = "cuda",
) -> Tensor:
    """
    Vectorize gradients into a flattened tensor.
    
    This function takes a dictionary of gradients and returns a flattened tensor
    of shape [batch_size, num_params] or [num_params].
    
    Args:
        gradients: Dictionary containing gradient tensors to be vectorized
        include_batch_dim: Whether to include batch dimension in output
        output_tensor: Optional pre-allocated tensor to store results
        device: Device to store tensor on ("cuda" or "cpu")
        
    Returns:
        Flattened tensor of gradients
        
    Raises:
        ValueError: If parameter sizes don't match batch size
    """
    if output_tensor is None:
        if include_batch_dim:
            first_gradient = gradients[next(iter(gradients.keys()))]
            batch_size = first_gradient.shape[0]
            num_params = 0
            
            for param in gradients.values():
                if param.shape[0] != batch_size:
                    raise ValueError("Parameter row count doesn't match batch size.")
                num_params += int(param.numel() / batch_size)
            
            output_tensor = torch.empty(
                size=(batch_size, num_params),
                dtype=first_gradient.dtype,
                device=device,
            )
        else:
            num_params = 0
            for param in gradients.values():
                num_params += int(param.numel())
            output_tensor = torch.empty(
                size=(num_params,), 
                dtype=next(iter(gradients.values())).dtype, 
                device=device
            )

    pointer = 0
    vector_dim = 1
    
    for param in gradients.values():
        if include_batch_dim:
            if len(param.shape) <= vector_dim:
                num_param = 1
                param_reshaped = param.data.reshape(-1, 1)
            else:
                num_param = param[0].numel()
                param_reshaped = param.flatten(start_dim=1).data
            
            output_tensor[:, pointer : pointer + num_param] = param_reshaped.to(device)
            pointer += num_param
        else:
            num_param = param.numel()
            output_tensor[pointer : pointer + num_param] = param.reshape(-1).to(device)
            pointer += num_param
    
    return output_tensor


def get_parameter_chunk_sizes(
    param_shape_list: List,
    batch_size: int,
) -> Tuple[int, int]:
    """
    Compute chunk size information for parameter projection.
    
    Args:
        param_shape_list: List of parameter shapes
        batch_size: Batch size for processing
        
    Returns:
        Tuple of (max_chunk_size, total_dimensions)
    """
    total_dim = sum(math.prod(shape) for shape in param_shape_list)
    max_chunk_size = min(total_dim, batch_size)
    
    return max_chunk_size, total_dim


def flatten_parameters(parameters: Dict[str, Tensor]) -> Tensor:
    """
    Flatten a dictionary of parameters into a single tensor.
    
    Args:
        parameters: Dictionary of parameter tensors
        
    Returns:
        Flattened parameter tensor
    """
    flattened = []
    for param in parameters.values():
        flattened.append(param.flatten())
    
    return torch.cat(flattened)


def unflatten_parameters(
    flattened_tensor: Tensor, 
    model: torch.nn.Module
) -> Dict[str, Tensor]:
    """
    Unflatten a tensor back into a dictionary of parameters.
    
    Args:
        flattened_tensor: Flattened parameter tensor
        model: Model to get parameter structure from
        
    Returns:
        Dictionary of unflattened parameters
    """
    unflattened = {}
    start_idx = 0
    
    for name, param in model.named_parameters():
        param_size = param.numel()
        param_tensor = flattened_tensor[start_idx:start_idx + param_size]
        unflattened[name] = param_tensor.reshape(param.shape)
        start_idx += param_size
    
    return unflattened


def create_parameter_generator(model: torch.nn.Module) -> Callable:
    """
    Create a generator function for model parameters.
    
    Args:
        model: Neural network model
        
    Returns:
        Generator function that yields parameters
    """
    def generator() -> Tensor:
        """Generate flattened parameters."""
        parameters = dict(model.named_parameters())
        return flatten_parameters(parameters)
    
    return generator


def unflatten_parameters_layerwise(
    tensors: Tuple[Tensor, ...],
    model: torch.nn.Module,
) -> Dict[str, Tensor]:
    """
    Unflatten tensors layer by layer.
    
    Args:
        tensors: Tuple of flattened tensors
        model: Model to get layer structure from
        
    Returns:
        Dictionary of unflattened parameters by layer
    """
    unflattened = {}
    tensor_idx = 0
    
    for name, param in model.named_parameters():
        if tensor_idx < len(tensors):
            param_size = param.numel()
            param_tensor = tensors[tensor_idx][:param_size]
            unflattened[name] = param_tensor.reshape(param.shape)
            tensor_idx += 1
    
    return unflattened


def create_flattened_function(
    model: torch.nn.Module, 
    param_index: int = 0
) -> Callable:
    """
    Create a function that works with flattened parameters.
    
    Args:
        model: Neural network model
        param_index: Index of parameter to flatten
        
    Returns:
        Function decorator for flattened parameter handling
    """
    def flatten_function_wrapper(function: Callable) -> Callable:
        """
        Wrapper function that handles parameter flattening.
        
        Args:
            function: Function to wrap
            
        Returns:
            Wrapped function
        """
        @functools.wraps(function)
        def flattened_function(*args, **kwargs: Dict[str, Any]) -> Tensor:
            """Execute function with flattened parameters."""
            # Flatten parameters if needed
            if args and isinstance(args[param_index], dict):
                flattened_params = flatten_parameters(args[param_index])
                new_args = list(args)
                new_args[param_index] = flattened_params
                return function(*new_args, **kwargs)
            else:
                return function(*args, **kwargs)
        
        return flattened_function
    
    return flatten_function_wrapper


def create_partial_parameter_function(
    full_parameters: Dict[str, Tensor],
    layer_names: List[str],
    param_index: int = 0,
) -> Callable:
    """
    Create a function that works with partial parameters.
    
    Args:
        full_parameters: Full parameter dictionary
        layer_names: Names of layers to include
        param_index: Index of parameter argument
        
    Returns:
        Function decorator for partial parameter handling
    """
    def partial_parameter_wrapper(function: Callable) -> Callable:
        """
        Wrapper function that handles partial parameter selection.
        
        Args:
            function: Function to wrap
            
        Returns:
            Wrapped function
        """
        @functools.wraps(function)
        def partial_function(*args, **kwargs: Dict[str, Any]) -> torch.Tensor:
            """Execute function with partial parameters."""
            # Select only specified layers
            partial_params = {name: full_parameters[name] for name in layer_names if name in full_parameters}
            
            if args:
                new_args = list(args)
                new_args[param_index] = partial_params
                return function(*new_args, **kwargs)
            else:
                return function(partial_params, **kwargs)
        
        return partial_function
    
    return partial_parameter_wrapper


def compute_gradient_norm(gradients: Dict[str, Tensor]) -> float:
    """
    Compute the L2 norm of gradients.
    
    Args:
        gradients: Dictionary of gradient tensors
        
    Returns:
        L2 norm of gradients
    """
    total_norm = 0.0
    for grad in gradients.values():
        if grad is not None:
            param_norm = grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    return total_norm ** 0.5


def clip_gradients(
    gradients: Dict[str, Tensor], 
    max_norm: float
) -> Dict[str, Tensor]:
    """
    Clip gradients to a maximum norm.
    
    Args:
        gradients: Dictionary of gradient tensors
        max_norm: Maximum allowed norm
        
    Returns:
        Clipped gradients
    """
    total_norm = compute_gradient_norm(gradients)
    clip_coef = max_norm / (total_norm + 1e-6)
    
    if clip_coef < 1:
        clipped_gradients = {}
        for name, grad in gradients.items():
            if grad is not None:
                clipped_gradients[name] = grad * clip_coef
            else:
                clipped_gradients[name] = grad
        return clipped_gradients
    
    return gradients


def compute_parameter_count(model: torch.nn.Module) -> int:
    """
    Compute the total number of parameters in a model.
    
    Args:
        model: Neural network model
        
    Returns:
        Total number of parameters
    """
    return sum(p.numel() for p in model.parameters())


def compute_trainable_parameter_count(model: torch.nn.Module) -> int:
    """
    Compute the number of trainable parameters in a model.
    
    Args:
        model: Neural network model
        
    Returns:
        Number of trainable parameters
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_model_device(model: torch.nn.Module) -> torch.device:
    """
    Get the device of a model.
    
    Args:
        model: Neural network model
        
    Returns:
        Device of the model
    """
    return next(model.parameters()).device


def move_model_to_device(model: torch.nn.Module, device: torch.device) -> torch.nn.Module:
    """
    Move a model to a specific device.
    
    Args:
        model: Neural network model
        device: Target device
        
    Returns:
        Model moved to target device
    """
    return model.to(device)


def create_parameter_dict(model: torch.nn.Module) -> Dict[str, Tensor]:
    """
    Create a dictionary of model parameters.
    
    Args:
        model: Neural network model
        
    Returns:
        Dictionary of parameter tensors
    """
    return dict(model.named_parameters())


def update_model_parameters(
    model: torch.nn.Module, 
    parameters: Dict[str, Tensor]
) -> None:
    """
    Update model parameters with new values.
    
    Args:
        model: Neural network model
        parameters: New parameter values
    """
    for name, param in model.named_parameters():
        if name in parameters:
            param.data = parameters[name].data.clone()


def compute_hessian_vector_product(
    model: torch.nn.Module,
    loss_function: Callable,
    vector: Tensor,
    damping: float = 0.0
) -> Tensor:
    """
    Compute Hessian-vector product.
    
    Args:
        model: Neural network model
        loss_function: Loss function
        vector: Vector to multiply with Hessian
        damping: Damping parameter
        
    Returns:
        Hessian-vector product
    """
    # Compute gradients
    gradients = torch.autograd.grad(loss_function, model.parameters(), create_graph=True)
    gradients = torch.cat([g.flatten() for g in gradients])
    
    # Compute Hessian-vector product
    hvp = torch.autograd.grad(torch.dot(gradients, vector), model.parameters())
    hvp = torch.cat([h.flatten() for h in hvp])
    
    # Add damping
    if damping > 0:
        hvp += damping * vector
    
    return hvp 