"""
Influence Function Computation Modules

This module provides various implementations of influence function computation methods.
"""

import logging
from typing import Callable, Optional, List, Dict, Any, Tuple
import time
import numpy as np
import scipy.sparse.linalg as sparse_linalg
import torch
from torch import nn
from torch.utils import data
import gc
import sys
import matplotlib.pyplot as plt

from .base import BaseInfluenceModule, BaseObjective
from .projection import random_project

# Default projector configuration
DEFAULT_PROJECTOR_CONFIG = {
    "proj_dim": 512,
    "proj_max_batch_size": 32,
    "proj_seed": 0,
    "device": "cpu",
    "use_half_precision": False,
}


class AutogradInfluenceModule(BaseInfluenceModule):
    """
    Influence module that computes inverse-Hessian vector products using direct matrix inversion.
    
    This module forms and inverts the risk Hessian matrix directly using PyTorch autograd utilities.
    It provides exact results but scales poorly with model size.
    
    Args:
        model: Neural network model
        objective: Objective implementation
        train_loader: Training data loaders
        test_loader: Test data loaders
        device: Computation device
        damp: Damping strength for numerical stability
        task_names: List of task names
        check_eigvals: Whether to check positive definiteness
        
    Warning:
        This module scales poorly with model parameters. Computing the Hessian takes O(nd²) time
        and inverting it takes O(d³) time, where n is dataset size and d is parameter count.
    """

    def __init__(
            self,
            model: nn.Module,
            objective: BaseObjective,
            train_loader: Dict[str, data.DataLoader],
            test_loader: Dict[str, data.DataLoader],
            device: torch.device,
            damp: float,
            task_names: List[str],
            check_eigvals: bool = False
    ):
        """
        Initialize Autograd influence module.
        
        Args:
            model: Neural network model
            objective: Objective implementation
            train_loader: Training data loaders
            test_loader: Test data loaders
            device: Computation device
            damp: Damping strength
            task_names: List of task names
            check_eigvals: Whether to check eigenvalues
        """
        super().__init__(
            model=model,
            objective=objective,
            train_loader=train_loader,
            test_loader=test_loader,
            device=device,
        )
        self.task_names = task_names
        self.damp = damp
        
        # Compute Hessian matrix
        params = self._model_make_functional()
        flat_params = self._flatten_params_like(params)
        
        param_dim = flat_params.shape[0]
        hessian = torch.zeros(param_dim, param_dim, device=device)
        
        print("Computing Hessian matrix...")
        for task_name in self.task_names:
            self.objective.settn(task_name)
            print(f"Processing task: {task_name}")
            
            for batch, batch_size in self._loader_wrapper(train=True, task_name=task_name):
                def loss_function(theta):
                    self._model_reinsert_params(self._reshape_like_params(theta))
                    return self.objective.train_loss(self.model, theta, batch)
                
                hessian_batch = torch.autograd.functional.hessian(loss_function, flat_params).detach()
                hessian += hessian_batch * batch_size
            
            # Clear GPU memory
            torch.cuda.empty_cache()
            gc.collect()
        
        # Save Hessian for debugging
        torch.save(hessian, 'hessian_matrix.pt')
        
        # Normalize by total dataset size
        total_samples = sum(len(loader.dataset) for loader in self.train_loaders.values())
        hessian = hessian / total_samples
        
        # Add damping
        hessian += self.damp * torch.eye(param_dim, device=device)
        
        # Check positive definiteness if requested
        if check_eigvals:
            eigenvalues = torch.linalg.eigvals(hessian)
            if torch.any(eigenvalues.real <= 0):
                raise ValueError("Hessian is not positive definite. Increase damping parameter.")
        
        # Compute inverse
        self.hessian_inv = torch.inverse(hessian)
        
        # Restore model parameters
        with torch.no_grad():
            self._model_reinsert_params(self._reshape_like_params(flat_params), register=True)
            print(f"Total training samples: {total_samples}")

    def inverse_hvp(self, vec: torch.Tensor) -> torch.Tensor:
        """
        Compute inverse Hessian-vector product.
        
        Args:
            vec: Vector to multiply with inverse Hessian
            
        Returns:
            Inverse Hessian-vector product
        """
        return torch.matmul(self.hessian_inv, vec)


