"""
Model Factory Module
Unified management of different types of model creation and configuration
"""

import os
import torch
from models.medclip import MedCLIP  # Original model
from models.medclip_english import EnglishMedCLIP, DEFAULT_CONFIG as ENGLISH_DEFAULT_CONFIG
from models.openclip_vit import build_openclip_vit
from models.text_encoder import DualTextEncoder
from models.region_head import RegionHead
from models.biomedclip_adapter import (
    build_biomedclip_vision_encoder,
    build_biomedclip_text_encoder,
    build_single_roi_processor
)

class ModelFactory:
    """Model factory class"""

    MODEL_TYPES = {
        'medclip': MedCLIP,                    # Original bilingual MedCLIP
        'english_medclip': EnglishMedCLIP,     # New English medical CLIP
    }

    @classmethod
    def create_model(cls, model_type, config, **kwargs):
        """
        Create model

        Args:
            model_type: Model type ('medclip', 'english_medclip')
            config: Model configuration dictionary
            **kwargs: Additional parameters

        Returns:
            Model instance
        """
        if model_type not in cls.MODEL_TYPES:
            raise ValueError(f"Unsupported model type: {model_type}. Supported types: {list(cls.MODEL_TYPES.keys())}")

        if model_type == 'medclip':
            return cls._create_original_medclip(config,** kwargs)
        elif model_type == 'english_medclip':
            return cls._create_english_medclip(config, **kwargs)
        else:
            raise ValueError(f"Unimplemented model type: {model_type}")

    @classmethod
    def _create_original_medclip(cls, config,** kwargs):
        """Create the original MedCLIP model"""
        # Build each component
        vision_encoder = build_openclip_vit(config.get('vision', {}))

        text_encoder = DualTextEncoder(
            cfg_en=config.get('text_en', {}),
            cfg_zh=config.get('text_zh', {})
        )

        region_head = RegionHead(
            in_dim=config.get('region_head', {}).get('in_dim', 512),
            out_dim=config.get('region_head', {}).get('out_dim', 512),
            num_heads=config.get('region_head', {}).get('num_heads', 8)
        )

        # Create MedCLIP model
        model = MedCLIP(
            vision_encoder=vision_encoder,
            text_encoder=text_encoder,
            region_head=region_head,
            projection_dim=config.get('projection_dim', 512),
            temperature=config.get('temperature', 0.07)
        )

        return model

    @classmethod
    def _create_english_medclip(cls, config, **kwargs):
        """Create new English medical CLIP model"""
        # Merge default configurations
        merged_config = cls._merge_configs(ENGLISH_DEFAULT_CONFIG, config)

        # Validate required configurations
        cls._validate_english_medclip_config(merged_config)

        # Create model
        model = EnglishMedCLIP(merged_config)

        # Set initial training stage
        initial_stage = kwargs.get('initial_stage', 'warmup')
        model.set_training_stage(initial_stage)

        return model

    @classmethod
    def _merge_configs(cls, default_config, user_config):
        """Deep merge configuration dictionaries"""
        merged = default_config.copy()

        for key, value in user_config.items():
            if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
                merged[key] = cls._merge_configs(merged[key], value)
            else:
                merged[key] = value

        return merged

    @classmethod
    def _validate_english_medclip_config(cls, config):
        """Validate English medical CLIP configuration"""
        required_paths = [
            'vision.pretrained_path',
            'text.pretrained_path'
        ]

        for path in required_paths:
            keys = path.split('.')
            current = config
            for key in keys:
                if key not in current:
                    raise ValueError(f"Missing required configuration: {path}")
                current = current[key]

            # Check if path exists (if it's a local path)
            if isinstance(current, str) and current.startswith('/') and not os.path.exists(current):
                print(f"Warning: Path does not exist {path}: {current}")

    @classmethod
    def get_model_info(cls, model_type):
        """Get model information"""
        info = {
            'medclip': {
                'description': 'Original bilingual medical image-text matching model',
                'features': ['Bilingual support', 'Region-level features', 'Multi-loss functions'],
                'data_format': 'coarse/fine datasets',
                'training_stages': ['Single-stage training'],
                'encoders': ['OpenCLIP ViT', 'Dual Text Encoder']
            },
            'english_medclip': {
                'description': 'New English medical CLIP, based on BiomedCLIP pre-training',
                'features': ['English monolingual', 'BiomedCLIP initialization', 'LoRA fine-tuning', 'Single ROI processing', 'Four-stage training'],
                'data_format': 'english_medical dataset',
                'training_stages': ['warmup', 'global_alignment', 'region_learning', 'fine_tuning'],
                'encoders': ['BiomedCLIP ViT', 'BiomedCLIP Text']
            }
        }
        return info.get(model_type, {})

    @classmethod
    def recommend_model_type(cls, dataset_type, task_requirements=None):
        """
        Recommend model type based on dataset type and task requirements

        Args:
            dataset_type: Type of dataset
            task_requirements: List of task requirements

        Returns:
            Recommended model type
        """
        if dataset_type == 'english_medical' or dataset_type == 'english_medical_simple':
            return 'english_medclip'
        elif dataset_type in ['coarse', 'fine']:
            return 'medclip'
        else:
            # Default recommendation
            return 'english_medclip'

    @classmethod
    def create_model_from_config_file(cls, config_path, model_type=None):
        """
        Create model from configuration file

        Args:
            config_path: Path to configuration file
            model_type: Force specified model type, automatically inferred if None

        Returns:
            Model instance
        """
        import yaml
        import json

        # Load configuration file
        with open(config_path, 'r', encoding='utf-8') as f:
            if config_path.endswith('.yaml') or config_path.endswith('.yml'):
                config = yaml.safe_load(f)
            else:
                config = json.load(f)

        # Extract model configuration
        model_config = config.get('model', {})

        # Determine model type
        if model_type is None:
            model_type = model_config.get('type')
            if model_type is None:
                # Automatically infer based on dataset type
                dataset_type = config.get('dataset', {}).get('type')
                model_type = cls.recommend_model_type(dataset_type)
                print(f"Automatically inferred model type: {model_type}")

        return cls.create_model(model_type, model_config)

    @classmethod
    def get_training_stages(cls, model_type):
        """Get model training stages"""
        if model_type == 'english_medclip':
            return ['warmup', 'global_alignment', 'region_learning', 'fine_tuning']
        elif model_type == 'medclip':
            return ['training']  # Original model has only one stage
        else:
            return []


