"""
Influence Attribution Computation Module

This module provides functions for computing influence attribution scores in multi-task learning.
"""

import torch
import torch.nn as nn
from typing import List, Tuple, Optional, Union
import torch.nn.functional as F


def compute_model_output(model: Union[torch.Tensor, nn.Module], inputs: torch.Tensor) -> torch.Tensor:
    """
    Compute model outputs for given inputs.
    
    Args:
        model: Model (tensor or neural network)
        inputs: Input tensor
        
    Returns:
        Model outputs
    """
    if isinstance(model, torch.Tensor):
        return inputs @ model
    else:
        return model(inputs)


def compute_loss_function(
    model: Union[torch.Tensor, nn.Module], 
    inputs: torch.Tensor, 
    targets: torch.Tensor, 
    regularization_param: torch.Tensor, 
    lambda_reg: float, 
    output_dim: int, 
    num_samples: int, 
    link_function: str, 
    device: torch.device
) -> torch.Tensor:
    """
    Compute loss function with regularization.
    
    Args:
        model: Model parameters or neural network
        inputs: Input features
        targets: Target labels
        regularization_param: Regularization parameter
        lambda_reg: Regularization strength
        output_dim: Output dimension
        num_samples: Number of samples
        link_function: Link function type ('linear' or 'logistic')
        device: Computation device
        
    Returns:
        Computed loss value
    """
    if targets.shape != torch.Size([]):
        if link_function == 'linear':
            return (torch.sum(torch.square(compute_model_output(model, inputs) - targets.squeeze()))) / len(targets) + lambda_reg * torch.sum(torch.square(model - regularization_param))
        elif link_function == 'logistic':
            if targets.shape[-1] > 1:  # One-hot encoding
                outputs = compute_model_output(model, inputs)
                exp_outputs = torch.exp(outputs)
                normalization = exp_outputs @ torch.ones((output_dim, 1)).to(device)
                return (torch.sum(torch.log(normalization) - targets.squeeze() * outputs)) / len(targets) + lambda_reg * torch.sum(torch.square(model - regularization_param))
            else:
                outputs = compute_model_output(model, inputs)
                return (torch.sum(torch.log(1 + torch.exp(outputs)) - targets.squeeze() * outputs)) / len(targets) + lambda_reg * torch.sum(torch.square(model - regularization_param))
    else:
        if link_function == 'linear':
            return torch.square(compute_model_output(model, inputs) - targets.squeeze())
        elif link_function == 'logistic':
            outputs = compute_model_output(model, inputs)
            return torch.log(1 + torch.exp(outputs)) - targets.squeeze() * outputs


def compute_cross_entropy_loss(model: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, regularization_param: torch.Tensor) -> torch.Tensor:
    """
    Compute cross-entropy loss.
    
    Args:
        model: Neural network model
        inputs: Input features
        targets: Target labels
        regularization_param: Regularization parameter
        
    Returns:
        Cross-entropy loss
    """
    return nn.CrossEntropyLoss()(model(inputs), targets.squeeze())