class CGInfluenceModule(BaseInfluenceModule):
    """
    Influence module using Conjugate Gradient method for inverse Hessian-vector products.
    
    This module uses the conjugate gradient algorithm to efficiently compute inverse
    Hessian-vector products without explicitly forming the Hessian matrix.
    
    Args:
        model: Neural network model
        objective: Objective implementation
        train_loader: Training data loaders
        test_loader: Test data loaders
        device: Computation device
        damp: Damping strength
        task_names: List of task names
        gnh: Whether to use Gauss-Newton Hessian approximation
        max_iter: Maximum CG iterations
        tol: Convergence tolerance
    """

    def __init__(
            self,
            model: nn.Module,
            objective: BaseObjective,
            train_loader: Dict[str, data.DataLoader],
            test_loader: Dict[str, data.DataLoader],
            device: torch.device,
            damp: float,
            task_names: List[str],
            gnh: bool = False,
            max_iter: int = 100,
            tol: float = 1e-6,
            **kwargs
    ):
        """
        Initialize CG influence module.
        
        Args:
            model: Neural network model
            objective: Objective implementation
            train_loader: Training data loaders
            test_loader: Test data loaders
            device: Computation device
            damp: Damping strength
            task_names: List of task names
            gnh: Whether to use Gauss-Newton Hessian
            max_iter: Maximum CG iterations
            tol: Convergence tolerance
            **kwargs: Additional arguments
        """
        super().__init__(
            model=model,
            objective=objective,
            train_loader=train_loader,
            test_loader=test_loader,
            device=device,
        )
        self.task_names = task_names
        self.damp = damp
        self.gnh = gnh
        self.max_iter = max_iter
        self.tol = tol

    def inverse_hvp(self, vec: torch.Tensor) -> torch.Tensor:
        """
        Compute inverse Hessian-vector product using conjugate gradient.
        
        Args:
            vec: Vector to multiply with inverse Hessian
            
        Returns:
            Inverse Hessian-vector product
        """
        def hvp_function(v):
            """Hessian-vector product function."""
            hvp = 0.0
            for task_name in self.task_names:
                self.objective.settn(task_name)
                for batch, batch_size in self._loader_wrapper(train=True, task_name=task_name):
                    hvp_batch = self._hvp_at_batch(batch, self._flatten_params_like(self._model_make_functional()), v, self.gnh)
                    hvp += hvp_batch * batch_size
            
            # Add damping
            hvp += self.damp * v
            return hvp
        
        # Use conjugate gradient
        result, info = sparse_linalg.cg(hvp_function, vec.cpu().numpy(), maxiter=self.max_iter, tol=self.tol)
        return torch.tensor(result, device=self.device)


