"""
Best Checkpoint Tracker for VERL Training

Tracks best validation accuracy per evaluation dataset, saves checkpoints in
legacy-compatible flat format, and handles async uploads to HuggingFace.

Usage:
    tracker = BestCheckpointTracker(
        output_dir="/path/to/output",
        eval_dataset_names=["bridges_5x5de_intformat", "bridges_5x5dm_test200_intformat"],
        hf_repo_id="anon-neurips26/my-model",
        hf_token="hf_xxx"
    )
    
    # After each validation in VERL training loop
    tracker.check_and_save(val_metrics, global_step, verl_checkpoint_dir)
    
    # At end of training
    tracker.wait_for_uploads()
"""

import os
import shutil
import json
from datetime import datetime
from typing import List, Dict, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, Future


class BestCheckpointTracker:
    """
    Track best validation accuracy per dataset, save LoRA checkpoints, 
    and async upload to HuggingFace Hub.
    
    Directory structure matches our legacy flat layout:
    {output_dir}/
    ├── best_checkpoint_{dataset1}/
    │   ├── adapter_config.json
    │   ├── adapter_model.safetensors
    │   ├── best_checkpoint_info.txt
    │   ├── tokenizer.json
    │   └── ...
    └── best_checkpoint_{dataset2}/
        └── ...
    """
    
    def __init__(
        self,
        output_dir: str,
        eval_dataset_names: List[str],
        hf_repo_id: Optional[str] = None,
        hf_token: Optional[str] = None,
        async_upload: bool = True
    ):
        """
        Initialize the tracker.
        
        Args:
            output_dir: Directory to save best checkpoints
            eval_dataset_names: List of dataset names to track (used for metric parsing)
            hf_repo_id: HuggingFace repo to upload to (e.g., "anon-neurips26/my-model")
            hf_token: HuggingFace API token for uploads
            async_upload: Whether to upload asynchronously (non-blocking)
        """
        self.output_dir = output_dir
        self.eval_dataset_names = eval_dataset_names
        self.hf_repo_id = hf_repo_id
        self.hf_token = hf_token
        self.async_upload = async_upload
        
        # Track best accuracy per dataset
        self.best_accuracies: Dict[str, float] = {ds: -1.0 for ds in eval_dataset_names}
        
        # Best checkpoint directories
        self.best_checkpoint_dirs: Dict[str, str] = {
            ds: os.path.join(output_dir, f"best_checkpoint_{ds}")
            for ds in eval_dataset_names
        }
        
        # Async upload executor (single worker to serialize uploads)
        self._upload_executor: Optional[ThreadPoolExecutor] = None
        if async_upload and hf_repo_id:
            self._upload_executor = ThreadPoolExecutor(max_workers=1)
        self._pending_uploads: List[Future] = []
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
        print(f"[BestCheckpointTracker] Initialized")
        print(f"  Output dir: {output_dir}")
        print(f"  Tracking datasets: {eval_dataset_names}")
        if hf_repo_id:
            print(f"  HuggingFace upload: {hf_repo_id} (async={async_upload})")
    
    def check_and_save(
        self, 
        val_metrics: Dict[str, float], 
        global_step: int, 
        verl_checkpoint_dir: str
    ) -> Dict[str, bool]:
        """
        Check if any dataset has new best accuracy, save checkpoint if so.
        
        Args:
            val_metrics: Dictionary of validation metrics from VERL
                         Expected keys like "val/test_score/bridges_5x5de_intformat"
            global_step: Current training step
            verl_checkpoint_dir: Path to VERL checkpoint directory containing actor/lora_adapter/
        
        Returns:
            Dict mapping dataset names to whether they achieved new best
        """
        results = {}
        
        # Parse metrics and check each dataset
        for ds_name, accuracy in self._parse_val_metrics(val_metrics):
            is_best = accuracy > self.best_accuracies.get(ds_name, -1.0)
            results[ds_name] = is_best
            
            if is_best:
                self.best_accuracies[ds_name] = accuracy
                best_dir = self.best_checkpoint_dirs.get(ds_name)
                
                if best_dir is None:
                    # Unknown dataset, create entry
                    best_dir = os.path.join(self.output_dir, f"best_checkpoint_{ds_name}")
                    self.best_checkpoint_dirs[ds_name] = best_dir
                
                print(f"\n{'='*80}")
                print(f"[BestCheckpointTracker] New best for {ds_name}: {accuracy:.4f} at step {global_step}")
                print(f"  Saving to: {best_dir}")
                print(f"{'='*80}\n")
                
                # Copy checkpoint files to best checkpoint directory
                self._copy_checkpoint_files(verl_checkpoint_dir, best_dir)
                
                # Write info file with accuracy and step
                self._write_info_file(best_dir, ds_name, accuracy, global_step, val_metrics)
                
                # Queue async upload to HuggingFace
                if self.hf_repo_id:
                    self._queue_upload(best_dir, f"best_checkpoint_{ds_name}")
        
        return results
    
    def _parse_val_metrics(self, val_metrics: Dict[str, float]) -> List[Tuple[str, float]]:
        """
        Parse VERL validation metrics to extract per-dataset accuracies.
        
        VERL metrics format:
            val/test_score/{data_source} -> accuracy (0.0 to 1.0)
        
        Args:
            val_metrics: Dictionary of validation metrics
        
        Returns:
            List of (dataset_name, accuracy) tuples
        """
        results = []
        
        for key, value in val_metrics.items():
            # Look for "val/test_score/{dataset}" pattern
            if key.startswith("val/test_score/"):
                ds_name = key.replace("val/test_score/", "")
                try:
                    accuracy = float(value)
                    results.append((ds_name, accuracy))
                except (TypeError, ValueError):
                    print(f"[BestCheckpointTracker] Warning: Could not parse accuracy from {key}={value}")
        
        return results
    
    def _copy_checkpoint_files(self, verl_checkpoint_dir: str, best_dir: str):
        """
        Copy LoRA adapter and tokenizer files from VERL checkpoint to best checkpoint dir.
        
        Creates flat structure matching our legacy flat layout:
            best_checkpoint_{dataset}/
            ├── adapter_config.json
            ├── adapter_model.safetensors
            ├── tokenizer.json
            ├── vocab.json
            └── ...
        
        Args:
            verl_checkpoint_dir: VERL checkpoint directory (e.g., global_step_100/)
            best_dir: Target directory for best checkpoint
        """
        # Source directories
        lora_dir = os.path.join(verl_checkpoint_dir, "actor", "lora_adapter")
        hf_dir = os.path.join(verl_checkpoint_dir, "actor", "huggingface")
        
        # Validate sources exist
        if not os.path.exists(lora_dir):
            print(f"[BestCheckpointTracker] Warning: LoRA adapter not found at {lora_dir}")
            return
        
        if not os.path.exists(hf_dir):
            print(f"[BestCheckpointTracker] Warning: HuggingFace config not found at {hf_dir}")
        
        # Create/clear best checkpoint directory
        if os.path.exists(best_dir):
            shutil.rmtree(best_dir)
        os.makedirs(best_dir, exist_ok=True)
        
        # Copy LoRA adapter files (adapter_model.safetensors, adapter_config.json)
        for filename in os.listdir(lora_dir):
            src = os.path.join(lora_dir, filename)
            dst = os.path.join(best_dir, filename)
            if os.path.isfile(src):
                shutil.copy2(src, dst)
                print(f"  Copied: {filename}")
        
        # Copy tokenizer/config files from huggingface/
        if os.path.exists(hf_dir):
            for filename in os.listdir(hf_dir):
                src = os.path.join(hf_dir, filename)
                dst = os.path.join(best_dir, filename)
                if os.path.isfile(src):
                    shutil.copy2(src, dst)
                    print(f"  Copied: {filename}")
    
    def _write_info_file(
        self, 
        best_dir: str, 
        ds_name: str, 
        accuracy: float, 
        global_step: int,
        all_metrics: Optional[Dict[str, float]] = None
    ):
        """
        Write best_checkpoint_info.txt with accuracy, step, and timestamp.
        
        Args:
            best_dir: Best checkpoint directory
            ds_name: Dataset name
            accuracy: Best accuracy achieved
            global_step: Training step when best was achieved
            all_metrics: Optional dict of all validation metrics at this step
        """
        info_file = os.path.join(best_dir, "best_checkpoint_info.txt")
        
        with open(info_file, "w") as f:
            f.write(f"Best Accuracy for {ds_name}: {accuracy:.4f}\n")
            f.write(f"Step: {global_step}\n")
            f.write(f"Timestamp: {datetime.now().isoformat()}\n")
            f.write(f"Checkpoint Path: {best_dir}\n")
            
            # Include all validation metrics at this step
            if all_metrics:
                f.write("\nAll validation metrics at this step:\n")
                for key, value in sorted(all_metrics.items()):
                    if key.startswith("val/"):
                        try:
                            f.write(f"  {key}: {float(value):.4f}\n")
                        except (TypeError, ValueError):
                            f.write(f"  {key}: {value}\n")
        
        print(f"  Wrote: best_checkpoint_info.txt")
    
    def _queue_upload(self, checkpoint_dir: str, path_in_repo: str):
        """
        Queue async upload of checkpoint to HuggingFace Hub.
        
        Args:
            checkpoint_dir: Local directory to upload
            path_in_repo: Subfolder path in HF repo (e.g., "best_checkpoint_bridges_5x5de")
        """
        if not self.hf_repo_id:
            return
        
        if self._upload_executor is None:
            # Sync upload
            self._do_upload(checkpoint_dir, path_in_repo)
        else:
            # Async upload
            future = self._upload_executor.submit(
                self._do_upload, checkpoint_dir, path_in_repo
            )
            self._pending_uploads.append(future)
            print(f"[Async Upload] Queued: {self.hf_repo_id}/{path_in_repo}")
    
    def _do_upload(self, checkpoint_dir: str, path_in_repo: str):
        """
        Actually perform the upload to HuggingFace Hub.
        
        Args:
            checkpoint_dir: Local directory to upload
            path_in_repo: Subfolder path in HF repo
        """
        try:
            from huggingface_hub import HfApi, create_repo
            
            api = HfApi(token=self.hf_token)
            
            # Create repo if it doesn't exist
            create_repo(
                repo_id=self.hf_repo_id,
                token=self.hf_token,
                private=True,
                exist_ok=True,
                repo_type="model"
            )
            
            # Upload folder
            api.upload_folder(
                folder_path=checkpoint_dir,
                repo_id=self.hf_repo_id,
                path_in_repo=path_in_repo,
                repo_type="model",
                ignore_patterns=[".git*", "__pycache__", "*.pyc"]
            )
            
            url = f"https://huggingface.co/{self.hf_repo_id}/tree/main/{path_in_repo}"
            print(f"[Upload Complete] {url}")
            
        except Exception as e:
            print(f"[Upload Error] Failed to upload {checkpoint_dir}: {e}")
    
    def wait_for_uploads(self, timeout_per_upload: int = 600):
        """
        Wait for all pending async uploads to complete.
        Call at end of training to ensure all uploads finish.
        
        Args:
            timeout_per_upload: Timeout in seconds per upload (default 10 minutes)
        """
        if not self._pending_uploads:
            return
        
        print(f"\n[BestCheckpointTracker] Waiting for {len(self._pending_uploads)} pending upload(s)...")
        
        completed = 0
        errors = 0
        
        for future in self._pending_uploads:
            try:
                future.result(timeout=timeout_per_upload)
                completed += 1
            except Exception as e:
                print(f"[Upload Error] {e}")
                errors += 1
        
        self._pending_uploads.clear()
        print(f"[BestCheckpointTracker] Uploads complete: {completed} succeeded, {errors} failed")
    
    def get_summary(self) -> Dict:
        """
        Get summary of best checkpoints.
        
        Returns:
            Dict with best accuracies and checkpoint paths
        """
        return {
            "best_accuracies": dict(self.best_accuracies),
            "best_checkpoint_dirs": dict(self.best_checkpoint_dirs),
            "hf_repo_id": self.hf_repo_id
        }
    
    def save_summary(self, path: Optional[str] = None):
        """
        Save summary to JSON file.
        
        Args:
            path: Output path (defaults to {output_dir}/best_checkpoints_summary.json)
        """
        if path is None:
            path = os.path.join(self.output_dir, "best_checkpoints_summary.json")
        
        summary = self.get_summary()
        summary["timestamp"] = datetime.now().isoformat()
        
        with open(path, "w") as f:
            json.dump(summary, f, indent=2)
        
        print(f"[BestCheckpointTracker] Summary saved to {path}")
    
    def shutdown(self):
        """
        Shutdown the tracker, waiting for uploads and saving summary.
        """
        self.wait_for_uploads()
        self.save_summary()
        
        if self._upload_executor:
            self._upload_executor.shutdown(wait=True)
            self._upload_executor = None
        
        print("[BestCheckpointTracker] Shutdown complete")


