#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Checkpoint-Model Validation Script for TorchTitan

This script validates that all model attributes match between checkpoints and the actual model,
providing comprehensive logging and warnings for any mismatches.

Usage:
    python validate_checkpoint_model_match.py --checkpoint_path <path> --model_name <name> --model_flavor <flavor>

    Or import as a module:
    from validate_checkpoint_model_match import validate_checkpoint_model_match
    validate_checkpoint_model_match(checkpoint_path, model, logger)
"""

import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any, Dict, Optional

import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint import HuggingFaceStorageReader

# Add the torchtitan path to import modules
sys.path.append(str(Path(__file__).parent.parent))

try:
    import torchtitan.protocols.train_spec as train_spec_module
    from torchtitan.components.checkpoint import (
        excluded_parameters_for_model_only,
        ModelWrapper,
    )
except ImportError as e:
    print(f"Error importing TorchTitan modules: {e}")
    print("Make sure you're running this script from the TorchTitan root directory")
    sys.exit(1)


class CheckpointModelValidator:
    """
    Validates model state dict against checkpoint state dict with comprehensive logging.
    """

    def __init__(self, logger: Optional[logging.Logger] = None):
        self.logger = logger or self._setup_default_logger()
        self.validation_results = {
            "missing_in_model": set(),
            "missing_in_checkpoint": set(),
            "shape_mismatches": {},
            "dtype_mismatches": {},
            "matched_keys": set(),
            "excluded_keys": set(),
        }

    def _setup_default_logger(self) -> logging.Logger:
        """Setup a default logger if none provided."""
        logger = logging.getLogger("checkpoint_validator")
        logger.setLevel(logging.INFO)

        if not logger.handlers:
            handler = logging.StreamHandler(sys.stdout)
            formatter = logging.Formatter(
                "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
            )
            handler.setFormatter(formatter)
            logger.addHandler(handler)

        return logger

    def validate_checkpoint_model_match(
        self,
        checkpoint_path: str,
        model: nn.Module,
        from_hf: bool = False,
        sd_adapter: Optional[Any] = None,
        strict_validation: bool = False,
    ) -> Dict[str, Any]:
        """
        Validate that model attributes match checkpoint attributes.

        Args:
            checkpoint_path: Path to the checkpoint directory or file
            model: The PyTorch model to validate against
            from_hf: Whether the checkpoint is in HuggingFace format
            sd_adapter: State dict adapter for format conversion
            strict_validation: If True, raise exceptions on mismatches

        Returns:
            Dictionary containing validation results
        """
        self.logger.info("=" * 80)
        self.logger.info("CHECKPOINT-MODEL VALIDATION STARTING")
        self.logger.info("=" * 80)

        # Clear previous results
        self.validation_results = {
            "missing_in_model": set(),
            "missing_in_checkpoint": set(),
            "shape_mismatches": {},
            "dtype_mismatches": {},
            "matched_keys": set(),
            "excluded_keys": set(),
        }

        try:
            # Get model state dict
            model_state_dict = self._get_model_state_dict(model)
            self.logger.info(f"Model has {len(model_state_dict)} parameters/buffers")

            # Load checkpoint state dict
            checkpoint_state_dict = self._load_checkpoint_state_dict(
                checkpoint_path, from_hf, sd_adapter, model_state_dict
            )
            self.logger.info(
                f"Checkpoint has {len(checkpoint_state_dict)} parameters/buffers"
            )

            # Perform validation
            self._compare_state_dicts(model_state_dict, checkpoint_state_dict)

            # Log results
            self._log_validation_results()

            # Handle strict validation
            if strict_validation:
                self._enforce_strict_validation()

        except Exception as e:
            self.logger.error(f"Validation failed with error: {e}")
            raise

        self.logger.info("=" * 80)
        self.logger.info("CHECKPOINT-MODEL VALIDATION COMPLETED")
        self.logger.info("=" * 80)

        return self.validation_results

    def _get_model_state_dict(self, model: nn.Module) -> Dict[str, torch.Tensor]:
        """Get the model state dict, handling ModelWrapper if needed."""
        if isinstance(model, ModelWrapper):
            state_dict = model.state_dict()
        else:
            state_dict = model.state_dict()

        # Apply exclusions similar to ModelWrapper
        for excluded_key in excluded_parameters_for_model_only:
            if excluded_key in state_dict:
                state_dict.pop(excluded_key)
                self.validation_results["excluded_keys"].add(excluded_key)
                self.logger.info(f"Excluded parameter from model: {excluded_key}")

        return state_dict

    def _load_checkpoint_state_dict(
        self,
        checkpoint_path: str,
        from_hf: bool,
        sd_adapter: Optional[Any],
        model_state_dict: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        """Load checkpoint state dict with appropriate format handling."""

        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(
                f"Checkpoint path does not exist: {checkpoint_path}"
            )

        self.logger.info(f"Loading checkpoint from: {checkpoint_path}")
        self.logger.info(f"HuggingFace format: {from_hf}")

        if from_hf:
            return self._load_hf_checkpoint(
                checkpoint_path, sd_adapter, model_state_dict
            )
        else:
            return self._load_dcp_checkpoint(checkpoint_path, model_state_dict)

    def _load_hf_checkpoint(
        self,
        checkpoint_path: str,
        sd_adapter: Optional[Any],
        model_state_dict: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        """Load HuggingFace format checkpoint."""
        if sd_adapter is None:
            raise ValueError(
                "sd_adapter is required for HuggingFace format checkpoints"
            )

        # Convert model state dict to HF format for loading
        hf_state_dict = sd_adapter.to_hf(model_state_dict)

        # Load from HF checkpoint
        dcp.load(
            hf_state_dict, storage_reader=HuggingFaceStorageReader(path=checkpoint_path)
        )

        # Convert back to TorchTitan format
        return sd_adapter.from_hf(hf_state_dict)

    def _load_dcp_checkpoint(
        self, checkpoint_path: str, model_state_dict: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        """Load DCP (Distributed Checkpoint) format checkpoint.

        DCP requires a destination structure with the expected keys/shapes.
        We use the provided model_state_dict as the destination so keys can be populated.
        """
        # Make a shallow copy so we don't mutate the original reference
        destination: Dict[str, torch.Tensor] = {
            k: v.clone() for k, v in model_state_dict.items()
        }

        try:
            # Load the checkpoint
            dcp.load(destination, checkpoint_id=checkpoint_path)
            self.logger.info(
                "Successfully loaded DCP checkpoint into provided model structure"
            )

            # Keep only keys actually populated (destination has all model keys)
            populated = {
                k: v
                for k, v in destination.items()
                if isinstance(v, torch.Tensor) and v.numel() > 0
            }
            self.logger.info(
                f"Found {len(populated)} model parameters in checkpoint (after load)"
            )

            return populated

        except Exception as e:
            self.logger.error(f"Failed to load DCP checkpoint: {e}")
            raise

    def _compare_state_dicts(
        self,
        model_state_dict: Dict[str, torch.Tensor],
        checkpoint_state_dict: Dict[str, torch.Tensor],
    ) -> None:
        """Compare model and checkpoint state dicts."""

        model_keys = set(model_state_dict.keys())
        checkpoint_keys = set(checkpoint_state_dict.keys())

        self.logger.info(
            f"Comparing {len(model_keys)} model keys with {len(checkpoint_keys)} checkpoint keys"
        )

        # Find missing keys
        self.validation_results["missing_in_checkpoint"] = model_keys - checkpoint_keys
        self.validation_results["missing_in_model"] = checkpoint_keys - model_keys

        # Find matching keys
        common_keys = model_keys & checkpoint_keys
        self.validation_results["matched_keys"] = common_keys

        # Check shapes and dtypes for common keys
        for key in common_keys:
            model_tensor = model_state_dict[key]
            checkpoint_tensor = checkpoint_state_dict[key]

            # Shape comparison
            if model_tensor.shape != checkpoint_tensor.shape:
                self.validation_results["shape_mismatches"][key] = {
                    "model_shape": model_tensor.shape,
                    "checkpoint_shape": checkpoint_tensor.shape,
                }

            # Dtype comparison
            if model_tensor.dtype != checkpoint_tensor.dtype:
                self.validation_results["dtype_mismatches"][key] = {
                    "model_dtype": model_tensor.dtype,
                    "checkpoint_dtype": checkpoint_tensor.dtype,
                }

    def _log_validation_results(self) -> None:
        """Log comprehensive validation results."""

        # Summary statistics
        total_model_keys = len(self.validation_results["matched_keys"]) + len(
            self.validation_results["missing_in_checkpoint"]
        )
        total_checkpoint_keys = len(self.validation_results["matched_keys"]) + len(
            self.validation_results["missing_in_model"]
        )

        self.logger.info("\n" + "=" * 60)
        self.logger.info("VALIDATION SUMMARY")
        self.logger.info("=" * 60)
        self.logger.info(f"Total model parameters/buffers: {total_model_keys}")
        self.logger.info(
            f"Total checkpoint parameters/buffers: {total_checkpoint_keys}"
        )
        self.logger.info(
            f"Matched keys: {len(self.validation_results['matched_keys'])}"
        )
        self.logger.info(
            f"Missing in checkpoint: {len(self.validation_results['missing_in_checkpoint'])}"
        )
        self.logger.info(
            f"Missing in model: {len(self.validation_results['missing_in_model'])}"
        )
        self.logger.info(
            f"Shape mismatches: {len(self.validation_results['shape_mismatches'])}"
        )
        self.logger.info(
            f"Dtype mismatches: {len(self.validation_results['dtype_mismatches'])}"
        )
        self.logger.info(
            f"Excluded parameters: {len(self.validation_results['excluded_keys'])}"
        )

        # Detailed logging for issues
        if self.validation_results["missing_in_checkpoint"]:
            self.logger.warning("\n" + "-" * 40)
            self.logger.warning("MISSING IN CHECKPOINT (present in model)")
            self.logger.warning("-" * 40)
            for key in sorted(self.validation_results["missing_in_checkpoint"]):
                self.logger.warning(f"  ❌ {key}")

        if self.validation_results["missing_in_model"]:
            self.logger.warning("\n" + "-" * 40)
            self.logger.warning("MISSING IN MODEL (present in checkpoint)")
            self.logger.warning("-" * 40)
            for key in sorted(self.validation_results["missing_in_model"]):
                self.logger.warning(f"  ❌ {key}")

        if self.validation_results["shape_mismatches"]:
            self.logger.error("\n" + "-" * 40)
            self.logger.error("SHAPE MISMATCHES")
            self.logger.error("-" * 40)
            for key, mismatch in self.validation_results["shape_mismatches"].items():
                self.logger.error(
                    f"  ❌ {key}: model={mismatch['model_shape']} vs "
                    f"checkpoint={mismatch['checkpoint_shape']}"
                )

        if self.validation_results["dtype_mismatches"]:
            self.logger.warning("\n" + "-" * 40)
            self.logger.warning("DTYPE MISMATCHES")
            self.logger.warning("-" * 40)
            for key, mismatch in self.validation_results["dtype_mismatches"].items():
                self.logger.warning(
                    f"  ⚠️  {key}: model={mismatch['model_dtype']} vs "
                    f"checkpoint={mismatch['checkpoint_dtype']}"
                )

        if self.validation_results["excluded_keys"]:
            self.logger.info("\n" + "-" * 40)
            self.logger.info("EXCLUDED PARAMETERS")
            self.logger.info("-" * 40)
            for key in sorted(self.validation_results["excluded_keys"]):
                self.logger.info(f"  ℹ️  {key}")

        # Log successful matches (sample)
        if self.validation_results["matched_keys"]:
            matched_sample = sorted(list(self.validation_results["matched_keys"]))[:10]
            self.logger.info("\n" + "-" * 40)
            self.logger.info(
                f"SUCCESSFULLY MATCHED PARAMETERS (showing {len(matched_sample)}/{len(self.validation_results['matched_keys'])})"
            )
            self.logger.info("-" * 40)
            for key in matched_sample:
                self.logger.info(f"  ✅ {key}")
            if len(self.validation_results["matched_keys"]) > 10:
                self.logger.info(
                    f"  ... and {len(self.validation_results['matched_keys']) - 10} more"
                )

    def _enforce_strict_validation(self) -> None:
        """Enforce strict validation by raising exceptions for mismatches."""
        errors = []

        if self.validation_results["missing_in_checkpoint"]:
            errors.append(
                f"Missing in checkpoint: {self.validation_results['missing_in_checkpoint']}"
            )

        if self.validation_results["missing_in_model"]:
            errors.append(
                f"Missing in model: {self.validation_results['missing_in_model']}"
            )

        if self.validation_results["shape_mismatches"]:
            errors.append(
                f"Shape mismatches: {list(self.validation_results['shape_mismatches'].keys())}"
            )

        if errors:
            raise ValueError("Strict validation failed:\n" + "\n".join(errors))


def validate_checkpoint_model_match(
    checkpoint_path: str,
    model: nn.Module,
    logger: Optional[logging.Logger] = None,
    from_hf: bool = False,
    sd_adapter: Optional[Any] = None,
    strict_validation: bool = False,
) -> Dict[str, Any]:
    """
    Convenience function to validate checkpoint-model match.

    Args:
        checkpoint_path: Path to the checkpoint
        model: The PyTorch model to validate
        logger: Optional logger instance
        from_hf: Whether checkpoint is in HuggingFace format
        sd_adapter: State dict adapter for format conversion
        strict_validation: Whether to raise exceptions on mismatches

    Returns:
        Dictionary containing validation results
    """
    validator = CheckpointModelValidator(logger)
    return validator.validate_checkpoint_model_match(
        checkpoint_path, model, from_hf, sd_adapter, strict_validation
    )


def main():
    """Main function for command-line usage."""
    parser = argparse.ArgumentParser(
        description="Validate model attributes against checkpoint"
    )
    parser.add_argument(
        "--checkpoint_path",
        required=True,
        help="Path to the checkpoint directory or file",
    )
    parser.add_argument(
        "--model_name", required=True, help="Model name (e.g., 'llama3')"
    )
    parser.add_argument(
        "--model_flavor", required=True, help="Model flavor/size (e.g., '8b', '1b')"
    )
    parser.add_argument(
        "--from_hf", action="store_true", help="Checkpoint is in HuggingFace format"
    )
    parser.add_argument(
        "--strict",
        action="store_true",
        help="Enable strict validation (raise exceptions on mismatches)",
    )
    parser.add_argument(
        "--log_level",
        default="INFO",
        choices=["DEBUG", "INFO", "WARNING", "ERROR"],
        help="Logging level",
    )
    parser.add_argument(
        "--print_all_matched",
        action="store_true",
        help="Print all matched parameter keys after validation",
    )

    args = parser.parse_args()

    # Setup logging
    logging.basicConfig(
        level=getattr(logging, args.log_level),
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )
    logger = logging.getLogger("checkpoint_validator")

    try:
        # Load model
        logger.info(f"Loading model: {args.model_name} ({args.model_flavor})")
        train_spec = train_spec_module.get_train_spec(args.model_name)

        # Check if model flavor exists and provide helpful error message
        if args.model_flavor not in train_spec.model_args:
            available_flavors = list(train_spec.model_args.keys())
            logger.error(
                f"Model flavor '{args.model_flavor}' not found for model '{args.model_name}'"
            )
            logger.error(f"Available flavors: {available_flavors}")

            # Check for common case mismatches
            flavor_lower = args.model_flavor.lower()
            flavor_upper = args.model_flavor.upper()
            suggestions = []
            for available in available_flavors:
                if available.lower() == flavor_lower:
                    suggestions.append(available)
                elif available.upper() == flavor_upper:
                    suggestions.append(available)

            if suggestions:
                logger.error(f"Did you mean: {suggestions}")

            sys.exit(1)

        model_args = train_spec.model_args[args.model_flavor]

        with torch.device("cpu"):
            model = train_spec.model_cls(model_args)

        # Get state dict adapter if needed
        sd_adapter = None
        if args.from_hf:
            sd_adapter = train_spec.state_dict_adapter(model_args)
            if sd_adapter is None:
                raise ValueError(
                    "State dict adapter is required for HuggingFace format"
                )

        # Validate
        results = validate_checkpoint_model_match(
            args.checkpoint_path, model, logger, args.from_hf, sd_adapter, args.strict
        )

        if args.print_all_matched:
            matched_sorted = sorted(list(results.get("matched_keys", [])))
            print(f"\nMATCHED PARAMETERS ({len(matched_sorted)}):")
            for k in matched_sorted:
                print(k)

        # Exit with appropriate code
        has_issues = (
            results["missing_in_checkpoint"]
            or results["missing_in_model"]
            or results["shape_mismatches"]
        )

        if has_issues:
            logger.warning("Validation completed with issues!")
            sys.exit(1 if args.strict else 0)
        else:
            logger.info("Validation completed successfully - all parameters match!")
            sys.exit(0)

    except Exception as e:
        logger.error(f"Validation failed: {e}")
        sys.exit(1)


if __name__ == "__main__":
    main()
