import glob
import os
import re
import shutil

import imageio
import matplotlib.pyplot as plt
import torch


def create_gif(time_series, saving_directory, name_file="density", delete_imgs=True):
    time_series_min = time_series.min()
    time_series_max = time_series.max()
    if time_series.ndim > 3:
        print("Error: The time series should be (time, height, width)")
        return
    if not os.path.exists(saving_directory + "/img_for_gif"):
        os.makedirs(saving_directory + "/img_for_gif")
    for i in range(time_series.shape[0]):
        plt.imshow(
            time_series[i], origin="lower", vmin=time_series_min, vmax=time_series_max
        )
        plt.axis("off")
        plt.savefig(
            saving_directory + f"/img_for_gif/time_series_{i}.png",
            bbox_inches="tight",
            pad_inches=0,
        )
        plt.close()
    images = []
    for file_name in sorted(
        glob.glob(saving_directory + "/img_for_gif/*.png"),
        key=lambda x: int(re.findall(r"\d+", x)[1]),
    ):
        images.append(imageio.imread(file_name))
    imageio.mimsave(saving_directory + "/" + name_file + ".gif", images, duration=0.1)
    if delete_imgs:
        shutil.rmtree(saving_directory + "/img_for_gif")


def load_common_weights(model, checkpoint_state_dict, strict=False, verbose=True):
    """
    Load common weights from a checkpoint into a model, ignoring missing or extra parameters.

    Args:
        model: The target model to load weights into
        checkpoint_state_dict: State dict from the checkpoint
        strict: If True, raises error on missing/unexpected keys (default: False)
        verbose: If True, prints information about loaded/skipped parameters

    Returns:
        dict: Dictionary with 'loaded', 'missing', and 'unexpected' parameter names
    """
    # Get model parameter names without copying tensors - cache the state dict once
    model_state_dict = model.state_dict()
    model_param_names = set(model_state_dict.keys())
    checkpoint_param_names = set(checkpoint_state_dict.keys())

    # Find intersections
    common_keys = model_param_names & checkpoint_param_names
    missing_keys = model_param_names - checkpoint_param_names
    unexpected_keys = checkpoint_param_names - model_param_names

    # Load parameters one by one, checking shapes on-demand
    loaded_keys = []
    size_mismatches = []

    with torch.no_grad():
        for key in common_keys:
            checkpoint_param = checkpoint_state_dict[key]
            # Use cached state dict to avoid repeated calls
            model_param = model_state_dict[key]

            if model_param.shape == checkpoint_param.shape:
                # Direct parameter assignment without intermediate dict
                model_param.copy_(checkpoint_param)
                loaded_keys.append(key)
            else:
                size_mismatches.append(key)
                if verbose:
                    print(
                        f"Size mismatch for {key}: model {model_param.shape} vs checkpoint {checkpoint_param.shape}"
                    )

    if verbose:
        print(f"Loaded {len(loaded_keys)} common parameters")
        print(f"Missing in checkpoint: {len(missing_keys)} parameters")
        print(f"Extra in checkpoint: {len(unexpected_keys)} parameters")
        print(f"Size mismatches: {len(size_mismatches)} parameters")

        if missing_keys:
            print("Missing parameters (will be randomly initialized):")
            for key in sorted(missing_keys):
                print(f"  - {key}")

    return {
        "loaded": loaded_keys,
        "missing": list(missing_keys),
        "unexpected": list(unexpected_keys),
        "size_mismatches": size_mismatches,
    }