class LiSSAInfluenceModule(BaseInfluenceModule):
    """
    LiSSA (Linear time Stochastic Second-Order Algorithm) influence module.
    
    This module implements the LiSSA algorithm for efficient influence function computation
    using stochastic approximation of inverse Hessian-vector products.
    
    Args:
        model: Neural network model
        objective: Objective implementation
        train_loader: Training data loaders
        test_loader: Test data loaders
        device: Computation device
        damp: Damping strength
        repeat: Number of LiSSA trials
        depth: Number of recursion steps
        scale: Scaling factor
        batch_size: Batch size for computation
        task_names: List of task names
        gnh: Whether to use Gauss-Newton Hessian
        explicit: Whether to use explicit computation
        enable_monitoring: Whether to enable convergence monitoring
        convergence_tol: Convergence tolerance
        max_scale_tries: Maximum scale adjustment attempts
    """

    def __init__(
            self,
            model: nn.Module,
            objective: BaseObjective,
            train_loader: Dict[str, data.DataLoader],
            test_loader: Dict[str, data.DataLoader],
            device: torch.device,
            damp: float,
            repeat: int,
            depth: int,
            scale: float,
            batch_size: int,
            task_names: List[str],
            gnh: bool = False,
            explicit: bool = False,
            enable_monitoring: bool = True,
            convergence_tol: float = 1e-5,
            max_scale_tries: int = 5
    ):
        """
        Initialize LiSSA influence module.
        
        Args:
            model: Neural network model
            objective: Objective implementation
            train_loader: Training data loaders
            test_loader: Test data loaders
            device: Computation device
            damp: Damping strength
            repeat: Number of LiSSA trials
            depth: Number of recursion steps
            scale: Scaling factor
            batch_size: Batch size for computation
            task_names: List of task names
            gnh: Whether to use Gauss-Newton Hessian
            explicit: Whether to use explicit computation
            enable_monitoring: Whether to enable convergence monitoring
            convergence_tol: Convergence tolerance
            max_scale_tries: Maximum scale adjustment attempts
        """
        super().__init__(
            model=model,
            objective=objective,
            train_loader=train_loader,
            test_loader=test_loader,
            device=device,
        )
        self.task_names = task_names
        self.damp = damp
        self.repeat = repeat
        self.depth = depth
        self.scale = scale
        self.batch_size = batch_size
        self.gnh = gnh
        self.explicit = explicit
        self.enable_monitoring = enable_monitoring
        self.convergence_tol = convergence_tol
        self.max_scale_tries = max_scale_tries
        self.convergence_history = []

    def _estimate_max_eigenvalue(self, num_power_iter: int = 10) -> float:
        """
        Estimate maximum eigenvalue of Hessian using power iteration.
        
        Args:
            num_power_iter: Number of power iterations
            
        Returns:
            Estimated maximum eigenvalue
        """
        flat_params = self._flatten_params_like(self._model_make_functional())
        vec = torch.randn_like(flat_params)
        vec = vec / torch.norm(vec)
        
        for _ in range(num_power_iter):
            hvp = 0.0
            for task_name in self.task_names:
                self.objective.settn(task_name)
                for batch, batch_size in self._loader_wrapper(train=True, task_name=task_name, batch_size=self.batch_size):
                    hvp_batch = self._hvp_at_batch(batch, flat_params, vec, self.gnh, self.explicit)
                    hvp += hvp_batch * batch_size
            
            hvp += self.damp * vec
            vec = hvp / torch.norm(hvp)
        
        return torch.dot(vec, hvp).item()

    def _adjust_scale(self, vec: torch.Tensor, flat_params: torch.Tensor) -> float:
        """
        Adjust scale parameter for numerical stability.
        
        Args:
            vec: Input vector
            flat_params: Flattened parameters
            
        Returns:
            Adjusted scale parameter
        """
        max_eigenvalue = self._estimate_max_eigenvalue()
        adjusted_scale = max(self.scale, 2.0 * max_eigenvalue)
        
        if self.enable_monitoring:
            print(f"Max eigenvalue: {max_eigenvalue:.6f}, Adjusted scale: {adjusted_scale:.6f}")
        
        return adjusted_scale

    def _log_convergence(self, iteration: int, error: float) -> None:
        """
        Log convergence information.
        
        Args:
            iteration: Current iteration
            error: Current error
        """
        if self.enable_monitoring:
            self.convergence_history.append((iteration, error))
            print(f"LiSSA Iteration {iteration}: Error = {error:.6f}")

    def inverse_hvp(self, vec: torch.Tensor) -> torch.Tensor:
        """
        Compute inverse Hessian-vector product using LiSSA algorithm.
        
        Args:
            vec: Vector to multiply with inverse Hessian
            
        Returns:
            Inverse Hessian-vector product
        """
        flat_params = self._flatten_params_like(self._model_make_functional())
        
        # Adjust scale if needed
        adjusted_scale = self._adjust_scale(vec, flat_params)
        
        # LiSSA algorithm
        result = torch.zeros_like(vec)
        
        for trial in range(self.repeat):
            if self.enable_monitoring:
                print(f"LiSSA Trial {trial + 1}/{self.repeat}")
            
            # Initialize recursion
            h_inv_v = vec / adjusted_scale
            
            for depth_iter in range(self.depth):
                # Compute Hessian-vector product
                hvp = 0.0
                for task_name in self.task_names:
                    self.objective.settn(task_name)
                    for batch, batch_size in self._loader_wrapper(train=True, task_name=task_name, batch_size=self.batch_size):
                        hvp_batch = self._hvp_at_batch(batch, flat_params, h_inv_v, self.gnh, self.explicit)
                        hvp += hvp_batch * batch_size
                
                hvp += self.damp * h_inv_v
                
                # Update recursion
                h_inv_v = (vec - hvp) / adjusted_scale + h_inv_v
                
                # Check convergence
                error = torch.norm(hvp - adjusted_scale * (h_inv_v - vec / adjusted_scale))
                self._log_convergence(depth_iter, error.item())
                
                if error < self.convergence_tol:
                    if self.enable_monitoring:
                        print(f"Converged at iteration {depth_iter}")
                    break
            
            result += h_inv_v
        
        return result / self.repeat

    def get_convergence_info(self) -> Dict:
        """
        Get convergence information.
        
        Returns:
            Dictionary with convergence history
        """
        return {
            'convergence_history': self.convergence_history,
            'final_error': self.convergence_history[-1][1] if self.convergence_history else None
        }


