# forward_forward/utils/freeze.py

import torch
from typing import List, Dict, Optional


def load_and_prepare_pretrained_blocks(
    model: torch.nn.Module,
    epochs_per_layer: List[int],
    pretrained_blocks: Dict[str, str],
    device: torch.device,
    trained_epochs_so_far: Optional[List[int]] = None,
):
    """
    Load pretrained FFBlocks from checkpoint files, and decide whether to freeze or continue training
    based on epochs_per_layer. Also prepares initial epoch offsets.

    Args:
        model (nn.Module): Model containing FFBlocks.
        block_indices (List[int]): Indices of FFBlocks in model.layers.
        epochs_per_layer (List[int]): List of epochs to train per block (0 means freeze).
        pretrained_blocks (Dict[str, str]): Mapping from block name to checkpoint path.
        device (torch.device): Device to load checkpoints onto.
        trained_epochs_so_far (List[int], optional): How many epochs each block has been trained so far.

    Returns:
        List[int]: The updated trained_epochs_so_far list (in case it was None).
    """
    if trained_epochs_so_far is None:
        trained_epochs_so_far = [0] * len(model.trainable_names)

    for idx, ff_idx in enumerate(model.trainable_names):
        block = model.layers[ff_idx]
        block_name = getattr(block, "name", f"block_{idx}")

        if block_name not in pretrained_blocks:
            continue

        ckpt_path = pretrained_blocks[block_name]
        print(f"📥 Loading pretrained block '{block_name}' from {ckpt_path}")

        try:
            checkpoint = torch.load(ckpt_path, map_location=device)
            model.load_state_dict(checkpoint, strict=False)
        except Exception as e:
            raise RuntimeError(f"❌ Failed to load checkpoint for '{block_name}' from {ckpt_path}: {e}")

        if epochs_per_layer[idx] == 0:
            for p in block.layer.parameters():
                p.requires_grad = False
            print(f"⛔ Block '{block_name}' frozen and will not be trained further.")
        else:
            print(f"🔄 Block '{block_name}' loaded and will continue training for {epochs_per_layer[idx]} epochs "
                  f"(already trained for {trained_epochs_so_far[idx]}).")

    return trained_epochs_so_far
