# Modified from: https://github.com/pliang279/LG-FedAvg/blob/master/models/Nets.py
# credit goes to: Paul Pu Liang

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import models
import json
import numpy as np
from models.Transformer import TransformerEncoder

class CNN_FEMNIST(nn.Module):
    def __init__(self, args):
        super(CNN_FEMNIST, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.conv3 = nn.Conv2d(64, 128, 5)

        self.gradients = {}

        self.weight_keys = [['conv3.weight', 'conv3.bias'],
                            ['conv2.weight', 'conv2.bias'],
                            ['conv1.weight', 'conv1.bias'],
                            ]

    def save_grads(self, name):
        def hook(grad):
            self.gradients[name] = grad

        return hook

    def get_activations_gradient(self):
        return self.gradients

    def forward(self, x):
        x = F.relu(self.conv1(x))
        if self.train and x.requires_grad:
            x.register_hook(self.save_grads('h1'))

        x = F.relu(self.conv2(x))
        if self.train and x.requires_grad:
            x.register_hook(self.save_grads('h2'))

        x = F.relu(F.max_pool2d(self.conv3(x), 2))
        if self.train and x.requires_grad:
            x.register_hook(self.save_grads('h3'))

        return x

class CNN_FAMNIST(nn.Module):
    def __init__(self, args):
        super(CNN_FAMNIST, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.conv3 = nn.Conv2d(64, 128, 5)

        self.gradients = {}

        self.weight_keys = [['conv3.weight', 'conv3.bias'],
                            ['conv2.weight', 'conv2.bias'],
                            ['conv1.weight', 'conv1.bias'],
                            ]

    def save_grads(self, name):
        def hook(grad):
            self.gradients[name] = grad

        return hook

    def get_activations_gradient(self):
        return self.gradients

    def forward(self, x):
        x = F.relu(self.conv1(x))
        if self.train and x.requires_grad:
            x.register_hook(self.save_grads('h1'))

        x = F.relu(self.conv2(x))
        if self.train and x.requires_grad:
            x.register_hook(self.save_grads('h2'))

        x = F.relu(F.max_pool2d(self.conv3(x), 2))
        if self.train and x.requires_grad:
            x.register_hook(self.save_grads('h3'))

        return x

