# Modified from: https://github.com/pliang279/LG-FedAvg/blob/master/models/Nets.py
# credit goes to: Paul Pu Liang

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import models
import json
import numpy as np
from models.Transformer import TransformerEncoder

class CNNCifar(nn.Module):
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        self.num_classes = args.num_classes
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        self.drop = nn.Dropout(0.6)
        self.gradients = {}

        self.weight_keys = [['conv3.weight', 'conv3.bias'],
                            ['conv2.weight', 'conv2.bias'],
                            ['conv1.weight', 'conv1.bias'],
                            ]

    def save_grads(self, name):
        def hook(grad):
            self.gradients[name] = grad
        return hook

    def forward(self, x):
        x = self.conv1(x)
        if self.train and x.requires_grad:
            x.register_hook(self.save_grads('h1'))
        x = self.pool(F.relu(x))

        x = self.conv2(x)
        if self.train and x.requires_grad:
            x.register_hook(self.save_grads('h2'))
        x = F.relu(x)

        x = self.conv3(x)
        if self.train and x.requires_grad:
            x.register_hook(self.save_grads('h3'))
        x = F.relu(x)
        # x = self.pool(F.relu(x))
        # x = x.view(-1, 256 * 5 * 5)

        return x

    # method for the gradient extraction
    def get_activations_gradient(self):
        return self.gradients


class CNN_FAMNIST(nn.Module):
    def __init__(self, args):
        super(CNN_FAMNIST, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.conv3 = nn.Conv2d(64, 128, 5)

        self.gradients = {}

        self.weight_keys = [['conv3.weight', 'conv3.bias'],
                            ['conv2.weight', 'conv2.bias'],
                            ['conv1.weight', 'conv1.bias'],
                            ]

    def save_grads(self, name):
        def hook(grad):
            self.gradients[name] = grad

        return hook

    def get_activations_gradient(self):
        return self.gradients

    def forward(self, x):
        x = F.relu(self.conv1(x))
        if self.train and x.requires_grad:
            x.register_hook(self.save_grads('h1'))

        x = F.relu(self.conv2(x))
        if self.train and x.requires_grad:
            x.register_hook(self.save_grads('h2'))

        x = F.relu(F.max_pool2d(self.conv3(x), 2))
        if self.train and x.requires_grad:
            x.register_hook(self.save_grads('h3'))

        return x