# Modified from: https://github.com/pliang279/LG-FedAvg/blob/master/models/Nets.py
# credit goes to: Paul Pu Liang

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import models
import json
import numpy as np

from functools import partial

from .dcf import Conv_DCFr

class Shared(torch.nn.Module):

    def __init__(self, num_bases, kernel_size):
        super(Shared, self).__init__()

        self.ncha,size,_=[3, 32, 32]

        hiddens = [64, 256, 256]
        self.pool = len(hiddens) -1

        self.num_layers = len(hiddens)
        self.hiddens = hiddens
        self.conv_layers = nn.ModuleList()
        for i, dim in enumerate(hiddens):
            if i == 0:
                self.conv_layers.append(Conv_DCFr(self.ncha, hiddens[0], kernel_size=kernel_size, padding=0, 
                                num_bases=num_bases))
            else:
                self.conv_layers.append(Conv_DCFr(hiddens[i-1], hiddens[i], kernel_size=kernel_size, padding=0, 
                                num_bases=num_bases))

        self.gap = torch.nn.AdaptiveAvgPool2d(1)

        self.maxpool=torch.nn.MaxPool2d(2)
        self.relu=torch.nn.ReLU()

        self.drop1=torch.nn.Dropout2d(0.2)
        self.drop2=torch.nn.Dropout(0.5)

    def forward(self, input):
        h, bases = input

        N, C, H, W = h.shape

        for i, conv in enumerate(self.conv_layers):
            if i >= self.pool:
                h = self.drop1(self.relu(conv((h, bases[i]))))
            else:
                h = self.maxpool(self.drop1(self.relu(conv((h, bases[i])))))
        
        h = self.gap(h)
        h = h.view(N, -1)
        
        return h


class AlexNet3(torch.nn.Module):

    def __init__(self, taskcla=100):
        super(AlexNet3, self).__init__()
        self.ncha,size,_=[3, 32, 32]
        self.taskcla=taskcla
        self.latent_dim = 256
        kernel_size = 3

        hiddens = [64, 256, 256]
        self.conv = nn.ModuleList()
        for i, dim in enumerate(hiddens):
            if i == 0:
                self.conv.append(torch.nn.Conv2d(self.ncha, hiddens[0], kernel_size=kernel_size, padding=0, bias=False))
            else:
                self.conv.append(torch.nn.Conv2d(hiddens[i-1], hiddens[i], kernel_size=kernel_size, padding=0, bias=False))
        

        self.hidden1 = 128
        self.pool = len(hiddens) -1

        self.nlayers = len(hiddens)
        self.gap = torch.nn.AdaptiveAvgPool2d(1)

        self.maxpool=torch.nn.MaxPool2d(2)
        self.relu=torch.nn.ReLU()

        self.drop1=torch.nn.Dropout2d(0.2)
        self.drop2=torch.nn.Dropout(0.5)

        ## initialize cls heads
        self.head = torch.nn.Sequential(
                        torch.nn.Linear(self.latent_dim, self.hidden1),
                        torch.nn.ReLU(inplace=True),
                        torch.nn.Dropout(),
                        torch.nn.Linear(self.hidden1, self.taskcla)
                    )


    def forward(self, x, task_id=0):

        x = x.view_as(x)
        N, C, H, W = x.shape
        for i, conv in enumerate(self.conv):
            if i >= self.pool:
                x = self.drop1(self.relu(conv(x)))
            else:
                x = self.maxpool(self.drop1(self.relu(conv(x))))
        
        x = self.gap(x)
        x = x.view(N, -1)
        
        return self.head(x)


class AlexNet3DCF(torch.nn.Module):

    def __init__(self, taskcla=100, num_bases=9):
        super(AlexNet3DCF, self).__init__()
        self.ncha,size,_=[3, 32, 32]
        self.taskcla=taskcla
        self.latent_dim = 256
        self.num_bases = num_bases
        kernel_size = 3

        self.hidden1 = 128

        self.shared = Shared(self.num_bases, kernel_size)
        ## number of dcf layers
        self.nlayers = self.shared.num_layers

        ## initilize bases
        bases_init = torch.randn((self.num_bases, kernel_size, kernel_size))
        bases_init = torch.nn.init.orthogonal(bases_init).permute(1,2,0).contiguous()
        bases_init = bases_init.repeat(self.nlayers, 1, 1, 1)

        self.bases_tasks = torch.nn.Parameter(bases_init, requires_grad=True)

        ## initialize cls heads
        self.head = torch.nn.Sequential(
                    torch.nn.Linear(self.latent_dim, self.hidden1),
                    torch.nn.ReLU(inplace=True),
                    torch.nn.Dropout(),
                    torch.nn.Linear(self.hidden1, self.taskcla)
                )

    def forward(self, x):
        x = x.view_as(x)
        x = self.shared((x, self.bases_tasks))
        return self.head(x)
    
    