"""
MLflow Export-Import Tools

This module provides comprehensive tools for exporting and importing MLflow experiments,
models, and runs between different tracking servers. It supports pattern-based filtering
and bulk operations for efficient data migration.
"""

import os
import json
import time
import re
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, asdict
import click

import mlflow
from mlflow.entities import Experiment, Run
from mlflow.exceptions import RestException, MlflowException
from mlflow.tracking import MlflowClient

logger = logging.getLogger(__name__)


@dataclass
class ExportConfig:
    """Configuration for export operations."""
    output_dir: str
    export_models: bool = True
    export_experiments: bool = True
    export_runs: bool = True
    export_artifacts: bool = True
    export_permissions: bool = False
    export_deleted_runs: bool = False
    export_latest_versions_only: bool = False
    use_threads: bool = True
    max_workers: int = 4
    run_start_time: Optional[str] = None
    stages: Optional[List[str]] = None


@dataclass
class ImportConfig:
    """Configuration for import operations."""
    input_dir: str
    import_models: bool = True
    import_experiments: bool = True
    import_permissions: bool = False
    import_source_tags: bool = True
    use_src_user_id: bool = False
    delete_existing_models: bool = False
    use_threads: bool = True
    max_workers: int = 4
    experiment_renames: Optional[Dict[str, str]] = None
    model_renames: Optional[Dict[str, str]] = None


@dataclass
class ExportResult:
    """Result of an export operation."""
    success: bool
    duration: float
    experiments_exported: int
    models_exported: int
    runs_exported: int
    errors: List[str]
    output_dir: str


@dataclass
class ImportResult:
    """Result of an import operation."""
    success: bool
    duration: float
    experiments_imported: int
    models_imported: int
    runs_imported: int
    errors: List[str]
    input_dir: str


