"""
Base Classes for Influence Function Computation

This module provides the foundational classes for implementing influence function methods.
"""

import abc
from typing import Any, List, Optional, Dict, Iterator, Tuple
from functools import lru_cache
import concurrent.futures

import numpy as np
import torch
from torch import nn
from torch.utils import data


def _set_attribute(obj, names, val):
    """
    Set nested attribute on an object.
    
    Args:
        obj: Object to set attribute on
        names: List of attribute names
        val: Value to set
    """
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        _set_attribute(getattr(obj, names[0]), names[1:], val)


def _delete_attribute(obj, names):
    """
    Delete nested attribute on an object.
    
    Args:
        obj: Object to delete attribute from
        names: List of attribute names
    """
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        _delete_attribute(getattr(obj, names[0]), names[1:])


class BaseObjective(abc.ABC):
    """
    Abstract base class for defining training and test objectives.
    
    This class provides an interface for influence function computation by defining
    how model outputs and losses are computed for both training and test data.
    """

    @abc.abstractmethod
    def train_outputs(self, model: nn.Module, batch: Any) -> torch.Tensor:
        """
        Compute model outputs for training data.
        
        Args:
            model: Neural network model
            batch: Batch of training data
            
        Returns:
            Model outputs (e.g., logits, probabilities)
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def train_loss_on_outputs(self, outputs: torch.Tensor, batch: Any, reduction: str = 'mean') -> torch.Tensor:
        """
        Compute loss from model outputs.
        
        Args:
            outputs: Model outputs
            batch: Batch of training data
            reduction: Loss reduction method ('mean', 'sum', 'none')
            
        Returns:
            Computed loss value
            
        Note:
            For influence function computation, the loss should be mean-reduced.
            For Gauss-Newton Hessian approximation, the loss should be convex
            in the model outputs.
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def train_regularization(self, params: torch.Tensor) -> torch.Tensor:
        """
        Compute regularization loss.
        
        Args:
            params: Flattened model parameters
            
        Returns:
            Regularization loss value
        """
        raise NotImplementedError()

    def train_loss(self, model: nn.Module, params: torch.Tensor, batch: Any, reduction: str = 'mean') -> torch.Tensor:
        """
        Compute complete training loss including regularization.
        
        This method combines model outputs, loss computation, and regularization.
        It should not be overridden for most use cases.
        
        Args:
            model: Neural network model
            params: Flattened model parameters
            batch: Batch of training data
            reduction: Loss reduction method
            
        Returns:
            Complete training loss
        """
        outputs = self.train_outputs(model, batch)
        loss = self.train_loss_on_outputs(outputs, batch, reduction=reduction)
        reg_loss = self.train_regularization(params)
        return loss + reg_loss

    @abc.abstractmethod
    def test_loss(self, model: nn.Module, params: torch.Tensor, batch: Any) -> torch.Tensor:
        """
        Compute test loss.
        
        Args:
            model: Neural network model
            params: Flattened model parameters
            batch: Batch of test data
            
        Returns:
            Test loss value
        """
        raise NotImplementedError()


class DataLoaderMixin:
    """Mixin for handling data loader operations."""
    
    def __init__(self):
        self.train_loaders = {}
        self.test_loaders = {}


class ModelParamsMixin:
    """Mixin for handling model parameter operations."""
    
    def __init__(self):
        self.model = None


