"""
Influence Analysis Module

This module provides tools for computing influence scores in multi-task learning scenarios.
"""

import os
import gc
import time
from typing import Optional, List, Tuple, Callable, Dict, Any
import torch
import torch.nn.functional as F

# Import torchinfluence components if available
try:
    from torchinfluence.torch_influence.modules import (
        AutogradInfluenceModule,
        CGInfluenceModule,
        LiSSAInfluenceModule,
        TrakInfluenceModule
    )
    from torchinfluence.torch_influence.base import BaseObjective
    TORCHINFLUENCE_AVAILABLE = True
except ImportError:
    TORCHINFLUENCE_AVAILABLE = False
    print("Warning: torchinfluence not available. Some functionality may be limited.")

torch.use_deterministic_algorithms(False)


class MultiTaskObjective(BaseObjective):
    """
    Objective class for multi-task learning influence computation.
    """
    
    def set_task_name(self, task_name: str) -> None:
        """
        Set current task name.
        
        Args:
            task_name: Name of the current task
        """
        self.task_name = task_name
    
    def train_outputs(self, model: torch.nn.Module, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        """
        Get model outputs for training data.
        
        Args:
            model: Neural network model
            batch: Input batch (features, labels)
            
        Returns:
            Model outputs for the current task
        """
        return model(batch[0])[self.task_name].squeeze()

    def train_loss_on_outputs(self, outputs: torch.Tensor, batch: Tuple[torch.Tensor, torch.Tensor], reduction: str = 'mean') -> torch.Tensor:
        """
        Compute training loss on model outputs.
        
        Args:
            outputs: Model outputs
            batch: Input batch (features, labels)
            reduction: Loss reduction method
            
        Returns:
            Computed loss value
        """
        return F.cross_entropy(outputs, batch[1].squeeze(), reduction=reduction)

    def train_regularization(self, params: torch.Tensor) -> torch.Tensor:
        """
        Compute regularization term for training.
        
        Args:
            params: Model parameters
            
        Returns:
            Regularization loss
        """
        return 0.0

    def test_loss(self, model: torch.nn.Module, params: torch.Tensor, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        """
        Compute test loss.
        
        Args:
            model: Neural network model
            params: Model parameters
            batch: Input batch (features, labels)
            
        Returns:
            Test loss value
        """
        return F.cross_entropy(model(batch[0])[self.task_name].squeeze(), batch[1].squeeze())


class MultiTaskRegressionObjective(BaseObjective):
    """
    Objective class for multi-task regression influence computation.
    """
    
    def set_task_name(self, task_name: str) -> None:
        """
        Set current task name.
        
        Args:
            task_name: Name of the current task
        """
        self.task_name = task_name
    
    def train_outputs(self, model: torch.nn.Module, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        """
        Get model outputs for training data.
        
        Args:
            model: Neural network model
            batch: Input batch (features, labels)
            
        Returns:
            Model outputs for the current task
        """
        return model(batch[0])[self.task_name].squeeze()
    
    def train_loss_on_outputs(self, outputs: torch.Tensor, batch: Tuple[torch.Tensor, torch.Tensor], reduction: str = 'mean') -> torch.Tensor:
        """
        Compute training loss on model outputs.
        
        Args:
            outputs: Model outputs
            batch: Input batch (features, labels)
            reduction: Loss reduction method
            
        Returns:
            Computed loss value
        """
        return F.mse_loss(outputs, batch[1].squeeze(), reduction=reduction)
    
    def train_regularization(self, params: torch.Tensor) -> torch.Tensor:
        """
        Compute regularization term for training.
        
        Args:
            params: Model parameters
            
        Returns:
            Regularization loss
        """
        return 0.0
    
    def test_loss(self, model: torch.nn.Module, params: torch.Tensor, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        """
        Compute test loss.
        
        Args:
            model: Neural network model
            params: Model parameters
            batch: Input batch (features, labels)
            
        Returns:
            Test loss value
        """
        return F.mse_loss(model(batch[0])[self.task_name].squeeze(), batch[1].squeeze())


class InfluenceAnalyzer:
    """
    Handler class for computing influence scores using different methods.
    """

    def __init__(self, trainer):
        """
        Initialize influence analyzer.
        
        Args:
            trainer: Multi-task trainer instance
        """
        self.model = trainer.trainer.MTLmodel.model
        self.task_num = trainer.task_num
        self.task_names = trainer.task_names
        self.training_size = trainer.training_size
        self.val_size = trainer.val_size
        self.batch_size = trainer.batch_size
        self.seed = trainer.seed
        self.device = trainer.device
        self.num_epochs = trainer.num_epochs
        self.begin_time = time.time()
        self.train_dataloaders = trainer.train_dataloaders
        self.val_dataloaders = trainer.val_dataloaders
        self.dataset = trainer.dataset

    def _save_scores(self, scores: torch.Tensor, method: str = '') -> None:
        """
        Save computed influence scores.
        
        Args:
            scores: Computed influence scores
            method: Method name for file naming
        """
        os.makedirs('results/influence_scores', exist_ok=True)
        end_time = time.time()
        print(f'Time used: {end_time - self.begin_time}')

        if self.dataset == 'face':
            save_path = f'results/influence_scores/seed_{self.seed}_size_{self.training_size}_val_{self.val_size}_tasks_{self.task_num}_epoch_{self.num_epochs}_bs_{self.batch_size}_{method}.pt'
        else:
            save_path = f'results/influence_scores/{self.dataset}_seed_{self.seed}_epoch_{self.num_epochs}_bs_{self.batch_size}_{method}.pt'
        
        torch.save(scores, save_path)

    def _compute_scores(self, module) -> torch.Tensor:
        """
        Compute influence scores for all task pairs.
        
        Args:
            module: Influence computation module
            
        Returns:
            Influence scores tensor
        """
        if self.dataset in ['office-home', 'office-31']:
            max_training_size = torch.tensor(self.training_size).max().item()
            scores = torch.zeros(self.task_num, self.task_num, max_training_size)
            for i, task in enumerate(self.task_names):
                for j, task2 in enumerate(self.task_names):
                    influences = module.influences(
                        torch.arange(self.training_size[i]), 
                        task, 
                        torch.arange(self.val_size[j]), 
                        task2
                    )
                    scores[i][j][:self.training_size[i]] = influences
        else:  # face, nyu, xtreme
            scores = torch.zeros(self.task_num, self.task_num, self.training_size)
            for i, task in enumerate(self.task_names):
                for j, task2 in enumerate(self.task_names):
                    influences = module.influences(
                        torch.arange(self.training_size), 
                        task, 
                        torch.arange(self.val_size), 
                        task2
                    )
                    scores[i][j] = influences
        return scores

    def run_autograd(self, save: bool = True) -> Optional[torch.Tensor]:
        """
        Run influence computation using AutogradInfluenceModule.
        
        Args:
            save: Whether to save the computed scores
            
        Returns:
            Computed influence scores if save=False, None otherwise
        """
        if not TORCHINFLUENCE_AVAILABLE:
            raise ImportError("torchinfluence is required for autograd influence computation")
        
        self.model.eval()        
        module = AutogradInfluenceModule(
            model=self.model.to(self.device),
            objective=MultiTaskObjective(),
            train_loader=self.train_dataloaders,
            test_loader=self.val_dataloaders,
            device=self.device,
            tns=self.task_names,
            damp=0.01,
            check_eigvals=True
        )

        torch.cuda.empty_cache()
        gc.collect()
        
        scores = self._compute_scores(module)
        if save:
            self._save_scores(scores, 'autograd')
        return scores if not save else None

    def run_lissa(
        self, 
        save: bool = True,
        damp: float = 0.05,
        scale: float = 50.0,
        depth: int = 50,
        repeat: int = 1
    ) -> Optional[torch.Tensor]:
        """
        Run influence computation using LiSSAInfluenceModule.
        
        Args:
            save: Whether to save the computed scores
            damp: Damping parameter for numerical stability (default: 0.05)
            scale: Scaling factor for LiSSA algorithm (default: 50.0)
            depth: Number of recursion steps (default: 50)
            repeat: Number of trials to average over (default: 1)
            
        Returns:
            Computed influence scores if save=False, None otherwise
        """
        if not TORCHINFLUENCE_AVAILABLE:
            raise ImportError("torchinfluence is required for LiSSA influence computation")
        
        self.model.eval()
        
        module = LiSSAInfluenceModule(
            model=self.model.to(self.device),
            objective=MultiTaskRegressionObjective(),
            train_loader=self.train_dataloaders,
            test_loader=self.val_dataloaders,
            device=self.device,
            tns=self.task_names,
            damp=damp,
            repeat=repeat,
            depth=depth,
            scale=scale,
            batch_size=32,
            enable_monitoring=True,
            convergence_tol=1e-5,
            max_scale_tries=5
        )

        torch.cuda.empty_cache()
        gc.collect()
        
        scores = self._compute_scores(module)
        if save:
            # Add hyperparameters to the filename for better tracking
            method_suffix = f'_lissa_d{damp}_s{scale}_r{depth}_t{repeat}'
            self._save_scores(scores, method_suffix)
        return scores if not save else None

    def run_trak(
        self,
        save: bool = True,
        num_projections: int = 100,
        normalize_grads: bool = True,
        seed: int = 42,
        task_name: str = None,
        load_checkpoints: Callable = None,
        num_checkpoints: int = 1
    ) -> Tuple[List[torch.Tensor], List[str]]:
        """
        Run Trak influence computation.
        
        Args:
            save: Whether to save the computed scores
            num_projections: Number of projections for Trak
            normalize_grads: Whether to normalize gradients
            seed: Random seed
            task_name: Specific task name to process
            load_checkpoints: Function to load checkpoints
            num_checkpoints: Number of checkpoints to use
            
        Returns:
            Tuple of (influences_list, task_names_list)
        """
        if not TORCHINFLUENCE_AVAILABLE:
            raise ImportError("torchinfluence is required for Trak influence computation")
        
        print("Initializing Trak module...")
        
        # Get the actual model
        model = self.model
        model.eval()
        
        # Initialize objective
        objective = MultiTaskObjective()
        
        try:
            # Create and initialize Trak module
            trak_module = TrakInfluenceModule(
                model=model,
                objective=objective,
                train_loader=self.train_dataloaders,
                test_loader=self.val_dataloaders,
                device=self.device,
                tns=self.task_names,
                load_checkpoints=load_checkpoints,
                num_checkpoints=num_checkpoints
            )
            
            influences_list = []
            task_names_list = []
            
            # Determine which tasks to process
            tasks_to_process = [task_name] if task_name else self.task_names
            
            print(f"\nComputing influences")         
            try:
                # Compute influences
                influences = trak_module.influences()
                current_train_length = 0
                influences_lists = []
                
                for train_task in self.task_names:
                    current_test_length = 0
                    train_length = len(self.train_dataloaders[train_task].dataset)
                    influences_list = []
                    
                    for test_task in self.task_names:
                        test_length = len(self.val_dataloaders[test_task].dataset)
                        influences_list.append(
                            influences[current_train_length:current_train_length+train_length, 
                                     current_test_length:current_test_length+test_length].mean(dim=1)
                        )
                        task_names_list.append(train_task)
                        current_test_length += test_length
                    
                    current_train_length += train_length
                    influences_lists.append(influences_list)
                
                if save:
                    self._save_scores(influences_lists, 'trak')
                    
            except Exception as e:
                print(f"Error computing influences for TRAK")
                print(f"Full error: {str(e.__class__.__name__)}: {str(e)}")
                import traceback
                traceback.print_exc()
            
            # Clear cache after each task
            torch.cuda.empty_cache()
            gc.collect()
            
            return influences_lists, [task_names_list]
            
        except Exception as e:
            print(f"Error initializing Trak module: {str(e)}")
            print(f"Full error: {str(e.__class__.__name__)}: {str(e)}")
            import traceback
            traceback.print_exc()
            return [], []

    # Alias for backward compatibility
    run = run_autograd


def compute_exact_influence(
    model: torch.nn.Module,
    train_loaders: Dict[str, torch.utils.data.DataLoader],
    test_loader: torch.utils.data.DataLoader,
    task: str,
    device: torch.device,
    damping: float = 0.01
) -> torch.Tensor:
    """
    Compute exact influence scores using Hessian-based method.
    
    Args:
        model: Neural network model
        train_loaders: Training data loaders
        test_loader: Test data loader
        task: Task name
        device: Computation device
        damping: Damping parameter for numerical stability
        
    Returns:
        Exact influence scores
    """
    model.eval()
    
    # Collect gradients for all training samples
    train_gradients = []
    train_samples = []
    
    for batch_idx, (inputs, targets) in enumerate(train_loaders[task]):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Forward pass
        outputs = model(inputs)
        loss = F.cross_entropy(outputs[task], targets)
        
        # Backward pass to get gradients
        loss.backward()
        
        # Collect gradients
        gradients = []
        for param in model.parameters():
            if param.grad is not None:
                gradients.append(param.grad.flatten())
        
        train_gradients.append(torch.cat(gradients))
        train_samples.append((inputs, targets))
        
        # Reset gradients
        model.zero_grad()
    
    # Stack all gradients
    train_gradients = torch.stack(train_gradients)
    
    # Compute Hessian
    hessian = torch.zeros(train_gradients.shape[1], train_gradients.shape[1], device=device)
    
    for batch_idx, (inputs, targets) in enumerate(train_loaders[task]):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Forward pass
        outputs = model(inputs)
        loss = F.cross_entropy(outputs[task], targets)
        
        # Compute Hessian-vector products
        gradients = torch.autograd.grad(loss, model.parameters(), create_graph=True)
        gradients = torch.cat([g.flatten() for g in gradients])
        
        for i in range(len(gradients)):
            hessian_row = torch.autograd.grad(gradients[i], model.parameters(), retain_graph=True)
            hessian_row = torch.cat([g.flatten() for g in hessian_row])
            hessian[i] = hessian_row
        
        model.zero_grad()
    
    # Add damping
    hessian += damping * torch.eye(hessian.shape[0], device=device)
    
    # Compute inverse
    hessian_inv = torch.inverse(hessian)
    
    # Compute influence scores
    influence_scores = torch.zeros(len(train_samples), device=device)
    
    for test_idx, (test_inputs, test_targets) in enumerate(test_loader):
        test_inputs, test_targets = test_inputs.to(device), test_targets.to(device)
        
        # Forward pass
        outputs = model(test_inputs)
        test_loss = F.cross_entropy(outputs[task], test_targets)
        
        # Compute gradients for test sample
        test_gradients = torch.autograd.grad(test_loss, model.parameters())
        test_gradients = torch.cat([g.flatten() for g in test_gradients])
        
        # Compute influence scores
        for train_idx in range(len(train_samples)):
            influence_scores[train_idx] += torch.dot(
                train_gradients[train_idx], 
                hessian_inv @ test_gradients
            )
        
        model.zero_grad()
    
    return influence_scores 