class MLflowExportImportTool:
    """
    Comprehensive tool for exporting and importing MLflow data between tracking servers.
    
    This tool supports:
    - Exporting entire tracking servers
    - Pattern-based experiment filtering
    - Bulk import operations
    - Model and experiment migration
    - Artifact preservation
    - Parallel processing for large datasets
    """
    
    def __init__(self, source_tracking_uri: str, dest_tracking_uri: Optional[str] = None):
        """
        Initialize the export-import tool.
        
        Args:
            source_tracking_uri: URI of the source MLflow tracking server
            dest_tracking_uri: URI of the destination MLflow tracking server (for import operations)
        """
        self.source_tracking_uri = source_tracking_uri
        self.dest_tracking_uri = dest_tracking_uri
        
        # Initialize clients
        self.source_client = MlflowClient(tracking_uri=source_tracking_uri)
        if dest_tracking_uri:
            self.dest_client = MlflowClient(tracking_uri=dest_tracking_uri)
        else:
            self.dest_client = None
    
    def export_all(self, config: ExportConfig) -> ExportResult:
        """
        Export entire tracking server.
        
        Args:
            config: Export configuration
            
        Returns:
            ExportResult with operation details
        """
        start_time = time.time()
        errors = []
        
        try:
            # Create output directory
            os.makedirs(config.output_dir, exist_ok=True)
            
            # Export experiments
            experiments_exported = 0
            if config.export_experiments:
                experiments_exported = self._export_all_experiments(config)
            
            # Export models
            models_exported = 0
            if config.export_models:
                models_exported = self._export_all_models(config)
            
            duration = time.time() - start_time
            
            return ExportResult(
                success=True,
                duration=duration,
                experiments_exported=experiments_exported,
                models_exported=models_exported,
                runs_exported=0,  # Will be calculated from experiments
                errors=errors,
                output_dir=config.output_dir
            )
            
        except Exception as e:
            errors.append(f"Export failed: {str(e)}")
            duration = time.time() - start_time
            return ExportResult(
                success=False,
                duration=duration,
                experiments_exported=0,
                models_exported=0,
                runs_exported=0,
                errors=errors,
                output_dir=config.output_dir
            )
    
    def export_experiments_by_pattern(self, pattern: str, config: ExportConfig) -> ExportResult:
        """
        Export experiments matching a pattern.
        
        Args:
            pattern: Regex pattern to match experiment names
            config: Export configuration
            
        Returns:
            ExportResult with operation details
        """
        start_time = time.time()
        errors = []
        
        try:
            # Get all experiments
            experiments = self.source_client.list_experiments()
            
            # Filter by pattern
            matching_experiments = []
            pattern_re = re.compile(pattern, re.IGNORECASE)
            
            for exp in experiments:
                if pattern_re.search(exp.name):
                    matching_experiments.append(exp)
            
            logger.info(f"Found {len(matching_experiments)} experiments matching pattern '{pattern}'")
            
            if not matching_experiments:
                return ExportResult(
                    success=True,
                    duration=time.time() - start_time,
                    experiments_exported=0,
                    models_exported=0,
                    runs_exported=0,
                    errors=["No experiments found matching pattern"],
                    output_dir=config.output_dir
                )
            
            # Export matching experiments
            experiments_exported = self._export_experiments(matching_experiments, config)
            
            duration = time.time() - start_time
            
            return ExportResult(
                success=True,
                duration=duration,
                experiments_exported=experiments_exported,
                models_exported=0,
                runs_exported=0,
                errors=errors,
                output_dir=config.output_dir
            )
            
        except Exception as e:
            errors.append(f"Pattern export failed: {str(e)}")
            duration = time.time() - start_time
            return ExportResult(
                success=False,
                duration=duration,
                experiments_exported=0,
                models_exported=0,
                runs_exported=0,
                errors=errors,
                output_dir=config.output_dir
            )
    
    def import_all(self, config: ImportConfig) -> ImportResult:
        """
        Import all data from export directory.
        
        Args:
            config: Import configuration
            
        Returns:
            ImportResult with operation details
        """
        if not self.dest_client:
            raise ValueError("Destination tracking URI must be set for import operations")
        
        start_time = time.time()
        errors = []
        
        try:
            # Validate input directory
            if not os.path.exists(config.input_dir):
                raise FileNotFoundError(f"Input directory not found: {config.input_dir}")
            
            # Import experiments
            experiments_imported = 0
            if config.import_experiments:
                experiments_imported = self._import_experiments(config)
            
            # Import models
            models_imported = 0
            if config.import_models:
                models_imported = self._import_models(config)
            
            duration = time.time() - start_time
            
            return ImportResult(
                success=True,
                duration=duration,
                experiments_imported=experiments_imported,
                models_imported=models_imported,
                runs_imported=0,  # Will be calculated from experiments
                errors=errors,
                input_dir=config.input_dir
            )
            
        except Exception as e:
            errors.append(f"Import failed: {str(e)}")
            duration = time.time() - start_time
            return ImportResult(
                success=False,
                duration=duration,
                experiments_imported=0,
                models_imported=0,
                runs_imported=0,
                errors=errors,
                input_dir=config.input_dir
            )
    
    def _export_all_experiments(self, config: ExportConfig) -> int:
        """Export all experiments from the tracking server."""
        experiments = self.source_client.list_experiments()
        return self._export_experiments(experiments, config)
    
    def _export_experiments(self, experiments: List[Experiment], config: ExportConfig) -> int:
        """Export a list of experiments."""
        experiments_dir = os.path.join(config.output_dir, "experiments")
        os.makedirs(experiments_dir, exist_ok=True)
        
        exported_count = 0
        max_workers = config.max_workers if config.use_threads else 1
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = []
            
            for exp in experiments:
                future = executor.submit(
                    self._export_single_experiment,
                    exp,
                    experiments_dir,
                    config
                )
                futures.append(future)
            
            for future in as_completed(futures):
                try:
                    result = future.result()
                    if result:
                        exported_count += 1
                except Exception as e:
                    logger.error(f"Failed to export experiment: {e}")
        
        # Save experiments metadata
        experiments_metadata = {
            "export_time": time.time(),
            "source_tracking_uri": self.source_tracking_uri,
            "experiments": [asdict(exp) for exp in experiments]
        }
        
        metadata_file = os.path.join(experiments_dir, "experiments.json")
        with open(metadata_file, 'w') as f:
            json.dump(experiments_metadata, f, indent=2)
        
        return exported_count
    
    def _export_single_experiment(self, experiment: Experiment, output_dir: str, config: ExportConfig) -> bool:
        """Export a single experiment with all its runs."""
        try:
            exp_dir = os.path.join(output_dir, experiment.experiment_id)
            os.makedirs(exp_dir, exist_ok=True)
            
            # Export experiment metadata
            exp_metadata = {
                "experiment_id": experiment.experiment_id,
                "name": experiment.name,
                "artifact_location": experiment.artifact_location,
                "lifecycle_stage": experiment.lifecycle_stage,
                "tags": experiment.tags
            }
            
            with open(os.path.join(exp_dir, "experiment.json"), 'w') as f:
                json.dump(exp_metadata, f, indent=2)
            
            # Export runs
            runs = self.source_client.search_runs(
                experiment_ids=[experiment.experiment_id],
                filter_string=config.run_start_time if config.run_start_time else None
            )
            
            runs_dir = os.path.join(exp_dir, "runs")
            os.makedirs(runs_dir, exist_ok=True)
            
            for run in runs:
                self._export_single_run(run, runs_dir, config)
            
            logger.info(f"Exported experiment: {experiment.name} ({experiment.experiment_id})")
            return True
            
        except Exception as e:
            logger.error(f"Failed to export experiment {experiment.name}: {e}")
            return False
    
    def _export_single_run(self, run: Run, runs_dir: str, config: ExportConfig):
        """Export a single run with its artifacts."""
        try:
            run_dir = os.path.join(runs_dir, run.info.run_id)
            os.makedirs(run_dir, exist_ok=True)
            
            # Export run metadata
            run_metadata = {
                "run_id": run.info.run_id,
                "experiment_id": run.info.experiment_id,
                "status": run.info.status,
                "start_time": run.info.start_time,
                "end_time": run.info.end_time,
                "lifecycle_stage": run.info.lifecycle_stage,
                "params": run.data.params,
                "metrics": run.data.metrics,
                "tags": run.data.tags
            }
            
            with open(os.path.join(run_dir, "run.json"), 'w') as f:
                json.dump(run_metadata, f, indent=2)
            
            # Export artifacts if enabled
            if config.export_artifacts:
                artifacts_dir = os.path.join(run_dir, "artifacts")
                self._export_artifacts(run, artifacts_dir)
            
            logger.debug(f"Exported run: {run.info.run_id}")
            
        except Exception as e:
            logger.error(f"Failed to export run {run.info.run_id}: {e}")
    
    def _export_artifacts(self, run: Run, artifacts_dir: str):
        """Export run artifacts."""
        try:
            # Download artifacts using MLflow client
            self.source_client.download_artifacts(run.info.run_id, artifacts_dir)
        except Exception as e:
            logger.warning(f"Failed to export artifacts for run {run.info.run_id}: {e}")
    
    def _export_all_models(self, config: ExportConfig) -> int:
        """Export all registered models."""
        models_dir = os.path.join(config.output_dir, "models")
        os.makedirs(models_dir, exist_ok=True)
        
        try:
            models = self.source_client.list_registered_models()
            exported_count = 0
            
            for model in models:
                if self._export_single_model(model, models_dir, config):
                    exported_count += 1
            
            # Save models metadata
            models_metadata = {
                "export_time": time.time(),
                "source_tracking_uri": self.source_tracking_uri,
                "models": [{"name": model.name, "description": model.description} for model in models]
            }
            
            metadata_file = os.path.join(models_dir, "models.json")
            with open(metadata_file, 'w') as f:
                json.dump(models_metadata, f, indent=2)
            
            return exported_count
            
        except Exception as e:
            logger.error(f"Failed to export models: {e}")
            return 0
    
    def _export_single_model(self, model, models_dir: str, config: ExportConfig) -> bool:
        """Export a single registered model."""
        try:
            model_dir = os.path.join(models_dir, model.name)
            os.makedirs(model_dir, exist_ok=True)
            
            # Export model metadata
            model_metadata = {
                "name": model.name,
                "description": model.description,
                "tags": model.tags,
                "creation_timestamp": model.creation_timestamp,
                "last_updated_timestamp": model.last_updated_timestamp
            }
            
            with open(os.path.join(model_dir, "model.json"), 'w') as f:
                json.dump(model_metadata, f, indent=2)
            
            # Export model versions
            versions = self.source_client.search_model_versions(f"name='{model.name}'")
            
            versions_dir = os.path.join(model_dir, "versions")
            os.makedirs(versions_dir, exist_ok=True)
            
            for version in versions:
                if not config.export_latest_versions_only or version.current_stage == "Production":
                    self._export_model_version(version, versions_dir, config)
            
            logger.info(f"Exported model: {model.name}")
            return True
            
        except Exception as e:
            logger.error(f"Failed to export model {model.name}: {e}")
            return False
    
    def _export_model_version(self, version, versions_dir: str, config: ExportConfig):
        """Export a single model version."""
        try:
            version_dir = os.path.join(versions_dir, str(version.version))
            os.makedirs(version_dir, exist_ok=True)
            
            # Export version metadata
            version_metadata = {
                "version": version.version,
                "run_id": version.run_id,
                "status": version.status,
                "current_stage": version.current_stage,
                "description": version.description,
                "tags": version.tags
            }
            
            with open(os.path.join(version_dir, "version.json"), 'w') as f:
                json.dump(version_metadata, f, indent=2)
            
            # Download model files
            model_uri = f"models:/{version.name}/{version.version}"
            model_path = os.path.join(version_dir, "model")
            mlflow.artifacts.download_artifacts(artifact_uri=model_uri, dst_path=model_path)
            
        except Exception as e:
            logger.error(f"Failed to export model version {version.name}/{version.version}: {e}")
    
    def _import_experiments(self, config: ImportConfig) -> int:
        """Import experiments from export directory."""
        experiments_dir = os.path.join(config.input_dir, "experiments")
        
        if not os.path.exists(experiments_dir):
            logger.warning("No experiments directory found")
            return 0
        
        # Read experiments metadata
        metadata_file = os.path.join(experiments_dir, "experiments.json")
        if not os.path.exists(metadata_file):
            logger.warning("No experiments metadata found")
            return 0
        
        with open(metadata_file, 'r') as f:
            metadata = json.load(f)
        
        imported_count = 0
        max_workers = config.max_workers if config.use_threads else 1
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = []
            
            for exp_data in metadata["experiments"]:
                future = executor.submit(
                    self._import_single_experiment,
                    exp_data,
                    experiments_dir,
                    config
                )
                futures.append(future)
            
            for future in as_completed(futures):
                try:
                    result = future.result()
                    if result:
                        imported_count += 1
                except Exception as e:
                    logger.error(f"Failed to import experiment: {e}")
        
        return imported_count
    
    def _import_single_experiment(self, exp_data: Dict, experiments_dir: str, config: ImportConfig) -> bool:
        """Import a single experiment."""
        try:
            exp_id = exp_data["experiment_id"]
            exp_name = exp_data["name"]
            
            # Apply rename if specified
            if config.experiment_renames and exp_name in config.experiment_renames:
                exp_name = config.experiment_renames[exp_name]
            
            # Create or get experiment
            experiment = self.dest_client.get_experiment_by_name(exp_name)
            if experiment is None:
                experiment_id = self.dest_client.create_experiment(exp_name)
                experiment = self.dest_client.get_experiment(experiment_id)
            
            # Import runs
            exp_dir = os.path.join(experiments_dir, exp_id)
            runs_dir = os.path.join(exp_dir, "runs")
            
            if os.path.exists(runs_dir):
                for run_id in os.listdir(runs_dir):
                    run_dir = os.path.join(runs_dir, run_id)
                    if os.path.isdir(run_dir):
                        self._import_single_run(run_dir, experiment.experiment_id, config)
            
            logger.info(f"Imported experiment: {exp_name}")
            return True
            
        except Exception as e:
            logger.error(f"Failed to import experiment {exp_data.get('name', 'unknown')}: {e}")
            return False
    
    def _import_single_run(self, run_dir: str, experiment_id: str, config: ImportConfig):
        """Import a single run."""
        try:
            # Read run metadata
            run_metadata_file = os.path.join(run_dir, "run.json")
            with open(run_metadata_file, 'r') as f:
                run_data = json.load(f)
            
            # Start new run
            with mlflow.start_run(experiment_id=experiment_id) as run:
                # Log parameters
                for key, value in run_data.get("params", {}).items():
                    mlflow.log_param(key, value)
                
                # Log metrics
                for key, value in run_data.get("metrics", {}).items():
                    mlflow.log_metric(key, value)
                
                # Log tags
                for key, value in run_data.get("tags", {}).items():
                    mlflow.set_tag(key, value)
                
                # Import artifacts
                artifacts_dir = os.path.join(run_dir, "artifacts")
                if os.path.exists(artifacts_dir):
                    mlflow.log_artifacts(artifacts_dir)
            
            logger.debug(f"Imported run: {run_data['run_id']}")
            
        except Exception as e:
            logger.error(f"Failed to import run from {run_dir}: {e}")
    
    def _import_models(self, config: ImportConfig) -> int:
        """Import models from export directory."""
        models_dir = os.path.join(config.input_dir, "models")
        
        if not os.path.exists(models_dir):
            logger.warning("No models directory found")
            return 0
        
        # Read models metadata
        metadata_file = os.path.join(models_dir, "models.json")
        if not os.path.exists(metadata_file):
            logger.warning("No models metadata found")
            return 0
        
        with open(metadata_file, 'r') as f:
            metadata = json.load(f)
        
        imported_count = 0
        
        for model_data in metadata["models"]:
            if self._import_single_model(model_data, models_dir, config):
                imported_count += 1
        
        return imported_count
    
    def _import_single_model(self, model_data: Dict, models_dir: str, config: ImportConfig) -> bool:
        """Import a single model."""
        try:
            model_name = model_data["name"]
            
            # Apply rename if specified
            if config.model_renames and model_name in config.model_renames:
                model_name = config.model_renames[model_name]
            
            # Delete existing model if requested
            if config.delete_existing_models:
                try:
                    self.dest_client.delete_registered_model(model_name)
                except:
                    pass  # Model might not exist
            
            # Create model if it doesn't exist
            try:
                self.dest_client.get_registered_model(model_name)
            except:
                self.dest_client.create_registered_model(model_name)
            
            # Import model versions
            model_dir = os.path.join(models_dir, model_data["name"])
            versions_dir = os.path.join(model_dir, "versions")
            
            if os.path.exists(versions_dir):
                for version_dir in os.listdir(versions_dir):
                    version_path = os.path.join(versions_dir, version_dir)
                    if os.path.isdir(version_path):
                        self._import_model_version(version_path, model_name, config)
            
            logger.info(f"Imported model: {model_name}")
            return True
            
        except Exception as e:
            logger.error(f"Failed to import model {model_data.get('name', 'unknown')}: {e}")
            return False
    
    def _import_model_version(self, version_dir: str, model_name: str, config: ImportConfig):
        """Import a single model version."""
        try:
            # Read version metadata
            version_metadata_file = os.path.join(version_dir, "version.json")
            with open(version_metadata_file, 'r') as f:
                version_data = json.load(f)
            
            # Load and log model
            model_path = os.path.join(version_dir, "model")
            if os.path.exists(model_path):
                # Determine model flavor and log accordingly
                # This is a simplified version - you might need to handle different model types
                mlflow.log_artifacts(model_path, f"models:/{model_name}")
            
            logger.debug(f"Imported model version: {model_name}/{version_data['version']}")
            
        except Exception as e:
            logger.error(f"Failed to import model version from {version_dir}: {e}")