def compute_influence_attribution(
    models: List[Union[torch.Tensor, nn.Module]], 
    regularization_param: torch.Tensor, 
    input_data: List[torch.Tensor], 
    target_data: List[torch.Tensor], 
    input_dim: int, 
    loss_function: str, 
    computation_mode: int, 
    lambda_reg: List[float], 
    output_dim: int, 
    num_samples: int, 
    device: torch.device, 
    link_function: str
) -> List[torch.Tensor]:
    """
    Compute influence attribution scores for multi-task learning.
    
    Args:
        models: List of models for each task
        regularization_param: Regularization parameter
        input_data: List of input tensors for each task
        target_data: List of target tensors for each task
        input_dim: Input dimension
        loss_function: Loss function type
        computation_mode: Computation mode (0: simple, 1: full)
        lambda_reg: List of regularization strengths for each task
        output_dim: Output dimension
        num_samples: Number of samples
        device: Computation device
        link_function: Link function type
        
    Returns:
        List of attribution scores for each task
    """
    num_tasks = len(input_data)
    input_dim = input_data[0][0].shape[0]
    attribution_scores = []
    
    # Compute Hessian matrices
    hessian_ss, hessian_s_reg, hessian_reg_reg = _compute_hessian_matrices(
        models, regularization_param, input_data, target_data, input_dim, 
        output_dim, lambda_reg, num_samples, device, link_function
    )
    
    # Compute Hessian inverses
    hessian_inv_ss, hessian_inv_s_reg, hessian_inv_reg_reg = _compute_hessian_inverses(
        hessian_ss, hessian_s_reg, hessian_reg_reg, num_tasks, input_dim, device
    )
    
    gradient_norm = torch.zeros(1).to(device)
    
    for task_idx, (task_inputs, task_targets) in enumerate(zip(input_data, target_data)):
        task_attribution = torch.zeros([task_inputs.shape[0], num_tasks, input_dim]).to(device)
        
        for sample_idx, (sample_input, sample_target) in enumerate(zip(task_inputs, task_targets)):
            # Compute loss and gradients
            loss = compute_loss_function(
                models[task_idx], sample_input, sample_target.squeeze(), 
                regularization_param, lambda_reg[task_idx], output_dim, 
                task_inputs.shape[0], link_function, device
            )
            loss.backward()
            
            gradient_norm += torch.linalg.norm(models[task_idx].grad)
            
            # Compute attribution scores for each task
            for target_task in range(num_tasks):
                if computation_mode == 0:
                    score = hessian_inv_ss[target_task][task_idx] @ models[task_idx].grad
                else:
                    score = hessian_inv_ss[target_task][task_idx] @ models[task_idx].grad
                    score += hessian_inv_s_reg[target_task] @ regularization_param.grad
                
                task_attribution[sample_idx][target_task] = -score / 2
            
            # Reset gradients
            models[task_idx].grad.data.zero_()
        
        attribution_scores.append(task_attribution)
    
    print(f"Average gradient norm: {gradient_norm.mean()}")
    return attribution_scores


def compute_leave_one_out_attribution(
    models: List[Union[torch.Tensor, nn.Module]], 
    regularization_param: torch.Tensor, 
    input_data: List[torch.Tensor], 
    target_data: List[torch.Tensor], 
    removed_task: int, 
    loo_data: Tuple[torch.Tensor, torch.Tensor], 
    input_dim: int, 
    loss_function: str, 
    computation_mode: int, 
    lambda_reg: List[float], 
    output_dim: int, 
    num_samples: int, 
    device: torch.device, 
    link_function: str
) -> torch.Tensor:
    """
    Compute leave-one-out influence attribution.
    
    Args:
        models: List of models for each task
        regularization_param: Regularization parameter
        input_data: List of input tensors for each task
        target_data: List of target tensors for each task
        removed_task: Index of removed task
        loo_data: Leave-one-out data (inputs, targets)
        input_dim: Input dimension
        loss_function: Loss function type
        computation_mode: Computation mode
        lambda_reg: List of regularization strengths
        output_dim: Output dimension
        num_samples: Number of samples
        device: Computation device
        link_function: Link function type
        
    Returns:
        Attribution scores for leave-one-out scenario
    """
    num_tasks = len(input_data)
    input_dim = input_data[0][0].shape[0]
    
    # Compute Hessian matrices
    hessian_ss, hessian_s_reg, hessian_reg_reg = _compute_hessian_matrices(
        models, regularization_param, input_data, target_data, input_dim, 
        output_dim, lambda_reg, num_samples, device, link_function
    )
    
    # Compute Hessian inverses
    hessian_inv_ss, hessian_inv_s_reg, hessian_inv_reg_reg = _compute_hessian_inverses(
        hessian_ss, hessian_s_reg, hessian_reg_reg, num_tasks, input_dim, device
    )
    
    attribution_scores = torch.zeros([num_tasks, input_dim], dtype=torch.float32).to(device)
    
    # Compute loss and gradients for removed task
    loss = compute_loss_function(
        models[removed_task], loo_data[0], loo_data[1], 
        regularization_param, lambda_reg[removed_task], output_dim, 1, link_function, device
    )
    loss.backward()
    
    # Compute attribution scores
    for target_task in range(num_tasks):
        if computation_mode == 0:
            score = hessian_inv_ss[target_task][removed_task] @ models[removed_task].grad
        else:
            score = hessian_inv_ss[target_task][removed_task] @ models[removed_task].grad
            score += hessian_inv_s_reg[target_task] @ regularization_param.grad
        
        attribution_scores[target_task] = -score / 2
    
    # Reset gradients
    models[removed_task].grad.data.zero_()
    
    return attribution_scores