class TrakInfluenceModule(BaseInfluenceModule):
    """
    Trak (Training Data Attribution) influence module.
    
    This module implements the Trak algorithm for efficient influence function computation
    using random projections and gradient-based attribution.
    
    Args:
        model: Neural network model
        objective: Objective implementation
        train_loader: Training data loaders
        test_loader: Test data loaders
        device: Computation device
        task_names: List of task names
        load_checkpoints: Function to load model checkpoints
        num_checkpoints: Number of checkpoints to use
        proj_dim: Projection dimension
        seed: Random seed
        layer_names: Specific layer names to analyze
        regularization: Regularization strength
        normalize_grads: Whether to normalize gradients
        projector_config: Configuration for random projections
    """

    def __init__(
        self,
        model: nn.Module,
        objective: BaseObjective,
        train_loader: Dict[str, data.DataLoader],
        test_loader: Dict[str, data.DataLoader],
        device: torch.device,
        task_names: List[str],
        load_checkpoints: Callable = None,
        num_checkpoints: int = 1,
        proj_dim: int = 512,
        seed: int = 42,
        layer_names: Optional[List[str]] = None,
        regularization: float = 0.02,
        normalize_grads: bool = True,
        projector_config: Optional[Dict[str, Any]] = None
    ):
        """
        Initialize Trak influence module.
        
        Args:
            model: Neural network model
            objective: Objective implementation
            train_loader: Training data loaders
            test_loader: Test data loaders
            device: Computation device
            task_names: List of task names
            load_checkpoints: Function to load model checkpoints
            num_checkpoints: Number of checkpoints to use
            proj_dim: Projection dimension
            seed: Random seed
            layer_names: Specific layer names to analyze
            regularization: Regularization strength
            normalize_grads: Whether to normalize gradients
            projector_config: Configuration for random projections
        """
        super().__init__(
            model=model,
            objective=objective,
            train_loader=train_loader,
            test_loader=test_loader,
            device=device,
        )
        self.task_names = task_names
        self.load_checkpoints = load_checkpoints
        self.num_checkpoints = num_checkpoints
        self.proj_dim = proj_dim
        self.seed = seed
        self.layer_names = layer_names
        self.regularization = regularization
        self.normalize_grads = normalize_grads
        self.projector_config = projector_config or DEFAULT_PROJECTOR_CONFIG
        
        # Initialize caches
        self.train_gradients_cache = {}
        self.test_gradients_cache = {}
        self.projectors_cache = {}

    def _load_checkpoint(self, checkpoint_idx: int):
        """
        Load model checkpoint.
        
        Args:
            checkpoint_idx: Checkpoint index
            
        Returns:
            Model state dict
        """
        if self.load_checkpoints is not None:
            return self.load_checkpoints(checkpoint_idx)
        else:
            return self.model.state_dict()

    def _generate_projector(self, batch_size: int) -> torch.Tensor:
        """
        Generate random projector matrix.
        
        Args:
            batch_size: Batch size for projection
            
        Returns:
            Random projector matrix
        """
        config = self.projector_config
        proj_dim = config["proj_dim"]
        device = torch.device(config["device"])
        seed = config["proj_seed"]
        
        torch.manual_seed(seed)
        projector = torch.randn(batch_size, proj_dim, device=device)
        
        if config["use_half_precision"]:
            projector = projector.half()
        
        return projector

    def _compute_gradients(
        self, 
        batch_data: Any,
        parameters: Dict[str, torch.Tensor],
        train: bool = True,
        task_name: str = None
    ) -> torch.Tensor:
        """
        Compute gradients for batch data.
        
        Args:
            batch_data: Batch of data
            parameters: Model parameters
            train: Whether this is training data
            task_name: Task name
            
        Returns:
            Computed gradients
        """
        if task_name is not None:
            self.objective.settn(task_name)
        
        # Reinsert parameters
        self._model_reinsert_params(parameters)
        
        # Compute loss
        if train:
            loss = self.objective.train_loss(self.model, self._flatten_params_like(parameters), batch_data)
        else:
            loss = self.objective.test_loss(self.model, self._flatten_params_like(parameters), batch_data)
        
        # Compute gradients
        gradients = torch.autograd.grad(loss, parameters.values(), create_graph=False)
        
        # Flatten and concatenate gradients
        flat_gradients = torch.cat([g.flatten() for g in gradients if g is not None])
        
        # Normalize if requested
        if self.normalize_grads:
            norm = torch.norm(flat_gradients)
            if norm > 0:
                flat_gradients = flat_gradients / norm
        
        return flat_gradients

    def cache_train_data(self, task_name: str = None) -> None:
        """
        Cache training gradients for efficiency.
        
        Args:
            task_name: Specific task name (optional)
        """
        tasks_to_process = [task_name] if task_name else self.task_names
        
        for task in tasks_to_process:
            if task not in self.train_gradients_cache:
                self.train_gradients_cache[task] = {}
            
            print(f"Caching training gradients for task: {task}")
            
            for checkpoint_idx in range(self.num_checkpoints):
                checkpoint = self._load_checkpoint(checkpoint_idx)
                
                gradients_list = []
                for batch, batch_size in self._loader_wrapper(train=True, task_name=task):
                    gradients = self._compute_gradients(batch, checkpoint, train=True, task_name=task)
                    gradients_list.append(gradients)
                
                self.train_gradients_cache[task][checkpoint_idx] = torch.stack(gradients_list)

    def influences(
        self,
        use_cache: bool = True,
        threshold: float = None,
        verbose: bool = True,
        precision: torch.dtype = torch.float16
    ) -> torch.Tensor:
        """
        Compute influence scores using Trak algorithm.
        
        Args:
            use_cache: Whether to use cached gradients
            threshold: Threshold for influence scores
            verbose: Whether to print progress
            precision: Computation precision
            
        Returns:
            Influence scores
        """
        if use_cache:
            # Cache training data if not already cached
            for task_name in self.task_names:
                if task_name not in self.train_gradients_cache:
                    self.cache_train_data(task_name)
        
        # Initialize influence scores
        total_train_samples = sum(len(loader.dataset) for loader in self.train_loaders.values())
        total_test_samples = sum(len(loader.dataset) for loader in self.test_loaders.values())
        
        influence_scores = torch.zeros(total_train_samples, total_test_samples, device=self.device, dtype=precision)
        
        if verbose:
            print(f"Computing influence scores: {total_train_samples} train × {total_test_samples} test")
        
        # Compute influence scores for each checkpoint
        for checkpoint_idx in range(self.num_checkpoints):
            if verbose:
                print(f"Processing checkpoint {checkpoint_idx + 1}/{self.num_checkpoints}")
            
            checkpoint = self._load_checkpoint(checkpoint_idx)
            
            # Generate projector for this checkpoint
            if checkpoint_idx not in self.projectors_cache:
                max_batch_size = max(len(loader.dataset) for loader in self.train_loaders.values())
                self.projectors_cache[checkpoint_idx] = self._generate_projector(max_batch_size)
            
            projector = self.projectors_cache[checkpoint_idx]
            
            # Compute projected gradients for each task
            for task_name in self.task_names:
                if verbose:
                    print(f"  Processing task: {task_name}")
                
                # Get cached training gradients
                if use_cache and task_name in self.train_gradients_cache:
                    train_gradients = self.train_gradients_cache[task_name][checkpoint_idx]
                else:
                    # Compute training gradients on-the-fly
                    train_gradients = []
                    for batch, batch_size in self._loader_wrapper(train=True, task_name=task_name):
                        gradients = self._compute_gradients(batch, checkpoint, train=True, task_name=task_name)
                        train_gradients.append(gradients)
                    train_gradients = torch.stack(train_gradients)
                
                # Compute test gradients
                test_gradients = []
                for batch, batch_size in self._loader_wrapper(train=False, task_name=task_name):
                    gradients = self._compute_gradients(batch, checkpoint, train=False, task_name=task_name)
                    test_gradients.append(gradients)
                test_gradients = torch.stack(test_gradients)
                
                # Project gradients
                projected_train = torch.matmul(train_gradients, projector.t())
                projected_test = torch.matmul(test_gradients, projector.t())
                
                # Compute influence scores using projected gradients
                task_influences = torch.matmul(projected_train, projected_test.t())
                
                # Apply regularization
                if self.regularization > 0:
                    identity = torch.eye(projected_train.size(1), device=self.device, dtype=precision)
                    regularized_inv = torch.inverse(
                        torch.matmul(projected_train.t(), projected_train) + 
                        self.regularization * identity
                    )
                    task_influences = torch.matmul(
                        torch.matmul(projected_train, regularized_inv),
                        projected_test.t()
                    )
                
                # Store results
                start_train = sum(len(self.train_loaders[t].dataset) for t in self.task_names[:self.task_names.index(task_name)])
                end_train = start_train + len(self.train_loaders[task_name].dataset)
                start_test = sum(len(self.test_loaders[t].dataset) for t in self.task_names[:self.task_names.index(task_name)])
                end_test = start_test + len(self.test_loaders[task_name].dataset)
                
                influence_scores[start_train:end_train, start_test:end_test] = task_influences
        
        # Apply threshold if specified
        if threshold is not None:
            influence_scores = self._soft_threshold(influence_scores, threshold)
        
        return influence_scores

    def inverse_hvp(self, vec: torch.Tensor) -> torch.Tensor:
        """
        Compute inverse Hessian-vector product using Trak approximation.
        
        Args:
            vec: Vector to multiply with inverse Hessian
            
        Returns:
            Approximate inverse Hessian-vector product
        """
        # This is a simplified implementation for Trak
        # In practice, Trak doesn't directly compute inverse HVP
        return vec

    def _soft_threshold(self, influences: torch.Tensor, threshold: float = 0.01) -> torch.Tensor:
        """
        Apply soft thresholding to influence scores.
        
        Args:
            influences: Influence scores
            threshold: Threshold value
            
        Returns:
            Thresholded influence scores
        """
        return torch.where(
            torch.abs(influences) > threshold,
            influences,
            torch.zeros_like(influences)
        ) 