# CLI interface
@click.group()
def cli():
    """MLflow Export-Import Tools"""
    pass


@cli.command()
@click.option('--source-uri', required=True, help='Source MLflow tracking URI')
@click.option('--output-dir', required=True, help='Output directory for export')
@click.option('--pattern', help='Regex pattern to filter experiments by name')
@click.option('--export-models/--no-export-models', default=True, help='Export registered models')
@click.option('--export-experiments/--no-export-experiments', default=True, help='Export experiments')
@click.option('--export-artifacts/--no-export-artifacts', default=True, help='Export run artifacts')
@click.option('--use-threads/--no-threads', default=True, help='Use parallel processing')
@click.option('--max-workers', default=4, help='Maximum number of worker threads')
def export(source_uri, output_dir, pattern, export_models, export_experiments, 
           export_artifacts, use_threads, max_workers):
    """Export MLflow data from tracking server."""
    
    config = ExportConfig(
        output_dir=output_dir,
        export_models=export_models,
        export_experiments=export_experiments,
        export_artifacts=export_artifacts,
        use_threads=use_threads,
        max_workers=max_workers
    )
    
    tool = MLflowExportImportTool(source_uri)
    
    if pattern:
        result = tool.export_experiments_by_pattern(pattern, config)
    else:
        result = tool.export_all(config)
    
    if result.success:
        click.echo(f"Export completed successfully in {result.duration:.2f}s")
        click.echo(f"Experiments exported: {result.experiments_exported}")
        click.echo(f"Models exported: {result.models_exported}")
    else:
        click.echo(f"Export failed: {result.errors}")
        exit(1)


