import os

import torch
import wandb
import yaml

from src.config import Config
from src.models.model_factory import build_model
from src.utils.logger_ctx import get_logger


def load_model_from_checkpoint(checkpoint_path: str, device: str | None = None) -> torch.nn.Module:
    """
    Load a model checkpoint from a given path.

    Parameters
    ----------
    checkpoint_path : str
        Path to the checkpoint file.

    Returns
    -------
    model : torch.nn.Module
        The loaded model.
    """
    checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)

    if device:
        checkpoint['config'].training.device = device
    model = build_model(checkpoint['config'], state_dict=checkpoint['model_state_dict'])

    return model


def recreate_model_from_wandb(run_path, build_model, config_filename="toy_config.yaml",
                              weight_filename="model.pth", device=None, strict=True):
    """
    Recreate a model from a W&B run (non-Lightning).

    Parameters
    ----------
    run_path : str
        W&B run path "entity/project/run_id".
    build_model : callable
        Function that builds the model from a config dict.
    config_filename : str
        Name of the config file stored in the W&B run files.
    weight_filename : str
        Name of the weights file stored in the W&B run files.
    device : str or torch.device, optional
        Move model to this device (e.g., "cuda").
    strict : bool
        Whether to enforce that the weights match exactly.

    Returns
    -------
    model : torch.nn.Module
    config : dict
    """
    api = wandb.Api()
    run = api.run(run_path)
    download_dir = f"./wandb_restore/wandb_restore_{run.id}"
    os.makedirs(download_dir, exist_ok=True)

    # Download config + weights
    run.file(config_filename).download(root=download_dir, replace=True)
    run.file(weight_filename).download(root=download_dir, replace=True)

    # Load config
    with open(os.path.join(download_dir, config_filename), "r") as f:
        config = yaml.safe_load(f)

    # Build model
    model = build_model(config)

    # Load weights
    state_dict = torch.load(os.path.join(download_dir, weight_filename), map_location="cpu")
    model.load_state_dict(state_dict, strict=strict)

    if device:
        model = model.to(device)

    return model, config


def nested_dicts_equal(d1, d2):
    """
    Return False immediately if any difference is found between two nested dicts.
    """
    # Keys differ
    if d1.keys() != d2.keys():
        return False

    for key in d1:
        v1, v2 = d1[key], d2[key]

        if isinstance(v1, dict) and isinstance(v2, dict):
            if not nested_dicts_equal(v1, v2):
                return False
        else:
            if v1 != v2:
                return False

    return True


def overwrite_model_config(config: Config, checkpoint: dict) -> Config:
    """
    Overwrite the model configuration with the one from the checkpoint.

    Parameters
    ----------
    config : Config
        The original configuration.
    checkpoint : dict
        The loaded checkpoint containing the model configuration.

    Returns
    -------
    Config
        The updated configuration with values from the checkpoint.
    """
    # Check if config and loaded config match
    if not nested_dicts_equal(config.to_dict()['model'], checkpoint['config'].to_dict()['model']):
        if config.model.force_overwrite:
            logger = get_logger()
            logger.info("Loaded config does not match the original config. Overwriting with loaded config.")
            config.model = checkpoint['config'].residual_model
        else:
            raise ValueError(
                "Loaded config does not match the original config. Set force_overwrite=True to overwrite.")

    return config
