import logging

logger = logging.getLogger(__name__)

# Remove any existing handlers (optional, to avoid duplicate logs)
if logger.hasHandlers():
    logger.handlers.clear()

# Create a handler with a custom format for this module
handler = logging.StreamHandler()
formatter = logging.Formatter(
    "%(asctime)s BIRDSNAP SNAPSHOT %(levelname)s %(name)s: %(message)s"
)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

from snapshot_utils import save_snapshot, load_snapshot, compare_snapshots

def collect_model_outputs(model, config=None):
    """
    Collects key model outputs and configuration for consistency checking.

    Args:
        model: The trained model instance (should have attributes like logits, y_preds, labels, etc.).
        config (dict, optional): The configuration or hyperparameters used for this run.

    Returns:
        dict: Dictionary containing model outputs and config.
    """
    snapshot_results = {
        "logits": getattr(model, "logits", None),
        "y_preds": getattr(model, "y_preds", None),
        "labels": getattr(model, "labels", None),
        # Add more fields as needed
    }
    # Convert tensors to numpy if necessary
    for k, v in snapshot_results.items():
        if hasattr(v, "cpu") and hasattr(v, "numpy"):
            snapshot_results[k] = v.cpu().numpy()
    if config is not None:
        snapshot_results["config"] = config
    return snapshot_results

def snapshot_matches_config(snapshot_file, config):
    """
    Checks if the snapshot file corresponds to the given config.
    """
    snapshot = load_snapshot(snapshot_file)
    snapshot_config = snapshot.get("config", None)
    if snapshot_config is None:
        logger.warning("No config found in snapshot.")
        return False
    return snapshot_config == config

def check_or_update_snapshot(model, config, snapshot_file="birdsnap_model_output_snapshot.pkl", atol=1e-6):
    """
    Checks if a snapshot matches the given config. If so, compares model outputs to the snapshot.
    If the snapshot config does not match, creates a new snapshot.

    Returns:
        bool: True if model outputs match the snapshot (and config matches), False if a new snapshot was created.
    """
    import os

    results_with_config = collect_model_outputs(model, config)

    if os.path.isfile(snapshot_file):
        snapshot = load_snapshot(snapshot_file)
        snapshot_config = snapshot.get("config", None)
        if snapshot_config == config:
            # Config matches, compare results
            # Remove config key for comparison
            results = dict(results_with_config)
            results.pop("config", None)
            snapshot_no_config = dict(snapshot)
            snapshot_no_config.pop("config", None)
            match = compare_snapshots(results, snapshot_no_config, atol=atol)
            if match:
                logger.info("ALL OKAY. Results match the existing snapshot for this config.")
            else:
                logger.warning("Results do NOT match the existing snapshot for this config!")
            return match
        else:
            logger.info("Snapshot config does not match. Creating new snapshot.")
            save_snapshot(results_with_config, snapshot_file)
            return False
    else:
        logger.info("No snapshot found. Creating new snapshot.")
        save_snapshot(results_with_config, snapshot_file)
        return False

def collect_combined_accuracies(accuracies: dict, config: dict = None) -> dict:
    """
    Collects combined vision-location accuracies and configuration for snapshotting.

    Args:
        accuracies (dict): Dictionary of combined accuracies (all returned by combined_vision_location_encoding_evaluation).
        config (dict, optional): The configuration or hyperparameters used for this run.

    Returns:
        dict: Dictionary containing accuracies and config.
    """
    snapshot = dict(accuracies)
    if config is not None:
        snapshot["config"] = config
    return snapshot

def check_or_update_combined_accuracies_snapshot(
    accuracies: dict,
    config: dict,
    snapshot_file: str = "combined_accuracies_snapshot.pkl",
    atol: float = 1e-6
) -> bool:
    """
    Checks if a combined accuracies snapshot matches the given config. If so, compares accuracies to the snapshot.
    If the snapshot config does not match, creates a new snapshot.

    Returns:
        bool: True if accuracies match the snapshot (and config matches), False if a new snapshot was created.
    """
    import os
    from snapshot_utils import save_snapshot, load_snapshot, compare_snapshots

    results_with_config = collect_combined_accuracies(accuracies, config)

    if os.path.isfile(snapshot_file):
        snapshot = load_snapshot(snapshot_file)
        snapshot_config = snapshot.get("config", None)
        if snapshot_config == config:
            # Config matches, compare results (excluding config key)
            results = dict(results_with_config)
            results.pop("config", None)
            snapshot_no_config = dict(snapshot)
            snapshot_no_config.pop("config", None)
            match = compare_snapshots(results, snapshot_no_config, atol=atol)
            if match:
                logger.info("ALL OKAY. Combined accuracies match the existing snapshot for this config.")
            else:
                logger.warning("Combined accuracies do NOT match the existing snapshot for this config!")
            return match
        else:
            logger.info("Combined accuracies snapshot config does not match. Creating new snapshot.")
            save_snapshot(results_with_config, snapshot_file)
            return False
    else:
        logger.info("No combined accuracies snapshot found. Creating new snapshot.")
        save_snapshot(results_with_config, snapshot_file)
        return False