'''The implementation of ASAM and SAM are borrowed from
    https://github.com/debcaldarola/fedsam
   Caldarola, D., Caputo, B., & Ciccone, M.
   Improving Generalization in Federated Learning by Seeking Flat Minima,
   European Conference on Computer Vision (ECCV) 2022.
'''
import os
import re
from typing import Callable

import numpy as np
import torch
import torch.nn as nn
from PIL import Image

from federatedscope.register import register_model


class Conv2Model(nn.Module):
    def __init__(self, num_classes):
        super(Conv2Model, self).__init__()
        self.num_classes = num_classes

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5),
            nn.ReLU(), nn.MaxPool2d(kernel_size=2))

        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5),
            nn.ReLU(), nn.MaxPool2d(kernel_size=2))

        self.classifier = nn.Sequential(nn.Linear(64 * 5 * 5, 384), nn.ReLU(),
                                        nn.Linear(384, 192), nn.ReLU(),
                                        nn.Linear(192, self.num_classes))

        self.size = self.model_size()

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = torch.reshape(x, (x.shape[0], -1))
        x = self.classifier(x)
        return x

    def model_size(self):
        tot_size = 0
        for param in self.parameters():
            tot_size += param.size()[0]
        return tot_size


def call_fedsam_conv2(model_config, local_data):
    if model_config.type == 'fedsam_conv2':
        model = Conv2Model(10)
        return model


register_model('fedsam_conv2', call_fedsam_conv2)