@cli.command()
@click.option('--dest-uri', required=True, help='Destination MLflow tracking URI')
@click.option('--input-dir', required=True, help='Input directory from export')
@click.option('--import-models/--no-import-models', default=True, help='Import registered models')
@click.option('--import-experiments/--no-import-experiments', default=True, help='Import experiments')
@click.option('--delete-existing-models', is_flag=True, help='Delete existing models before import')
@click.option('--use-threads/--no-threads', default=True, help='Use parallel processing')
@click.option('--max-workers', default=4, help='Maximum number of worker threads')
def import_data(dest_uri, input_dir, import_models, import_experiments, 
                delete_existing_models, use_threads, max_workers):
    """Import MLflow data to tracking server."""
    
    config = ImportConfig(
        input_dir=input_dir,
        import_models=import_models,
        import_experiments=import_experiments,
        delete_existing_models=delete_existing_models,
        use_threads=use_threads,
        max_workers=max_workers
    )
    
    tool = MLflowExportImportTool("", dest_uri)
    result = tool.import_all(config)
    
    if result.success:
        click.echo(f"Import completed successfully in {result.duration:.2f}s")
        click.echo(f"Experiments imported: {result.experiments_imported}")
        click.echo(f"Models imported: {result.models_imported}")
    else:
        click.echo(f"Import failed: {result.errors}")
        exit(1)


if __name__ == "__main__":
    cli() 