from pathlib import Path
from typing import Union

import torch
import torch.nn as nn
import torch.onnx

from adversarial_superposition.cifar.utils.model import ViT


def load_vit_model(
    model_path: Union[str, Path],
    bottleneck_dim: int,
    bottleneck_after_dim: int,
    device: str,
    image_size: int = 32,
    patch_size: int = 4,
    num_classes: int = 10,
    dim: int = 512,
    mlp_dim: int = 512,
    depth: int = 6,
    heads: int = 8,
) -> torch.nn.Module:
    """Load a checkpoint file into the specified model architecture."""
    model = ViT(
        image_size=image_size,
        patch_size=patch_size,
        num_classes=num_classes,
        dim=dim,
        depth=depth,
        heads=heads,
        mlp_dim=mlp_dim,
        dropout=0.1,
        emb_dropout=0.1,
    )
    original_dim = model.mlp_head[0].normalized_shape[0]
    if bottleneck_dim:
        model.mlp_head = nn.Sequential(
            nn.LayerNorm(original_dim),
            nn.Linear(original_dim, bottleneck_dim),
            # nn.ReLU(),
            nn.Linear(bottleneck_dim, 10),
        )
    else:
        model.mlp_head = nn.Sequential(
            nn.LayerNorm(original_dim),
            nn.Linear(original_dim, 10),
        )

    if bottleneck_after_dim:
        model.mlp_head = nn.Sequential(
            nn.LayerNorm(original_dim),
            nn.Linear(original_dim, num_classes),
            nn.Linear(num_classes, bottleneck_after_dim, bias=False),
            nn.Linear(bottleneck_after_dim, num_classes, bias=False),
        )

    checkpoint = torch.load(model_path, map_location=device)
    if "model" in checkpoint:
        model.load_state_dict(checkpoint["model"])
    else:
        model.load_state_dict(checkpoint)

    model.to(device)
    model.eval()
    return model