def _compute_hessian_matrices(
    models: List[Union[torch.Tensor, nn.Module]], 
    regularization_param: torch.Tensor, 
    input_data: List[torch.Tensor], 
    target_data: List[torch.Tensor], 
    input_dim: int, 
    output_dim: int, 
    lambda_reg: List[float], 
    num_samples: int, 
    device: torch.device, 
    link_function: str
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute Hessian matrices for influence computation.
    
    Args:
        models: List of models for each task
        regularization_param: Regularization parameter
        input_data: List of input tensors for each task
        target_data: List of target tensors for each task
        input_dim: Input dimension
        output_dim: Output dimension
        lambda_reg: List of regularization strengths
        num_samples: Number of samples
        device: Computation device
        link_function: Link function type
        
    Returns:
        Tuple of Hessian matrices (H_ss, H_s_reg, H_reg_reg)
    """
    num_tasks = len(input_data)
    parameters = tuple(models) + (regularization_param,)
    
    def total_loss_function(*params):
        total_loss = torch.zeros(1).to(device)
        for task_idx in range(num_tasks):
            total_loss += compute_loss_function(
                params[task_idx], input_data[task_idx], target_data[task_idx], 
                params[-1], lambda_reg[task_idx], output_dim, num_samples, link_function, device
            )
        return total_loss
    
    hessian = torch.autograd.functional.hessian(total_loss_function, parameters)
    
    hessian_ss = torch.zeros([num_tasks, input_dim, input_dim]).to(device)
    hessian_s_reg = torch.zeros([num_tasks, input_dim, input_dim]).to(device)
    hessian_reg_reg = torch.zeros([input_dim, input_dim]).to(device)
    
    for task_idx in range(num_tasks):
        hessian_ss[task_idx] += hessian[task_idx][task_idx]
        hessian_s_reg[task_idx] += hessian[task_idx][-1]
    
    hessian_reg_reg += hessian[-1][-1]
    
    return hessian_ss, hessian_s_reg, hessian_reg_reg


def _compute_hessian_inverses(
    hessian_ss: torch.Tensor, 
    hessian_s_reg: torch.Tensor, 
    hessian_reg_reg: torch.Tensor, 
    num_tasks: int, 
    input_dim: int, 
    device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute inverse of Hessian matrices using block matrix inversion.
    
    Args:
        hessian_ss: Hessian matrix for model parameters
        hessian_s_reg: Cross Hessian matrix
        hessian_reg_reg: Hessian matrix for regularization parameter
        num_tasks: Number of tasks
        input_dim: Input dimension
        device: Computation device
        
    Returns:
        Tuple of inverse Hessian matrices
    """
    hessian_inv_ss = torch.zeros([num_tasks, num_tasks, input_dim, input_dim]).to(device)
    hessian_inv_s_reg = torch.zeros_like(hessian_s_reg).to(device)
    
    # Compute Schur complement
    schur_complement = hessian_reg_reg.clone().detach().to(device)
    cached_inverses = []
    
    for task_idx in range(num_tasks):
        inv = torch.inverse(hessian_ss[task_idx])
        schur_complement -= hessian_s_reg[task_idx].t() @ inv @ hessian_s_reg[task_idx]
        cached_inverses.append(inv)
    
    hessian_inv_reg_reg = torch.inverse(schur_complement)
    cached_matrices = {}
    
    for task_idx in range(num_tasks):
        cached_matrices[task_idx] = cached_inverses[task_idx] @ hessian_s_reg[task_idx]
        hessian_inv_s_reg[task_idx] = -cached_matrices[task_idx] @ hessian_inv_reg_reg
        hessian_inv_ss[task_idx][task_idx] += cached_inverses[task_idx]
        
        for other_task in range(task_idx + 1):
            hessian_inv_ss[task_idx][other_task] -= hessian_inv_s_reg[task_idx] @ cached_matrices[other_task].t()
            hessian_inv_ss[other_task][task_idx] = hessian_inv_ss[task_idx][other_task].clone().t()
    
    return hessian_inv_ss, hessian_inv_s_reg, hessian_inv_reg_reg 