# Convenience function for creating tracker from environment variables
def create_tracker_from_env(
    output_dir: str,
    eval_dataset_names: List[str]
) -> Optional[BestCheckpointTracker]:
    """
    Create BestCheckpointTracker from environment variables.
    
    Environment variables:
        VERL_HF_REPO_ID: HuggingFace repo ID (required for upload)
        VERL_HF_TOKEN: HuggingFace API token
        VERL_ASYNC_UPLOAD: "true" or "false" (default: "true")
    
    Args:
        output_dir: Directory to save best checkpoints
        eval_dataset_names: List of dataset names to track
    
    Returns:
        BestCheckpointTracker instance or None if HF upload is disabled
    """
    hf_repo_id = os.environ.get("VERL_HF_REPO_ID")
    hf_token = os.environ.get("VERL_HF_TOKEN")
    async_upload = os.environ.get("VERL_ASYNC_UPLOAD", "true").lower() == "true"
    
    if not hf_repo_id:
        print("[BestCheckpointTracker] VERL_HF_REPO_ID not set, HuggingFace upload disabled")
    
    return BestCheckpointTracker(
        output_dir=output_dir,
        eval_dataset_names=eval_dataset_names,
        hf_repo_id=hf_repo_id,
        hf_token=hf_token,
        async_upload=async_upload
    )