class BaseInfluenceModule(DataLoaderMixin, ModelParamsMixin, abc.ABC):
    """
    Abstract base class for influence function computation modules.
    
    This class provides the foundation for implementing various influence function
    computation methods such as Autograd, CG, LiSSA, and Trak.
    """

    def __init__(
            self,
            model: nn.Module,
            objective: BaseObjective,
            train_loader: Dict[str, data.DataLoader],
            test_loader: Dict[str, data.DataLoader],
            device: torch.device
    ):
        """
        Initialize the influence module.
        
        Args:
            model: Neural network model
            objective: Objective implementation
            train_loader: Training data loaders
            test_loader: Test data loaders
            device: Computation device
        """
        super().__init__()
        self.model = model
        self.objective = objective
        self.train_loaders = train_loader
        self.test_loaders = test_loader
        self.device = device

    @abc.abstractmethod
    def inverse_hvp(self, vec: torch.Tensor) -> torch.Tensor:
        """
        Compute inverse Hessian-vector product.
        
        Args:
            vec: Vector to multiply with inverse Hessian
            
        Returns:
            Inverse Hessian-vector product
        """
        raise NotImplementedError()

    def train_loss_grad(self, train_indices: List[int], task_name: str) -> torch.Tensor:
        """
        Compute gradient of training loss with respect to parameters.
        
        Args:
            train_indices: Indices of training samples
            task_name: Task name
            
        Returns:
            Gradient of training loss
        """
        self.objective.settn(task_name)
        grad = 0.0
        count = 0
        
        for batch, batch_size in self._loader_wrapper(train=True, task_name=task_name, subset=train_indices):
            def loss_fn(params):
                self._model_reinsert_params(self._reshape_like_params(params))
                return self.objective.train_loss(self.model, params, batch)
            
            grad_batch = torch.autograd.grad(loss_fn(self._flatten_params_like(self._model_make_functional())), 
                                           self._flatten_params_like(self._model_make_functional()))[0]
            grad += grad_batch * batch_size
            count += batch_size
        
        return grad / count

    def test_loss_grad(self, test_indices: List[int], task_name: str) -> torch.Tensor:
        """
        Compute gradient of test loss with respect to parameters.
        
        Args:
            test_indices: Indices of test samples
            task_name: Task name
            
        Returns:
            Gradient of test loss
        """
        self.objective.settn(task_name)
        grad = 0.0
        count = 0
        
        for batch, batch_size in self._loader_wrapper(train=False, task_name=task_name, subset=test_indices):
            def loss_fn(params):
                self._model_reinsert_params(self._reshape_like_params(params))
                return self.objective.test_loss(self.model, params, batch)
            
            grad_batch = torch.autograd.grad(loss_fn(self._flatten_params_like(self._model_make_functional())), 
                                           self._flatten_params_like(self._model_make_functional()))[0]
            grad += grad_batch * batch_size
            count += batch_size
        
        return grad / count

    def compute_stest(self, test_indices: List[int], task_name: str) -> torch.Tensor:
        """
        Compute s_test vector for influence function computation.
        
        Args:
            test_indices: Indices of test samples
            task_name: Task name
            
        Returns:
            s_test vector
        """
        test_grad = self.test_loss_grad(test_indices, task_name)
        return self.inverse_hvp(test_grad)

    def influences(
            self,
            train_indices: List[int],
            train_task: str,
            test_indices: List[int],
            test_task: str,
            stest: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Compute influence scores.
        
        Args:
            train_indices: Indices of training samples
            train_task: Training task name
            test_indices: Indices of test samples
            test_task: Test task name
            stest: Pre-computed s_test vector (optional)
            
        Returns:
            Influence scores
        """
        if stest is None:
            stest = self.compute_stest(test_indices, test_task)
        
        influences = []
        for train_idx in train_indices:
            train_grad = self.train_loss_grad([train_idx], train_task)
            influence = -torch.dot(stest, train_grad)
            influences.append(influence)
        
        return torch.tensor(influences)

    def _model_params(self, with_names: bool = True):
        """Get model parameters."""
        if with_names:
            return dict(self.model.named_parameters())
        else:
            return list(self.model.parameters())

    def _model_make_functional(self):
        """Convert model to functional form."""
        params = self._model_params()
        return params

    def _model_reinsert_params(self, params, register: bool = False):
        """Reinsert parameters into model."""
        for name, param in self._model_params().items():
            if name in params:
                param.data = params[name].data

    def _flatten_params_like(self, params_like):
        """Flatten parameters into a single vector."""
        return torch.cat([p.flatten() for p in params_like.values()])

    def _reshape_like_params(self, vec):
        """Reshape vector back to parameter structure."""
        params = self._model_params()
        reshaped = {}
        start_idx = 0
        
        for name, param in params.items():
            param_size = param.numel()
            reshaped[name] = vec[start_idx:start_idx + param_size].reshape(param.shape)
            start_idx += param_size
        
        return reshaped

    def _transfer_to_device(self, batch):
        """Transfer batch to device."""
        if isinstance(batch, (tuple, list)):
            return tuple(self._transfer_to_device(item) for item in batch)
        elif isinstance(batch, dict):
            return {key: self._transfer_to_device(value) for key, value in batch.items()}
        elif isinstance(batch, torch.Tensor):
            return batch.to(self.device)
        else:
            return batch

    def _loader_wrapper(self, train: bool = True, task_name: str = None, batch_size: Optional[int] = None, 
                       subset: Optional[List[int]] = None, sample_n_batches: int = -1):
        """
        Wrapper for data loader iteration.
        
        Args:
            train: Whether to use training data
            task_name: Task name
            batch_size: Batch size
            subset: Subset of indices
            sample_n_batches: Number of batches to sample
            
        Yields:
            Tuple of (batch, batch_size)
        """
        loaders = self.train_loaders if train else self.test_loaders
        
        if task_name is None:
            task_name = list(loaders.keys())[0]
        
        loader = loaders[task_name]
        
        if subset is not None:
            dataset = loader.dataset
            subset_dataset = data.Subset(dataset, subset)
            loader = data.DataLoader(subset_dataset, batch_size=loader.batch_size, shuffle=False)
        
        if batch_size is not None:
            loader = data.DataLoader(loader.dataset, batch_size=batch_size, shuffle=False)
        
        batch_count = 0
        for batch in loader:
            if sample_n_batches > 0 and batch_count >= sample_n_batches:
                break
            
            batch = self._transfer_to_device(batch)
            batch_size = batch[0].size(0) if isinstance(batch, (tuple, list)) else batch.size(0)
            
            yield batch, batch_size
            batch_count += 1

    def _joint_loader_wrapper(
        self,
        train: bool = True,
        batch_size: Optional[int] = None,
        subset: Optional[List[int]] = None,
        sample_n_batches: int = -1
    ) -> Iterator[Tuple[Dict[str, Any], int]]:
        """
        Wrapper for joint data loader iteration across all tasks.
        
        Args:
            train: Whether to use training data
            batch_size: Batch size
            subset: Subset of indices
            sample_n_batches: Number of batches to sample
            
        Yields:
            Tuple of (batch_dict, batch_size)
        """
        loaders = self.train_loaders if train else self.test_loaders
        
        # Create joint dataset
        all_datasets = []
        task_names = list(loaders.keys())
        
        for task_name in task_names:
            dataset = loaders[task_name].dataset
            if subset is not None:
                dataset = data.Subset(dataset, subset)
            all_datasets.append(dataset)
        
        joint_dataset = data.ConcatDataset(all_datasets)
        joint_loader = data.DataLoader(joint_dataset, batch_size=batch_size or 32, shuffle=False)
        
        batch_count = 0
        for batch in joint_loader:
            if sample_n_batches > 0 and batch_count >= sample_n_batches:
                break
            
            batch = self._transfer_to_device(batch)
            batch_size = batch[0].size(0) if isinstance(batch, (tuple, list)) else batch.size(0)
            
            # Split batch by task
            batch_dict = {}
            start_idx = 0
            for i, task_name in enumerate(task_names):
                task_dataset_size = len(all_datasets[i])
                end_idx = start_idx + task_dataset_size
                batch_dict[task_name] = batch[start_idx:end_idx]
                start_idx = end_idx
            
            yield batch_dict, batch_size
            batch_count += 1

    def _random_loader_wrapper(self, train: bool = True, batch_size: Optional[int] = None, 
                              subset: Optional[List[int]] = None, sample_n_batches: int = -1):
        """
        Wrapper for random data loader iteration.
        
        Args:
            train: Whether to use training data
            batch_size: Batch size
            subset: Subset of indices
            sample_n_batches: Number of batches to sample
            
        Yields:
            Tuple of (batch, batch_size)
        """
        loaders = self.train_loaders if train else self.test_loaders
        task_name = list(loaders.keys())[0]
        
        for batch, batch_size in self._loader_wrapper(train, task_name, batch_size, subset, sample_n_batches):
            yield batch, batch_size

    def _loss_grad_loader_wrapper(self, train: bool = True, task_name: str = None, **kwargs):
        """Wrapper for loss gradient computation."""
        for batch, batch_size in self._loader_wrapper(train, task_name, **kwargs):
            yield batch, batch_size

    def _loss_grad(self, indices: List[int], train: bool = True, task_name: str = None):
        """Compute loss gradient for given indices."""
        if train:
            return self.train_loss_grad(indices, task_name)
        else:
            return self.test_loss_grad(indices, task_name)

    def _hvp_at_batch(self, batches, flat_params, vec, gnh: bool = False, explicit: bool = True):
        """
        Compute Hessian-vector product at a batch.
        
        Args:
            batches: Batch data
            flat_params: Flattened parameters
            vec: Vector to multiply
            gnh: Whether to use Gauss-Newton Hessian
            explicit: Whether to use explicit computation
            
        Returns:
            Hessian-vector product
        """
        def f(theta_):
            self._model_reinsert_params(self._reshape_like_params(theta_))
            return self.objective.train_loss(self.model, theta_, batches)

        def out_f(theta_):
            self._model_reinsert_params(self._reshape_like_params(theta_))
            return self.objective.train_outputs(self.model, batches)

        def loss_f(out_):
            return self.objective.train_loss_on_outputs(out_, batches)

        def reg_f(theta_):
            return self.objective.train_regularization(theta_)

        if explicit:
            if gnh:
                # Gauss-Newton Hessian
                outputs = out_f(flat_params)
                jacobian = torch.autograd.grad(outputs, flat_params, create_graph=True)[0]
                hessian = torch.autograd.functional.hessian(loss_f, outputs)
                return torch.matmul(jacobian.t(), torch.matmul(hessian, torch.matmul(jacobian, vec)))
            else:
                # Full Hessian
                hessian = torch.autograd.functional.hessian(f, flat_params)
                return torch.matmul(hessian, vec)
        else:
            # Implicit computation
            grad = torch.autograd.grad(f(flat_params), flat_params, create_graph=True)[0]
            return torch.autograd.grad(torch.dot(grad, vec), flat_params)[0]

    def validate_task_names(self, task_names: List[str]) -> None:
        """
        Validate task names.
        
        Args:
            task_names: List of task names to validate
        """
        available_tasks = list(self.train_loaders.keys())
        for task_name in task_names:
            if task_name not in available_tasks:
                raise ValueError(f"Task '{task_name}' not found. Available tasks: {available_tasks}")

    def _compute_task_hessian(self, task_name: str, flat_params: torch.Tensor) -> torch.Tensor:
        """
        Compute Hessian for a specific task.
        
        Args:
            task_name: Task name
            flat_params: Flattened parameters
            
        Returns:
            Task-specific Hessian
        """
        self.objective.settn(task_name)
        hessian = 0.0
        
        for batch, batch_size in self._loader_wrapper(train=True, task_name=task_name):
            hessian_batch = self._hvp_at_batch(batch, flat_params, torch.eye(flat_params.size(0), device=self.device))
            hessian += hessian_batch * batch_size
        
        return hessian

    @lru_cache(maxsize=32)
    def _compute_batch_hessian(self, batch_key: str, params_key: str) -> torch.Tensor:
        """
        Compute Hessian for a specific batch with caching.
        
        Args:
            batch_key: Batch identifier
            params_key: Parameters identifier
            
        Returns:
            Batch-specific Hessian
        """
        # Implementation would depend on specific caching strategy
        pass

    def compute_multi_task_hessian(self, use_parallel: bool = False) -> torch.Tensor:
        """
        Compute Hessian across all tasks.
        
        Args:
            use_parallel: Whether to use parallel computation
            
        Returns:
            Multi-task Hessian
        """
        flat_params = self._flatten_params_like(self._model_make_functional())
        
        if use_parallel:
            with concurrent.futures.ThreadPoolExecutor() as executor:
                futures = []
                for task_name in self.train_loaders.keys():
                    future = executor.submit(self._compute_task_hessian, task_name, flat_params)
                    futures.append(future)
                
                hessians = [future.result() for future in futures]
                return sum(hessians)
        else:
            hessian = 0.0
            for task_name in self.train_loaders.keys():
                hessian += self._compute_task_hessian(task_name, flat_params)
            return hessian


class LiSSAInfluenceModule(BaseInfluenceModule):
    """
    LiSSA (Linear time Stochastic Second-Order Algorithm) influence module.
    
    This module implements the LiSSA algorithm for efficient influence function computation.
    """

    def __init__(self, *args, enable_monitoring: bool = False, **kwargs):
        """
        Initialize LiSSA influence module.
        
        Args:
            *args: Positional arguments
            enable_monitoring: Whether to enable convergence monitoring
            **kwargs: Keyword arguments
        """
        super().__init__(*args, **kwargs)
        self.enable_monitoring = enable_monitoring
        self.convergence_history = []

    def _log_convergence(self, iteration: int, error: float) -> None:
        """
        Log convergence information.
        
        Args:
            iteration: Current iteration
            error: Current error
        """
        if self.enable_monitoring:
            self.convergence_history.append((iteration, error))
            print(f"Iteration {iteration}: Error = {error:.6f}") 