class ModelManager:
    """Model manager, responsible for model loading, saving, and state management"""

    def __init__(self, model):
        self.model = model
        self.model_type = self._detect_model_type(model)
        self.training_history = []  # Simplified training history to avoid non-serializable objects

    def _detect_model_type(self, model):
        """Detect model type"""
        # Dynamic import to avoid circular dependencies
        import sys
        if 'models.medclip_english' in sys.modules:
            from models.medclip_english import EnglishMedCLIP
            if isinstance(model, EnglishMedCLIP):
                return 'english_medclip'

        if 'models.medclip' in sys.modules:
            from models.medclip import MedCLIP
            if isinstance(model, MedCLIP):
                return 'medclip'

        return 'unknown'

    def save_checkpoint(self, save_path, epoch=None, optimizer_state=None, **kwargs):
        """Save checkpoint (Fix: Ensure all data is serializable)"""
        try:
            checkpoint = {
                'model_type': self.model_type,
                'model_state_dict': self.model.state_dict(),
                'epoch': epoch,
                'training_history': self._serialize_training_history()  # Serialize history records
            }

            # Add model-specific information
            if hasattr(self.model, 'config'):
                checkpoint['config'] = self.model.config
            if hasattr(self.model, 'current_stage'):
                checkpoint['current_stage'] = self.model.current_stage

            # Add optimizer state
            if optimizer_state is not None:
                checkpoint['optimizer_state_dict'] = optimizer_state

            # Add other information (ensure serializable)
            for key, value in kwargs.items():
                if self._is_serializable(value):
                    checkpoint[key] = value
                else:
                    print(f"[Warning] Skipping non-serializable value for key '{key}': {type(value)}")

            torch.save(checkpoint, save_path)
            print(f"Checkpoint saved to {save_path}")

        except Exception as e:
            print(f"[Error] Failed to save checkpoint: {e}")
            import traceback
            traceback.print_exc()
            raise

    def _is_serializable(self, obj):
        """Check if object is serializable"""
        import pickle
        try:
            pickle.dumps(obj)
            return True
        except:
            return False

    def _serialize_training_history(self):
        """Serialize training history, removing non-serializable objects"""
        serialized_history = []
        for entry in self.training_history:
            serialized_entry = {}
            for key, value in entry.items():
                if self._is_serializable(value):
                    serialized_entry[key] = value
                elif key == 'timestamp':
                    # If it's a timestamp, convert to serializable format
                    if hasattr(value, 'item'):
                        serialized_entry[key] = value.item()
                    else:
                        serialized_entry[key] = str(value)
                elif isinstance(value, torch.Tensor):
                    # If it's a tensor, convert to python value
                    if value.numel() == 1:
                        serialized_entry[key] = value.item()
                    else:
                        serialized_entry[key] = value.detach().cpu().numpy().tolist()
                elif isinstance(value, dict):
                    # Recursively process dictionaries
                    serialized_dict = {}
                    for sub_key, sub_value in value.items():
                        if isinstance(sub_value, torch.Tensor) and sub_value.numel() == 1:
                            serialized_dict[sub_key] = sub_value.item()
                        elif self._is_serializable(sub_value):
                            serialized_dict[sub_key] = sub_value
                    serialized_entry[key] = serialized_dict
            serialized_history.append(serialized_entry)
        return serialized_history

    def load_checkpoint(self, load_path, map_location='cpu', strict=True):
        """Load checkpoint"""
        checkpoint = torch.load(load_path, map_location=map_location)

        # Verify model type
        saved_model_type = checkpoint.get('model_type', 'unknown')
        if saved_model_type != self.model_type and strict:
            raise ValueError(f"Model type mismatch: current={self.model_type}, saved={saved_model_type}")

        # Load model state
        self.model.load_state_dict(checkpoint['model_state_dict'], strict=strict)

        # Restore training state
        if hasattr(self.model, 'current_stage') and 'current_stage' in checkpoint:
            self.model.current_stage = checkpoint['current_stage']

        self.training_history = checkpoint.get('training_history', [])

        print(f"Checkpoint loaded from {load_path}")
        return checkpoint

    def log_training_step(self, epoch, step, loss_dict, lr=None):
        """Log training step (Fix: Use serializable data)"""
        import time
        import datetime

        # Convert tensors in loss_dict to python values
        serialized_losses = {}
        for key, value in loss_dict.items():
            if isinstance(value, torch.Tensor):
                if value.numel() == 1:
                    serialized_losses[key] = value.item()
                else:
                    serialized_losses[key] = value.detach().cpu().numpy().tolist()
            else:
                serialized_losses[key] = value

        log_entry = {
            'epoch': epoch,
            'step': step,
            'losses': serialized_losses,
            'lr': lr,
            'timestamp': time.time(),  # Use serializable timestamp
            'datetime': datetime.datetime.now().isoformat()  # Human-readable time
        }

        if hasattr(self.model, 'current_stage'):
            log_entry['stage'] = self.model.current_stage

        self.training_history.append(log_entry)

    def get_training_summary(self):
        """Get training summary"""
        if not self.training_history:
            return {}

        latest = self.training_history[-1]

        summary = {
            'total_steps': len(self.training_history),
            'latest_epoch': latest.get('epoch'),
            'latest_losses': latest.get('losses'),
            'model_type': self.model_type
        }

        if hasattr(self.model, 'current_stage'):
            summary['current_stage'] = self.model.current_stage

        return summary

    def switch_training_stage(self, new_stage):
        """Switch training stage (only applicable to EnglishMedCLIP)"""
        if hasattr(self.model, 'set_training_stage'):
            self.model.set_training_stage(new_stage)
            print(f"Switched to training stage: {new_stage}")
        else:
            print(f"Model {self.model_type} does not support stage switching")


