"""
Experiment management tools:
- Automatic experiment configuration logging
- Experiment result version control
- Automatic experiment report generation
"""

import os
import json
import time
import hashlib
import pickle
import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import shutil

from utils.logger import setup_logger
from utils.io_utils import save_results, load_results

logger = setup_logger(__name__)


class ExperimentManager:
    """
    Experiment manager for tracking and organizing experiments
    """
    
    def __init__(self, experiment_name: str, base_dir: str = "./outputs"):
        """
        Initialize the experiment manager
        
        Args:
            experiment_name: Name of the experiment
            base_dir: Base directory
        """
        self.experiment_name = experiment_name
        self.base_dir = Path(base_dir)
        self.experiment_dir = self.base_dir / experiment_name
        self.results_dir = self.experiment_dir / "results"
        self.config_dir = self.experiment_dir / "configs"
        self.plots_dir = self.experiment_dir / "plots"
        self.reports_dir = self.experiment_dir / "reports"
        
        # Create necessary directories
        self.experiment_dir.mkdir(parents=True, exist_ok=True)
        self.results_dir.mkdir(parents=True, exist_ok=True)
        self.config_dir.mkdir(parents=True, exist_ok=True)
        self.plots_dir.mkdir(parents=True, exist_ok=True)
        self.reports_dir.mkdir(parents=True, exist_ok=True)
        
        # Initialize experiment records
        self.config = {}
        self.results = {}
        self.metadata = {
            "experiment_name": experiment_name,
            "created_at": time.time(),
            "start_time": None,
            "end_time": None,
            "status": "initialized"
        }
        
        # Save initial state
        self._save_metadata()
        
        logger.info(f"Initialized experiment manager, name: {experiment_name}, dir: {self.experiment_dir}")
    
    def log_config(self, config: Dict):
        """
        Record experiment configuration
        
        Args:
            config: Configuration dictionary
        """
        # Calculate configuration hash value
        config_str = json.dumps(config, sort_keys=True)
        config_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
        
        # Update configuration
        self.config = {
            "config": config,
            "config_hash": config_hash,
            "config_str": config_str
        }
        
        # Save configuration file
        config_path = self.config_dir / f"config_{config_hash}.json"
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)
        
        # Create soft link to latest configuration
        latest_config_path = self.config_dir / "latest_config.json"
        if latest_config_path.exists():
            latest_config_path.unlink()
        
        os.symlink(config_path, latest_config_path)
        
        logger.info(f"Recording experiment configuration, hash: {config_hash}")
        
        # Update metadata
        self.metadata["config_hash"] = config_hash
        self._save_metadata()
    
    def log_results(self, results: Dict, version: Optional[str] = None):
        """
        Record experiment results
        
        Args:
            results: Results dictionary
            version: Results version
        """
        # Generate results version
        if version is None:
            version = f"{int(time.time())}"
        
        # Save results
        results_path = self.results_dir / f"results_{version}.pkl"
        with open(results_path, 'wb') as f:
            pickle.dump(results, f)
        
        # Also save as JSON format (if serializable)
        try:
            json_path = self.results_dir / f"results_{version}.json"
            with open(json_path, 'w') as f:
                json.dump(results, f, indent=2, default=str)
        except (TypeError, ValueError):
            logger.warning("Unable to save results as JSON format")
        
        # Update results record
        self.results = {
            "latest_version": version,
            "versions": self._get_result_versions(),
            "latest_results": results
        }
        
        # Create symlink to latest results
        latest_results_path = self.results_dir / "latest_results.pkl"
        if latest_results_path.exists():
            latest_results_path.unlink()
        
        os.symlink(results_path, latest_results_path)
        
        logger.info(f"Logged experiment results, version: {version}")
        
        # Update metadata
        if self.metadata["status"] != "running":
            self.metadata["status"] = "running"
            self.metadata["start_time"] = time.time()
        
        self.metadata["latest_result_version"] = version
        self._save_metadata()
    
    def log_metrics(self, metrics: Dict, step: Optional[int] = None):
        """
        Record metrics (for monitoring during training)
        
        Args:
            metrics: Metrics dictionary
            step: Step number
        """
        # Generate metrics record
        timestamp = time.time()
        metric_record = {
            "timestamp": timestamp,
            "step": step if step is not None else "N/A",
            "metrics": metrics
        }
        
        # Add to metrics log
        metrics_log_path = self.experiment_dir / "metrics_log.json"
        
        if metrics_log_path.exists():
            with open(metrics_log_path, 'r') as f:
                metrics_log = json.load(f)
        else:
            metrics_log = []
        
        metrics_log.append(metric_record)
        
        with open(metrics_log_path, 'w') as f:
            json.dump(metrics_log, f, indent=2)
        
        # Optionally, add to TensorBoard or other logging systems
        # Simplified handling here, only save to JSON
        pass
    
    def _save_metadata(self):
        """Save metadata to file"""
        metadata_path = self.experiment_dir / "metadata.json"
        with open(metadata_path, 'w') as f:
            json.dump(self.metadata, f, indent=2)
    
    def _load_metadata(self):
        """Load metadata from file"""
        metadata_path = self.experiment_dir / "metadata.json"
        if metadata_path.exists():
            with open(metadata_path, 'r') as f:
                return json.load(f)
        else:
            return {}
    
    def save_plot(self, fig, filename: str):
        """Save plot figure
        
        Args:
            fig: Figure object
            filename: Filename to save
        """
        # Generate filename
        plot_path = self.plots_dir / filename
        
        # Save plot
        fig.savefig(plot_path, dpi=150, bbox_inches='tight')
        
        # Create soft link to latest plot
        latest_plot_path = self.plots_dir / "latest_plot.png"
        if latest_plot_path.exists():
            latest_plot_path.unlink()
            
        os.symlink(plot_path, latest_plot_path)
    
    def save_report(self, report: str, filename: str = "report.md"):
        """Save experiment report
        
        Args:
            report: Report content
            filename: Filename to save
        """
        # Generate report content
        report_path = self.reports_dir / filename
        with open(report_path, 'w', encoding='utf-8') as f:
            f.write(report)
        
        # Create soft link to latest report
        latest_report_path = self.reports_dir / "latest_report.md"
        if latest_report_path.exists():
            latest_report_path.unlink()
            
        os.symlink(report_path, latest_report_path)
    
    def load_latest_results(self):
        """Load latest experiment results"""
        # Load latest results
        latest_result_path = self.results_dir / "latest_results.pkl"
        if latest_result_path.exists():
            with open(latest_result_path, 'rb') as f:
                return pickle.load(f)
        else:
            return None
    
    def start_experiment(self):
        """
        Mark experiment start
        """
        self.metadata["status"] = "running"
        self.metadata["start_time"] = time.time()
        self._save_metadata()
        
        logger.info(f"Experiment started: {self.experiment_name}")
    
    def end_experiment(self):
        """
        Mark experiment end
        """
        self.metadata["status"] = "completed"
        self.metadata["end_time"] = time.time()
        
        if self.metadata["start_time"] is not None:
            duration = self.metadata["end_time"] - self.metadata["start_time"]
            self.metadata["duration_seconds"] = duration
            self.metadata["duration_human"] = str(datetime.timedelta(seconds=duration))
        
        self._save_metadata()
        
        logger.info(f"Experiment ended: {self.experiment_name}, duration: {self.metadata.get('duration_human', 'N/A')}")
    
    def generate_report(self, template: Optional[str] = None):
        """
        Automatically generate experiment report
        
        Args:
            template: Report template path
        """
        # Load latest results
        latest_results_path = self.results_dir / "latest_results.pkl"
        
        if not latest_results_path.exists():
            logger.warning("No available results, unable to generate report")
            return
        
        # Load results
        with open(latest_results_path, 'rb') as f:
            results = pickle.load(f)
        
        # Generate report content
        report_content = self._generate_report_content(results, template)
        
        # Save report
        timestamp = int(time.time())
        report_path = self.reports_dir / f"report_{timestamp}.md"
        
        with open(report_path, 'w') as f:
            f.write(report_content)
        
        # Create symlink to latest report
        latest_report_path = self.reports_dir / "latest_report.md"
        if latest_report_path.exists():
            latest_report_path.unlink()
        
        os.symlink(report_path, latest_report_path)
        
        logger.info(f"Generated experiment report: {report_path}")
        
        return report_path
    
    def compare_experiments(self, other_experiment_names: List[str], 
                          metrics: List[str] = None) -> Dict:
        """
        Compare results of multiple experiments
        
        Args:
            other_experiment_names: List of other experiment names
            metrics: List of metrics to compare
        
        Returns:
            Comparison results
        """
        # Load current experiment results
        current_results = self.load_latest_results()
        
        # Load other experiment results
        other_results = {}
        for exp_name in other_experiment_names:
            exp_dir = self.base_dir / exp_name
            if exp_dir.exists():
                exp_manager = ExperimentManager(exp_name, str(self.base_dir))
                other_results[exp_name] = exp_manager.load_latest_results()
        
        # Compare results
        comparison = self._compare_results(
            {"current": current_results},
            other_results,
            metrics
        )
        
        # Save comparison results
        comparison_path = self.reports_dir / f"comparison_{int(time.time())}.json"
        with open(comparison_path, 'w') as f:
            json.dump(comparison, f, indent=2, default=str)
        
        logger.info(f"Generated experiment comparison: {comparison_path}")
        
        return comparison
    
    def load_config(self, version: Optional[str] = None) -> Dict:
        """
        Load configuration
        
        Args:
            version: Configuration version
        
        Returns:
            Configuration dictionary
        """
        if version is None:
            config_path = self.config_dir / "latest_config.json"
        else:
            config_path = self.config_dir / f"config_{version}.json"
        
        if config_path.exists():
            with open(config_path, 'r') as f:
                return json.load(f)
        else:
            return {}
    
    def get_metrics_log(self) -> List[Dict]:
        """
        Get metrics log
        
        Returns:
            List of metric records
        """
        metrics_log_path = self.experiment_dir / "metrics_log.json"
        
        if metrics_log_path.exists():
            with open(metrics_log_path, 'r') as f:
                return json.load(f)
        else:
            return []
    
    def get_experiment_info(self) -> Dict:
        """
        Get experiment information
        
        Returns:
            Experiment information dictionary
        """
        # Load metadata
        self._load_metadata()
        
        # Count number of results
        result_versions = self._get_result_versions()
        
        # Count number of plots
        plot_files = list(self.plots_dir.glob("*"))
        plot_count = len([f for f in plot_files if not f.is_symlink()])
        
        # Count number of reports
        report_files = list(self.reports_dir.glob("*.md"))
        report_count = len([f for f in report_files if not f.is_symlink()])
        
        return {
            "metadata": self.metadata,
            "result_versions": result_versions,
            "result_count": len(result_versions),
            "plot_count": plot_count,
            "report_count": report_count,
            "directory": str(self.experiment_dir)
        }
    
    def archive_experiment(self, archive_name: Optional[str] = None):
        """
        Archive experiment
        
        Args:
            archive_name: Archive name
        """
        if archive_name is None:
            timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            archive_name = f"{self.experiment_name}_{timestamp}"
        
        # Create archive directory
        archive_dir = self.base_dir / "archive"
        archive_dir.mkdir(parents=True, exist_ok=True)
        
        archive_path = archive_dir / f"{archive_name}.tar.gz"
        
        # Create archive
        shutil.make_archive(
            str(archive_path.with_suffix('')),
            'gztar',
            str(self.experiment_dir.parent),
            self.experiment_dir.name
        )
        
        logger.info(f"Experiment archived: {archive_path}")
        
        return archive_path
    
    def _get_result_versions(self) -> List[str]:
        """
        Get all result versions
        
        Returns:
            List of result versions
        """
        result_files = list(self.results_dir.glob("results_*.pkl"))
        versions = []
        
        for file in result_files:
            # Extract version number
            filename = file.stem  # Remove .pkl extension
            if filename.startswith("results_"):
                version = filename[len("results_"):]
                versions.append(version)
        
        return sorted(versions)
    
    def _compare_results(self, base_results: Dict, other_results: Dict, 
                         metrics: List[str] = None) -> Dict:
        """
        Compare multiple results
        
        Args:
            base_results: Base results
            other_results: Other results
            metrics: Metrics to compare
        
        Returns:
            Comparison results
        """
        if metrics is None:
            # If no metrics specified, try to auto-extract
            all_metrics = set()
            
            def extract_metrics(d, prefix=""):
                for k, v in d.items():
                    if isinstance(v, (int, float)):
                        all_metrics.add(prefix + k)
                    elif isinstance(v, dict):
                        extract_metrics(v, prefix + k + ".")
            
            for results in [base_results] + list(other_results.values()):
                extract_metrics(results)
            
            metrics = list(all_metrics)
        
        # Compare metrics
        comparison = {
            "experiments": {
                "base": base_results,
                **other_results
            },
            "metrics": metrics,
            "differences": {}
        }
        
        for metric in metrics:
            comparison["differences"][metric] = {}
            
            def get_metric_value(d, path):
                parts = path.split(".")
                value = d
                for part in parts:
                    if part in value:
                        value = value[part]
                    else:
                        return None
                return value
            
            base_value = get_metric_value(base_results, metric)
            
            if base_value is not None:
                comparison["differences"][metric]["base"] = base_value
                
                for exp_name, exp_results in other_results.items():
                    exp_value = get_metric_value(exp_results, metric)
                    
                    if exp_value is not None:
                        diff = exp_value - base_value
                        pct_change = (diff / base_value) * 100 if base_value != 0 else float('inf')
                        
                        comparison["differences"][metric][exp_name] = {
                            "value": exp_value,
                            "diff": diff,
                            "pct_change": pct_change
                        }
        
        return comparison
    
    def _generate_report_content(self, results: Dict, template_path: Optional[str] = None) -> str:
        """
        Generate report content
        
        Args:
            results: Results dictionary
            template_path: Template path
        
        Returns:
            Report content
        """
        # If template provided, use template
        if template_path and Path(template_path).exists():
            with open(template_path, 'r') as f:
                template = f.read()
            
            # Simple template replacement (can use more complex template engine in real project)
            content = template
            
            # Replace basic variables
            content = content.replace("{{experiment_name}}", self.experiment_name)
            content = content.replace("{{current_time}}", datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
            
            # Can add more template replacement logic here
            
            return content
        
        # Otherwise use default template
        content = [
            f"# Experiment Report: {self.experiment_name}",
            "",
            f"Generated at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
            "",
            "## Experiment Configuration",
            "",
            f"- Experiment Name: {self.experiment_name}",
            f"- Configuration Hash: {self.metadata.get('config_hash', 'N/A')}",
            f"- Start Time: {datetime.datetime.fromtimestamp(self.metadata.get('start_time', 0)).strftime('%Y-%m-%d %H:%M:%S') if self.metadata.get('start_time') else 'N/A'}",
            f"- End Time: {datetime.datetime.fromtimestamp(self.metadata.get('end_time', 0)).strftime('%Y-%m-%d %H:%M:%S') if self.metadata.get('end_time') else 'N/A'}",
            f"- Experiment Status: {self.metadata.get('status', 'N/A')}",
            f"- Experiment Duration: {self.metadata.get('duration_human', 'N/A')}",
            "",
            "## Experiment Results",
            "",
            "``json",
            json.dumps(results, indent=2, default=str),
            "```",
            "",
            "## Result Analysis",
            "",
            "TODO: Add result analysis",
            "",
            "## Conclusion",
            "",
            "TODO: Add conclusion"
        ]
        
        # Try to extract some key results for analysis
        if isinstance(results, dict):
            # Try to find accuracy, error rate, etc. common metrics
            key_metrics = {}
            
            def find_metrics(d, path=""):
                for k, v in d.items():
                    current_path = f"{path}.{k}" if path else k
                    
                    if isinstance(v, (int, float)) and any(
                        metric in current_path.lower() 
                        for metric in ["accuracy", "error", "loss", "precision", "recall", "f1"]
                    ):
                        key_metrics[current_path] = v
                    elif isinstance(v, dict):
                        find_metrics(v, current_path)
            
            find_metrics(results)
            
            if key_metrics:
                content.append("\n### Key Metrics")
                content.append("")
                
                for metric, value in key_metrics.items():
                    content.append(f"- {metric}: {value}")
                
                content.append("")
        
        return "\n".join(content)
    
    def __enter__(self):
        """
        Context manager entry
        """
        self.start_experiment()
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """
        Context manager exit
        """
        self.end_experiment()


if __name__ == "__main__":
    # Example usage
    with ExperimentManager("test_experiment") as exp:
        # Log configuration
        config = {
            "learning_rate": 0.001,
            "batch_size": 32,
            "num_epochs": 10
        }
        exp.log_config(config)
        
        # Log metrics
        for step in range(5):
            metrics = {
                "loss": 0.1 * (0.9 ** step),
                "accuracy": 0.5 + 0.1 * step
            }
            exp.log_metrics(metrics, step)
        
        # Log results
        results = {
            "test_accuracy": 0.92,
            "test_loss": 0.05,
            "training_time": 123.45
        }
        exp.log_results(results)
        
        # Generate report
        exp.generate_report()
    
    # Get experiment information
    info = exp.get_experiment_info()
    print(json.dumps(info, indent=2, default=str))