"""
Utility Functions for Multi-Task Learning Influence Analysis

This module provides utility functions for data processing, model training, and analysis.
"""

import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset, TensorDataset
from torchvision import models, datasets, transforms
from torchvision.transforms import functional as TF
from typing import Dict, List, Optional, Tuple, Union, Any
from copy import deepcopy
import matplotlib.pyplot as plt


def set_seed(seed: int) -> None:
    """
    Set random seeds for reproducibility.
    
    Args:
        seed: Random seed value
    """
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


def process_with_resnet(
    data_loaders: Dict[str, DataLoader],
    feature_extractor: nn.Module,
    device: str = "cuda",
    dataset_name: str = 'FaceDataset',
) -> Dict[str, DataLoader]:
    """
    Process data using a pre-trained ResNet feature extractor.
    
    Args:
        data_loaders: Dictionary of data loaders
        feature_extractor: Pre-trained ResNet model
        device: Computation device
        dataset_name: Name of the dataset
        
    Returns:
        Dictionary of processed data loaders
    """
    feature_extractor.eval()
    processed_loaders = {}
    
    for task_name, loader in data_loaders.items():
        processed_features = []
        processed_labels = []
        
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(loader):
                inputs = inputs.to(device)
                
                # Extract features
                features = feature_extractor(inputs)
                features = features.cpu()
                
                processed_features.append(features)
                processed_labels.append(targets)
        
        # Concatenate all features and labels
        all_features = torch.cat(processed_features, dim=0)
        all_labels = torch.cat(processed_labels, dim=0)
        
        # Create new dataset and loader
        processed_dataset = TensorDataset(all_features, all_labels)
        processed_loaders[task_name] = DataLoader(
            processed_dataset,
            batch_size=loader.batch_size,
            shuffle=False,
            num_workers=loader.num_workers
        )
    
    return processed_loaders


def process_with_bert(
    data_loaders: Dict[str, DataLoader], 
    bert_model: nn.Module, 
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
) -> Dict[str, DataLoader]:
    """
    Process data using a pre-trained BERT model.
    
    Args:
        data_loaders: Dictionary of data loaders
        bert_model: Pre-trained BERT model
        device: Computation device
        
    Returns:
        Dictionary of processed data loaders
    """
    bert_model.eval()
    processed_loaders = {}
    
    for task_name, loader in data_loaders.items():
        processed_features = []
        processed_labels = []
        
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(loader):
                inputs = inputs.to(device)
                
                # Extract BERT embeddings
                outputs = bert_model(inputs)
                embeddings = outputs.last_hidden_state[:, 0, :]  # Use [CLS] token
                embeddings = embeddings.cpu()
                
                processed_features.append(embeddings)
                processed_labels.append(targets)
        
        # Concatenate all features and labels
        all_features = torch.cat(processed_features, dim=0)
        all_labels = torch.cat(processed_labels, dim=0)
        
        # Create new dataset and loader
        processed_dataset = TensorDataset(all_features, all_labels)
        processed_loaders[task_name] = DataLoader(
            processed_dataset,
            batch_size=loader.batch_size,
            shuffle=False,
            num_workers=loader.num_workers
        )
    
    return processed_loaders


def load_config(config_path: str = 'config.json') -> Dict[str, Any]:
    """
    Load configuration from JSON file.
    
    Args:
        config_path: Path to configuration file
        
    Returns:
        Configuration dictionary
    """
    with open(config_path, 'r') as f:
        return json.load(f)


def save_config(config: Dict[str, Any], config_path: str = 'config.json') -> None:
    """
    Save configuration to JSON file.
    
    Args:
        config: Configuration dictionary
        config_path: Path to save configuration
    """
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2)


def plot_training_curves(
    train_losses: List[float],
    val_losses: Optional[List[float]] = None,
    train_accuracies: Optional[List[float]] = None,
    val_accuracies: Optional[List[float]] = None,
    save_path: Optional[str] = None
) -> None:
    """
    Plot training curves.
    
    Args:
        train_losses: Training losses
        val_losses: Validation losses
        train_accuracies: Training accuracies
        val_accuracies: Validation accuracies
        save_path: Path to save the plot
    """
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Plot losses
    axes[0].plot(train_losses, label='Train Loss')
    if val_losses is not None:
        axes[0].plot(val_losses, label='Val Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].set_title('Training and Validation Loss')
    
    # Plot accuracies
    if train_accuracies is not None:
        axes[1].plot(train_accuracies, label='Train Acc')
    if val_accuracies is not None:
        axes[1].plot(val_accuracies, label='Val Acc')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].legend()
    axes[1].set_title('Training and Validation Accuracy')
    
    plt.tight_layout()
    
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()


def compute_exact_influence(
    model: nn.Module,
    train_loaders: Dict[str, DataLoader],
    test_loader: DataLoader,
    task: str,
    device: torch.device,
    damping: float = 0.01
) -> torch.Tensor:
    """
    Compute exact influence scores using Hessian-based method.
    
    Args:
        model: Neural network model
        train_loaders: Training data loaders
        test_loader: Test data loader
        task: Task name
        device: Computation device
        damping: Damping parameter for numerical stability
        
    Returns:
        Exact influence scores
    """
    model.eval()
    
    # Collect gradients for all training samples
    train_gradients = []
    train_samples = []
    
    for batch_idx, (inputs, targets) in enumerate(train_loaders[task]):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Forward pass
        outputs = model(inputs)
        loss = F.cross_entropy(outputs[task], targets)
        
        # Backward pass to get gradients
        loss.backward()
        
        # Collect gradients
        gradients = []
        for param in model.parameters():
            if param.grad is not None:
                gradients.append(param.grad.flatten())
        
        train_gradients.append(torch.cat(gradients))
        train_samples.append((inputs, targets))
        
        # Reset gradients
        model.zero_grad()
    
    # Stack all gradients
    train_gradients = torch.stack(train_gradients)
    
    # Compute Hessian
    hessian = torch.zeros(train_gradients.shape[1], train_gradients.shape[1], device=device)
    
    for batch_idx, (inputs, targets) in enumerate(train_loaders[task]):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Forward pass
        outputs = model(inputs)
        loss = F.cross_entropy(outputs[task], targets)
        
        # Compute Hessian-vector products
        gradients = torch.autograd.grad(loss, model.parameters(), create_graph=True)
        gradients = torch.cat([g.flatten() for g in gradients])
        
        for i in range(len(gradients)):
            hessian_row = torch.autograd.grad(gradients[i], model.parameters(), retain_graph=True)
            hessian_row = torch.cat([g.flatten() for g in hessian_row])
            hessian[i] = hessian_row
        
        model.zero_grad()
    
    # Add damping
    hessian += damping * torch.eye(hessian.shape[0], device=device)
    
    # Compute inverse
    hessian_inv = torch.inverse(hessian)
    
    # Compute influence scores
    influence_scores = torch.zeros(len(train_samples), device=device)
    
    for test_idx, (test_inputs, test_targets) in enumerate(test_loader):
        test_inputs, test_targets = test_inputs.to(device), test_targets.to(device)
        
        # Forward pass
        outputs = model(test_inputs)
        test_loss = F.cross_entropy(outputs[task], test_targets)
        
        # Compute gradients for test sample
        test_gradients = torch.autograd.grad(test_loss, model.parameters())
        test_gradients = torch.cat([g.flatten() for g in test_gradients])
        
        # Compute influence scores
        for train_idx in range(len(train_samples)):
            influence_scores[train_idx] += torch.dot(
                train_gradients[train_idx], 
                hessian_inv @ test_gradients
            )
        
        model.zero_grad()
    
    return influence_scores 