"""
Multi-Task Learning Trainer

This module provides a trainer class for multi-task learning experiments.
"""

import os
import time
from typing import Dict, List, Optional, Union, Any, Tuple
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import json
from copy import deepcopy

# Import LibMTL components if available
try:
    from LibMTL.LibMTL import Trainer
    from LibMTL.LibMTL.loss import CELoss, L1Loss, KLDivLoss
    from LibMTL.LibMTL.metrics import AccMetric
    LIBMTL_AVAILABLE = True
except ImportError:
    LIBMTL_AVAILABLE = False
    print("Warning: LibMTL not available. Some functionality may be limited.")

from .models import get_model
from .dataloader import load_face_dataset, create_face_dataloaders
from .utils import set_seed


class MultiTaskTrainer:
    """
    Multi-Task Learning Trainer for handling training process and results.
    
    This class provides a high-level interface for training multi-task learning models
    with various datasets and architectures.
    """
    
    def __init__(self, config: Dict[str, Any]):
        """
        Initialize multi-task trainer with configuration.
        
        Args:
            config: Dictionary containing training configuration parameters
        """
        # Training parameters
        self.num_epochs = config['num_epochs']
        self.batch_size = config['batch_size']
        self.seed = config['seed']
        
        # Optimization parameters
        self.optim_param = config['optim_param']
        self.scheduler_param = config['scheduler_param']
        self.kwargs = config['kwargs']
        
        # Dataset parameters
        self.dataset = config['dataset']
        self.training_size = config['training_size']
        self.val_size = config['val_size']
        self.test_size = config['test_size']
        self.num_classes = config['num_classes']
        
        # Task parameters
        self.random_select = config['random_select']
        if self.random_select:
            self.task_num = config['task_num']
            self.task_names = config['task_names'][:self.task_num]
            self.task_label = torch.arange(len(self.task_names))
        else:
            self.task_num = config['task_num']
            self.tasks = config['tasks']
            self.task_names = self.tasks
            self.task_label = [config['task_names'].index(task) for task in self.tasks]
        
        # Other parameters
        self.adversarial = config['adversarial']
        self.adversarial_size = config['adversarial_size']
        self.no_stratify = config['no_stratify']
        self.grouping = config['grouping']
        self.device = torch.device(f"cuda:{config['device']}" if torch.cuda.is_available() else "cpu")
        
        print(f"Starting training on {self.device}")

    def prepare_dataloaders(
        self, 
        random_removal: bool = False, 
        early_stopping: bool = False,
        is_feature_extractor: bool = True,
        remove_ratio: float = 0.05, 
        task: Optional[str] = None, 
        remove: bool = False,
        method: str = 'autograd'
    ) -> bool:
        """
        Prepare data loaders for training, validation and testing.
        
        Args:
            random_removal: Whether to randomly remove samples
            early_stopping: Whether to use early stopping
            remove_ratio: Ratio of samples to remove
            task: Specific task to consider for removal
            remove: Whether to remove samples
            method: Method for influence computation
                        
        Returns:
            bool: True if any task has no samples left after removal
        """
        mask = self._create_removal_mask(random_removal, remove_ratio, method, remove)
        
        if mask is not None:
            mask_sum = [sum(mask[i]) for i in range(self.task_num)]
            print(f"Mask sum: {mask_sum}")
        
        # Set random seed for reproducibility
        set_seed(self.seed)
        
        # Prepare dataset-specific dataloaders
        self._prepare_dataset_loaders(mask)
        
        return False

    def _create_removal_mask(
        self, 
        random_removal: bool, 
        remove_ratio: float, 
        method: str,
        remove: bool
    ) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
        """
        Create mask for data removal.
        
        Args:
            random_removal: Whether to use random removal
            remove_ratio: Ratio of data to remove
            method: Method for influence computation
            remove: Whether to remove data
            
        Returns:
            Removal mask or None
        """
        if not remove:
            return None
            
        if random_removal:
            return self._create_random_removal_mask(remove_ratio)
        else:
            return self._create_fixed_size_mask(random_removal, remove_ratio, method)

    def _create_random_removal_mask(self, remove_ratio: float) -> List[torch.Tensor]:
        """
        Create random removal mask.
        
        Args:
            remove_ratio: Ratio of data to remove
            
        Returns:
            List of removal masks for each task
        """
        import numpy as np
        
        masks = []
        for task_idx in range(self.task_num):
            num_samples = self.training_size
            num_to_remove = int(remove_ratio * num_samples)
            
            # Create random mask
            mask = np.ones(num_samples, dtype=bool)
            remove_indices = np.random.choice(num_samples, num_to_remove, replace=False)
            mask[remove_indices] = False
            
            masks.append(torch.tensor(mask))
        
        return masks

    def _create_fixed_size_mask(
        self, 
        random_removal: bool, 
        remove_ratio: float, 
        method: str
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        """
        Create fixed size removal mask.
        
        Args:
            random_removal: Whether to use random removal
            remove_ratio: Ratio of data to remove
            method: Method for influence computation
            
        Returns:
            Removal mask
        """
        if hasattr(self, 'training_size') and isinstance(self.training_size, list):
            return self._create_variable_size_mask(random_removal, remove_ratio, method)
        else:
            return self._create_random_removal_mask(remove_ratio)

    def _create_variable_size_mask(
        self, 
        random_removal: bool, 
        remove_ratio: float, 
        method: str
    ) -> List[torch.Tensor]:
        """
        Create variable size removal mask.
        
        Args:
            random_removal: Whether to use random removal
            remove_ratio: Ratio of data to remove
            method: Method for influence computation
            
        Returns:
            List of removal masks for each task
        """
        import numpy as np
        
        masks = []
        for task_idx in range(self.task_num):
            num_samples = self.training_size[task_idx]
            num_to_remove = int(remove_ratio * num_samples)
            
            # Create random mask
            mask = np.ones(num_samples, dtype=bool)
            remove_indices = np.random.choice(num_samples, num_to_remove, replace=False)
            mask[remove_indices] = False
            
            masks.append(torch.tensor(mask))
        
        return masks

    def _check_empty_tasks(self, mask: Union[torch.Tensor, List[torch.Tensor]]) -> bool:
        """
        Check if any tasks have no samples left after removal.
        
        Args:
            mask: Removal mask
            
        Returns:
            True if any task is empty
        """
        if isinstance(mask, list):
            for task_mask in mask:
                if task_mask.sum() == 0:
                    return True
        else:
            if mask.sum() == 0:
                return True
        return False

    def _prepare_dataset_loaders(self, mask: Optional[Union[torch.Tensor, List[torch.Tensor]]]):
        """
        Prepare dataset-specific data loaders.
        
        Args:
            mask: Optional removal mask
        """
        if self.dataset == 'face':
            self._prepare_face_dataloaders(mask)
        else:
            raise ValueError(f"Unknown dataset: {self.dataset}")

    def _prepare_face_dataloaders(self, mask: Optional[Union[torch.Tensor, List[torch.Tensor]]]):
        """
        Prepare face dataset data loaders.
        
        Args:
            mask: Optional removal mask
        """
        # Load face dataset
        train_dataset, val_dataset, test_dataset = load_face_dataset(
            image_size=64,
            normalize=True,
            data_root='./data'
        )
        
        # Create data loaders
        self.train_dataloaders, self.val_dataloaders, self.test_dataloaders = create_face_dataloaders(
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            test_dataset=test_dataset,
            task_names=self.task_names,
            task_labels=self.task_label,
            training_size=self.training_size,
            val_size=self.val_size,
            test_size=self.test_size,
            batch_size=self.batch_size,
            seed=self.seed,
            mask=mask,
            adversarial=self.adversarial,
            adversarial_size=self.adversarial_size,
            no_stratify=self.no_stratify
        )

    def prepare_model(
        self, 
        weighting: str = 'EW', 
        model_name: str = 'FeatureEncoder', 
        checkpoint_freq: int = 49,
        **kwargs
    ):
        """
        Prepare model for training.
        
        Args:
            weighting: Task weighting strategy
            model_name: Name of the model to use
            checkpoint_freq: Checkpoint frequency
            **kwargs: Additional arguments
        """
        if not LIBMTL_AVAILABLE:
            raise ImportError("LibMTL is required for model preparation")
        
        # Create model
        model = get_model(model_name)
        
        # Create trainer
        self.trainer = Trainer(
            task_dict=self.train_dataloaders,
            weighting=weighting,
            architecture=model,
            **kwargs
        )
        
        print(f"Model prepared with {weighting} weighting strategy")

    def train_model(
        self, 
        early_stopping: bool = False, 
        save_test_acc: bool = False, 
        save: bool = True,
        num_checkpoints: int = 1,
        weighting: str = 'EW', 
        model_name: str = 'FeatureEncoder',
        **kwargs
    ) -> Dict:
        """
        Train the multi-task model.
        
        Args:
            early_stopping: Whether to use early stopping
            save_test_acc: Whether to save test accuracy
            save: Whether to save model
            num_checkpoints: Number of checkpoints to save
            weighting: Task weighting strategy
            model_name: Name of the model to use
            **kwargs: Additional arguments
            
        Returns:
            Training results dictionary
        """
        # Prepare model
        self.prepare_model(weighting=weighting, model_name=model_name, **kwargs)
        
        # Train model
        results = self.trainer.train(
            epochs=self.num_epochs,
            early_stopping=early_stopping,
            save_test_acc=save_test_acc,
            save=save,
            num_checkpoints=num_checkpoints
        )
        
        return results

    def test(self, test_dataloaders: Optional[Dict[str, DataLoader]] = None) -> List[float]:
        """
        Test the trained model.
        
        Args:
            test_dataloaders: Optional test data loaders
            
        Returns:
            List of test accuracies for each task
        """
        if test_dataloaders is None:
            test_dataloaders = self.test_dataloaders
        
        return self.trainer.test(test_dataloaders)

    def save_model(self, path: str):
        """
        Save the trained model.
        
        Args:
            path: Path to save the model
        """
        self.trainer.save_model(path)

    def load_model(self, path: str):
        """
        Load a trained model.
        
        Args:
            path: Path to the model file
        """
        self.trainer.load_model(path)

    def eval(self):
        """Set model to evaluation mode."""
        self.trainer.eval()

    @staticmethod
    def pretty_print(results: Dict):
        """
        Pretty print training results.
        
        Args:
            results: Training results dictionary
        """
        print("Training Results:")
        for key, value in results.items():
            print(f"  {key}: {value}")

    def _get_model_components(self):
        """Get model components for analysis."""
        return {
            'model': self.trainer.MTLmodel.model,
            'task_names': self.task_names,
            'task_num': self.task_num,
            'training_size': self.training_size,
            'val_size': self.val_size,
            'batch_size': self.batch_size,
            'seed': self.seed,
            'device': self.device,
            'num_epochs': self.num_epochs,
            'train_dataloaders': self.train_dataloaders,
            'val_dataloaders': self.val_dataloaders,
            'dataset': self.dataset
        }


def load_config(config_path: str = 'config.json') -> Dict[str, Any]:
    """
    Load configuration from JSON file.
    
    Args:
        config_path: Path to configuration file
        
    Returns:
        Configuration dictionary
    """
    with open(config_path, 'r') as f:
        return json.load(f)


def save_config(config: Dict[str, Any], config_path: str = 'config.json') -> None:
    """
    Save configuration to JSON file.
    
    Args:
        config: Configuration dictionary
        config_path: Path to save configuration
    """
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2) 