import torch
from torch import nn

def get_lenet_cifar10(ckpt_path=None, n_classes=10, **kwargs):
    net = LeNet(n_classes, latent_dim=10816)
    if ckpt_path:
        checkpoint = torch.load(ckpt_path)
        if "state_dict" in checkpoint:
            checkpoint = checkpoint["state_dict"]
        net.load_state_dict(checkpoint)
    return net

def get_lenet_cmnist(ckpt_path=None, n_classes=10, **kwargs):
    net = LeNet(n_classes, latent_dim=7744)
    if ckpt_path:
        checkpoint = torch.load(ckpt_path)
        if "state_dict" in checkpoint:
            checkpoint = checkpoint["state_dict"]
        elif "model_state_dict" in checkpoint:
            checkpoint = checkpoint["model_state_dict"]
        net.load_state_dict(checkpoint)
    return net

def get_lenet_canonizer():
    return None


class LeNet(nn.Module):
    """
    Simple LeNet-like architecture.
    """

    def __init__(self, n_classes, latent_dim):
        super().__init__()

        self.input_identity = nn.Identity()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1),
            nn.ReLU(False),
            nn.Conv2d(32, 64, 3, 1),
            nn.ReLU(False),
        )
        self.pool = nn.MaxPool2d(2)
        self.head = nn.Sequential(
            nn.Dropout(0.25),
            nn.Linear(latent_dim, 128),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(128, n_classes),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        x = self.input_identity(x)
        x = self.features(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        out = self.head(x)
        return out