def create_model_from_config(config, model_type=None):
    """
    Convenience function to create model from configuration

    Args:
        config: Configuration dictionary or path to configuration file
        model_type: Model type, automatically inferred if None

    Returns:
        Tuple of (model, model_manager)
    """
    if isinstance(config, str):
        # Configuration file path
        model = ModelFactory.create_model_from_config_file(config, model_type)
    else:
        # Configuration dictionary
        if model_type is None:
            model_type = ModelFactory.recommend_model_type(
                config.get('dataset', {}).get('type')
            )

        model_config = config.get('model', config)  # Compatibility handling
        model = ModelFactory.create_model(model_type, model_config)

    # Create model manager
    manager = ModelManager(model)

    return model, manager


# Usage examples and test functions
def test_model_factory():
    """Test model factory functionality"""
    print("Testing ModelFactory...")

    # Test model information retrieval
    for model_type in ModelFactory.MODEL_TYPES.keys():
        info = ModelFactory.get_model_info(model_type)
        print(f"\n{model_type}:")
        print(f"  Description: {info.get('description')}")
        print(f"  Features: {info.get('features')}")
        print(f"  Training stages: {ModelFactory.get_training_stages(model_type)}")

    # Test model recommendations
    test_cases = [
        ('english_medical', 'english_medclip'),
        ('coarse', 'medclip'),
        ('fine', 'medclip')
    ]

    print("\nModel recommendations:")
    for dataset_type, expected in test_cases:
        recommended = ModelFactory.recommend_model_type(dataset_type)
        status = "✓" if recommended == expected else "✗"
        print(f"  {dataset_type} -> {recommended} {status}")

    print("\nModelFactory test completed!")


if __name__ == "__main__":
    test_model_factory()