#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
from torch import nn
import torch.nn.functional as F
from ibp_torch import *
import pdb


class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.LeakyReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return F.log_softmax(x)

class IBPMLP(nn.Module):
    def __init__(self, model_arch,factorized_w,a_prior,lambda_post,lambda_prior, p_threshold,global_forward=False):
        super(IBPMLP, self).__init__()

        self.model_arch = model_arch
        self.factorized_w = factorized_w # (wa, r,p,wb)
        self.trainable_param = list()

        self.a_prior = a_prior

        self.lambda_post = lambda_post
        self.lambda_prior = lambda_prior
        self.p_threshold = p_threshold
        self.relu = nn.LeakyReLU()

        self.dropout = nn.Dropout()
        self.num_layers = len(model_arch) - 1
        self.global_forward = global_forward

        self.A = list()
        self.B = list()
        self.initial_AB()
        self.register_param()

    def initial_AB(self):
        for ix, (dim_in, dim_hidden, dim_out) in enumerate(self.model_arch):

            # Global parameter
            # The posterior for v is v ~ Kumaraswamy(a, b)
            # The prior is v ~ Beta(\alpha, 1)

            Al = nn.Parameter(get_torch_logit_variable(init_value=self.a_prior, shape=(dim_hidden, 1)))
            Bl = nn.Parameter(get_torch_logit_variable(init_value=1.0, shape=(dim_hidden, 1)))

            setattr(self, "A"+str(ix), Al)
            setattr(self, "B"+str(ix), Bl)

            self.A.append(Al)
            self.B.append(Bl)


    def register_param(self):
        param_names = ['wa','r','p','wb']
        for ix, item in enumerate(self.factorized_w):
            layer_name = 'mlp_{}'.format(ix)
            current_layer_param = list()
            for j, we in enumerate(item):
                we = nn.Parameter(we)
                setattr(self, layer_name+'_'+param_names[j], we)
                current_layer_param.append(getattr(self,layer_name+'_'+param_names[j]))
            self.trainable_param.append(current_layer_param)




    def forward(self, x, A_prior=None, B_prior=None, R_prior=None,expectation=False,initial=True,external_Zl=None):

        H = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        kld_binary = 0.0
        kld_v = 0.0
        kld_r = 0.0

        for ix, (dim_in, dim_hidden, dim_out) in enumerate(self.model_arch):
            wa, real_l, pi_post, wb = self.trainable_param[ix]

            # print (real_l)

            pi_post = torch.sigmoid(pi_post)
            Al = self.A[ix].to(wa.device)
            Bl = self.B[ix].to(wa.device)
            Al = torch_odds(torch.sigmoid(Al))
            Bl = torch_odds(torch.sigmoid(Bl))


            H = torch.cat((H, torch.ones((H.shape[0], 1)).to(wa.device)), dim=1)
            Y_post, Binary_l = torch_sample_BernConcrete(pi_post,self.lambda_post)

            if expectation:
                pi_post_squeeze = torch.squeeze(pi_post_squeeze)
                pi_activated = torch.where(pi_post_squeeze > self.p_threshold, pi_post_squeeze, torch.zeros((dim_hidden)).type(pi_post_squeeze.dtype))
                pi_activated = torch.reshape(pi_activated, (dim_hidden,1))
                Zl = torch.squeeze(pi_activated*real_l)
            else:
                Zl = torch.squeeze(Binary_l*real_l)

            if self.global_forward:
                Zl = torch.squeeze(real_l)

            Dl = torch.diag(Zl)
            if external_Zl is not None:
                Dl = torch.diag(external_Zl[ix])


            W = wa.mm(Dl).mm(wb)
            H = H.mm(W)

            if ix <= self.num_layers - 1:
                H = self.relu(H)

            if initial:
                kld_v += torch.sum(torch_kullback_kumar_beta(Al, Bl, torch.Tensor([self.a_prior]).to(wa.device), torch.Tensor([1.0]).to(wa.device)))
                kld_r += torch.sum(torch_kullback_normal_normal(real_l, torch.ones_like(real_l), torch.Tensor([0.0]).to(wa.device), torch.Tensor([1.0]).to(wa.device)))



            pi_prior_log = torch_stick_breaking_weights(Al, Bl)
            pi_prior_log = torch.reshape(pi_prior_log, (-1, 1))

            pi_prior = torch.exp(pi_prior_log)
            log_q = torch_log_density_logistic(pi_post, torch.Tensor([self.lambda_post]).to(wa.device), Y_post)
            log_p = torch_log_density_logistic(pi_prior, torch.Tensor([self.lambda_prior]).to(wa.device), Y_post)

            self.y_post = Y_post
            self.log_q = log_q
            self.log_p = log_p

            kld_binary += torch.sum(log_q - log_p)

        out = H

        return F.log_softmax(out,dim=1),kld_binary,kld_v,kld_r






# separte r and pi
class IBPMLP_V3(nn.Module):
    def __init__(self, model_arch,factorized_w,a_prior,lambda_post,lambda_prior, p_threshold,local_bayes=None, global_forward=False):
        super(IBPMLP_V3, self).__init__()

        self.model_arch = model_arch
        self.factorized_w = factorized_w
        self.trainable_param = list()

        self.a_prior = a_prior
        self.lambda_post = lambda_post
        self.lambda_prior = lambda_prior
        self.p_threshold = p_threshold
        self.relu = nn.LeakyReLU()
        self.num_layers = len(model_arch) - 1

        self.local_bayes = local_bayes
        self.global_forward = global_forward

        self.masks = list()
        self.A = list()
        self.B = list()

        self.pi_list = list()
        self.initial_AB()
        self.get_pi()


        last_in, _, last_out = self.model_arch[-1]
        self.linear = nn.Linear(last_in, last_out)
        self.register_param()


    def initial_AB(self):
        for ix, (dim_in, dim_hidden, dim_out) in enumerate(self.model_arch[:-1]):

            if self.local_bayes is None:
                Al = nn.Parameter(get_torch_logit_variable(init_value=self.a_prior, shape=(dim_hidden, 1)))
                Bl = nn.Parameter(get_torch_logit_variable(init_value=1.0, shape=(dim_hidden, 1)))
            else:
                Al = nn.Parameter(self.local_bayes['Al'][ix])
                Bl = nn.Parameter(self.local_bayes['Bl'][ix])

            setattr(self, "A"+str(ix), Al)
            setattr(self, "B"+str(ix), Bl)

            self.A.append(Al)
            self.B.append(Bl)


    def get_pi(self):

        for ix, (dim_in, dim_hidden, dim_out) in enumerate(self.model_arch[:-1]):

            if self.local_bayes is None:
                # Take \E[v] for v to initialize pi_post_logit
                init = np.ones((dim_hidden, 1), dtype=np.float32)*(self.a_prior / (self.a_prior+1))
                init = np.cumprod(init, axis=0)

                init = np.clip(init, a_min = 0.001, a_max = 0.999)
                # convert to logit
                init = np.log(init / ((1-init)))
                pi_post_logit, local_vars = get_torch_variable("", init=init, shape=(dim_hidden, 1), var_list=[],init_std=0.1)
                pi_post_logit = torch.reshape(pi_post_logit, shape=(dim_hidden, 1))

            else:
                pi_post_logit = self.local_bayes['pi'][ix]

            self.pi_list.append(pi_post_logit)
            self.masks.append(pi_post_logit)


    def register_param(self):
        param_names = ['wa','r','wb','p']
        for ix, item in enumerate(self.factorized_w[:-1]):
            layer_name = 'mlp_{}'.format(ix)
            current_layer_param = list()

            for j, we in enumerate(list(item) + [self.pi_list[ix]]):
                we = nn.Parameter(we)
                setattr(self, layer_name+'_'+param_names[j], we)
                current_layer_param.append(getattr(self,layer_name+'_'+param_names[j]))
            self.trainable_param.append(current_layer_param)
        self.linear.weight = self.factorized_w[-1][0]
        self.linear.bias = self.factorized_w[-1][1]


    def forward(self, x, A_prior=None, B_prior=None, R_prior=None,expectation=False,initial=True,external_Zl=None,inference=False):

        H = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        kld_binary = 0.0
        kld_v = 0.0
        kld_r = 0.0

        for ix, (dim_in, dim_hidden, dim_out) in enumerate(self.model_arch[:-1]):
            wa, real_l, wb, pi_post = self.trainable_param[ix]


            pi_post_prob = torch.sigmoid(pi_post)

            Al = self.A[ix].to(wa.device)
            Bl = self.B[ix].to(wa.device)
            Al_pos = torch_odds(torch.sigmoid(Al))
            Bl_pos = torch_odds(torch.sigmoid(Bl))

            # to use the previous update al and bl as new prior
            Al_pre = self.A[ix].clone().to(wa.device).detach()
            Bl_pre = self.B[ix].clone().to(wa.device).detach()
            Al_pos_pre = torch_odds(torch.sigmoid(Al_pre))
            Bl_pos_pre = torch_odds(torch.sigmoid(Bl_pre))



            H = torch.cat((H, torch.ones((H.shape[0], 1)).to(wa.device)), dim=1)
            Y_post, Binary_l = torch_sample_BernConcrete(pi_post_prob,self.lambda_post)


            if expectation:
                pi_post_squeeze = torch.squeeze(pi_post_squeeze)
                pi_activated = torch.where(pi_post_squeeze > self.p_threshold, pi_post_squeeze, torch.zeros((dim_hidden)).type(pi_post_squeeze.dtype))
                pi_activated = torch.reshape(pi_activated, (dim_hidden,1))
                Zl = torch.squeeze(pi_activated*real_l)
            else:
                Zl = torch.squeeze(Binary_l*real_l)

            if self.global_forward:
                Zl = torch.squeeze(real_l)

            Dl = torch.diag(Zl)
            if external_Zl is not None:
                Dl = torch.diag(external_Zl[ix])


            W = wa.mm(Dl).mm(wb)
            H = H.mm(W)

            if ix <= self.num_layers - 1:
                H = self.relu(H)

            if initial:

                kld_v += torch.sum(torch_kullback_kumar_beta(Al_pos, Bl_pos, Al_pos_pre, Bl_pos_pre))

                kld_r += torch.sum(torch_kullback_normal_normal(real_l, torch.ones_like(real_l), torch.Tensor([0.0]).to(wa.device), torch.Tensor([1.0]).to(wa.device)))



            pi_prior_log = torch_stick_breaking_weights(Al_pos, Bl_pos)
            pi_prior_log = torch.reshape(pi_prior_log, (-1, 1))

            pi_prior = torch.exp(pi_prior_log)
            log_q = torch_log_density_logistic(pi_post_prob, torch.Tensor([self.lambda_post]).to(wa.device), Y_post)
            log_p = torch_log_density_logistic(pi_prior, torch.Tensor([self.lambda_prior]).to(wa.device), Y_post)

            self.y_post = Y_post
            self.log_q = log_q
            self.log_p = log_p

            kld_binary += torch.sum(log_q - log_p)

        out = self.linear(H)

        return F.log_softmax(out,dim=1),kld_binary,kld_v,kld_r






class IBPCNN(nn.Module):
    def __init__(self, model_arch,conv_arch,factorized_w,a_prior,lambda_post,lambda_prior, p_threshold):
        super(IBPCNN, self).__init__()

        self.model_arch = model_arch
        self.conv_arch = conv_arch
        self.factorized_w = factorized_w
        self.trainable_param = list()

        self.a_prior = a_prior

        self.lambda_post = lambda_post
        self.lambda_prior = lambda_prior
        self.p_threshold = p_threshold
        self.relu = nn.LeakyReLU()

        self.dropout = nn.Dropout()
        self.num_layers = len(model_arch) - 1



        self.A = list()
        self.B = list()
        self.initial_AB()
        self.register_param()


    def initial_AB(self):
        for ix, (layer_name, dim_in, dim_hidden, dim_out) in enumerate(self.model_arch):


            Al = nn.Parameter(get_torch_logit_variable(init_value=self.a_prior, shape=(dim_hidden, 1)))
            Bl = nn.Parameter(get_torch_logit_variable(init_value=1.0, shape=(dim_hidden, 1)))

            setattr(self, layer_name + "_A"+str(ix), Al)
            setattr(self, layer_name + "_B"+str(ix), Bl)

            self.A.append(Al)
            self.B.append(Bl)


    def register_param(self):
        param_names = ['wa','r','p','wb']
        for ix, item in enumerate(self.factorized_w):
            layer_type = self.model_arch[ix][0] # conv or mlp
            layer_name = layer_type + '_{}'.format(ix)
            current_layer_param = list()
            for j, we in enumerate(item):
                we = nn.Parameter(we)
                setattr(self, layer_name+'_'+param_names[j], we)
                current_layer_param.append(getattr(self,layer_name+'_'+param_names[j]))
            self.trainable_param.append(current_layer_param)


    def forward(self, x, A_prior=None, B_prior=None, R_prior=None,expectation=False,initial=True,external_Zl=None):

        H = x
        batch_size = x.shape[0]
        kld_binary = 0.0
        kld_v = 0.0
        kld_r = 0.0

        for ix, (layer_name, dim_in, dim_hidden, dim_out) in enumerate(self.model_arch):
            wa, real_l, pi_post, wb = self.trainable_param[ix]


            pi_post = torch.sigmoid(pi_post)
            Al = self.A[ix].cuda()
            Bl = self.B[ix].cuda()
            Al = torch_odds(torch.sigmoid(Al))
            Bl = torch_odds(torch.sigmoid(Bl))


            Y_post, Binary_l = torch_sample_BernConcrete(pi_post,self.lambda_post)

            if expectation:
                pi_post_squeeze = torch.squeeze(pi_post_squeeze)
                pi_activated = torch.where(pi_post_squeeze > self.p_threshold, pi_post_squeeze, torch.zeros((dim_hidden)).type(pi_post_squeeze.dtype))
                pi_activated = torch.reshape(pi_activated, (dim_hidden,1))
                Zl = torch.squeeze(pi_activated*real_l)
            else:
                Zl = torch.squeeze(Binary_l*real_l)

            Dl = torch.diag(Zl)
            if external_Zl is not None:
                Dl = torch.diag(external_Zl[ix])

            W = wa.mm(Dl).mm(wb)

            if layer_name == 'conv':
                conv_filters = W.view(self.conv_arch[ix][1],self.conv_arch[ix][0],self.conv_arch[ix][2],self.conv_arch[ix][2]) #[dimout,dimin,H,W]
                H = self.relu(F.max_pool2d(F.conv2d(H, conv_filters, bias=None), 2))

            else: #mlp

                H = H.view(batch_size, -1)
                H = torch.cat((H, torch.ones((H.shape[0], 1)).cuda()), dim=1)
                H = H.mm(W)

            if len(self.conv_arch)<= ix <= self.num_layers - 1:
                H = self.relu(H)

            if initial:
                kld_v += torch.sum(torch_kullback_kumar_beta(Al, Bl, torch.Tensor([self.a_prior]).cuda(), torch.Tensor([1.0]).cuda()))
                kld_r += torch.sum(torch_kullback_normal_normal(real_l, torch.ones_like(real_l), torch.Tensor([1.0]).cuda(), torch.Tensor([1.]).cuda()))



            pi_prior_log = torch_stick_breaking_weights(Al, Bl)
            pi_prior_log = torch.reshape(pi_prior_log, (-1, 1))

            pi_prior = torch.exp(pi_prior_log)
            log_q = torch_log_density_logistic(pi_post, torch.Tensor([self.lambda_post]).cuda(), Y_post)
            log_p = torch_log_density_logistic(pi_prior, torch.Tensor([self.lambda_prior]).cuda(), Y_post)

            self.y_post = Y_post
            self.log_q = log_q
            self.log_p = log_p

            kld_binary += torch.sum(log_q - log_p)

        out = H

        return F.log_softmax(out,dim=1),kld_binary,2.0 * kld_v,kld_r



class IBPCNN_V3(nn.Module):
    def __init__(self, model_arch,conv_arch,factorized_w,a_prior,lambda_post,lambda_prior, p_threshold,local_bayes=None,global_forward=False):
        super(IBPCNN_V3, self).__init__()

        self.model_arch = model_arch
        self.conv_arch = conv_arch
        self.factorized_w = factorized_w
        self.trainable_param = list()

        self.a_prior = a_prior

        self.lambda_post = lambda_post
        self.lambda_prior = lambda_prior
        self.p_threshold = p_threshold
        self.relu = nn.LeakyReLU()

        self.dropout = nn.Dropout()
        self.num_layers = len(model_arch) - 1

        self.local_bayes = local_bayes
        self.global_forward = global_forward


        self.masks = list()
        self.A = list()
        self.B = list()

        self.pi_list = list()
        self.initial_AB()
        self.get_pi()

        _, last_in, _, last_out = self.model_arch[-1]
        self.linear = nn.Linear(last_in, last_out)
        self.register_param()


    def initial_AB(self):
        for ix, (layer_name, dim_in, dim_hidden, dim_out) in enumerate(self.model_arch[:-1]):

            if self.local_bayes is None:
                Al = nn.Parameter(get_torch_logit_variable(init_value=self.a_prior, shape=(dim_hidden, 1)))
                Bl = nn.Parameter(get_torch_logit_variable(init_value=1.0, shape=(dim_hidden, 1)))
            else:
                Al = nn.Parameter(self.local_bayes['Al'][ix])
                Bl = nn.Parameter(self.local_bayes['Bl'][ix])

            setattr(self, "A"+str(ix), Al)
            setattr(self, "B"+str(ix), Bl)

            self.A.append(Al)
            self.B.append(Bl)


    def get_pi(self):

        for ix, (layer_name, dim_in, dim_hidden, dim_out) in enumerate(self.model_arch[:-1]):

            if self.local_bayes is None:
                # Take \E[v] for v to initialize pi_post_logit
                init = np.ones((dim_hidden, 1), dtype=np.float32)*(self.a_prior / (self.a_prior+1))
                init = np.cumprod(init, axis=0)

                init = np.clip(init, a_min = 0.001, a_max = 0.999)
                init = np.log(init / ((1-init)))
                pi_post_logit, local_vars = get_torch_variable("", init=init, shape=(dim_hidden, 1), var_list=[],init_std=0.1)
                pi_post_logit = torch.reshape(pi_post_logit, shape=(dim_hidden, 1))

            else:
                pi_post_logit = self.local_bayes['pi'][ix]

            self.pi_list.append(pi_post_logit)
            self.masks.append(pi_post_logit)



    def register_param(self):
        param_names = ['wa','r','wb','p']
        for ix, item in enumerate(self.factorized_w[:-1]):
            layer_type = self.model_arch[ix][0] # conv or mlp
            layer_name = layer_type + '_{}'.format(ix)
            current_layer_param = list()
            for j, we in enumerate(list(item) + [self.pi_list[ix]]):
                we = nn.Parameter(we)
                setattr(self, layer_name+'_'+param_names[j], we)
                current_layer_param.append(getattr(self,layer_name+'_'+param_names[j]))
            self.trainable_param.append(current_layer_param)
        self.linear.weight = self.factorized_w[-1][0]
        self.linear.bias = self.factorized_w[-1][1]


    def forward(self, x, A_prior=None, B_prior=None, R_prior=None,expectation=False,initial=True,external_Zl=None,inference=False):

        H = x
        batch_size = x.shape[0]
        kld_binary = 0.0
        kld_v = 0.0
        kld_r = 0.0

        for ix, (layer_name, dim_in, dim_hidden, dim_out) in enumerate(self.model_arch[:-1]):
            wa, real_l, wb,pi_post = self.trainable_param[ix]


            pi_post_prob = torch.sigmoid(pi_post)

            Al = self.A[ix].to(wa.device)
            Bl = self.B[ix].to(wa.device)
            Al_pos = torch_odds(torch.sigmoid(Al))
            Bl_pos = torch_odds(torch.sigmoid(Bl))

            # to use the previous update al and bl as new prior
            Al_pre = self.A[ix].clone().to(wa.device).detach()
            Bl_pre = self.B[ix].clone().to(wa.device).detach()
            Al_pos_pre = torch_odds(torch.sigmoid(Al_pre))
            Bl_pos_pre = torch_odds(torch.sigmoid(Bl_pre))


            Y_post, Binary_l = torch_sample_BernConcrete(pi_post_prob,self.lambda_post)

            if expectation:
                pi_post_squeeze = torch.squeeze(pi_post_squeeze)
                pi_activated = torch.where(pi_post_squeeze > self.p_threshold, pi_post_squeeze, torch.zeros((dim_hidden)).type(pi_post_squeeze.dtype))
                pi_activated = torch.reshape(pi_activated, (dim_hidden,1))
                Zl = torch.squeeze(pi_activated*real_l)
            else:
                Zl = torch.squeeze(Binary_l*real_l)

            if self.global_forward:
                Zl = torch.squeeze(real_l)


            Dl = torch.diag(Zl)
            if external_Zl is not None:
                Dl = torch.diag(external_Zl[ix])

            W = wa.mm(Dl).mm(wb)

            if layer_name == 'conv':
                conv_filters = W.view(self.conv_arch[ix][1],self.conv_arch[ix][0],self.conv_arch[ix][2],self.conv_arch[ix][2]) #[dimout,dimin,H,W]
                if self.num_layers < 3: # fmnist
                    H = self.relu(F.max_pool2d(F.conv2d(H, conv_filters, bias=None,padding=2), 2)) # For mnist
                else:
                    H = self.relu(F.max_pool2d(F.conv2d(H, conv_filters, bias=None), 2))  # for LeNet

            else: #mlp

                H = H.view(batch_size, -1)
                H = torch.cat((H, torch.ones((H.shape[0], 1)).cuda()), dim=1)
                H = H.mm(W)

            if len(self.conv_arch)<= ix <= self.num_layers - 1:
                H = self.relu(H)

            if initial:

                kld_v += torch.sum(torch_kullback_kumar_beta(Al_pos, Bl_pos, Al_pos_pre, Bl_pos_pre))

                kld_r += torch.sum(torch_kullback_normal_normal(real_l, torch.ones_like(real_l), torch.Tensor([0.0]).to(wa.device), torch.Tensor([1.]).to(wa.device)))


            pi_prior_log = torch_stick_breaking_weights(Al_pos, Bl_pos)
            pi_prior_log = torch.reshape(pi_prior_log, (-1, 1))

            pi_prior = torch.exp(pi_prior_log)
            log_q = torch_log_density_logistic(pi_post_prob, torch.Tensor([self.lambda_post]).to(wa.device), Y_post)
            log_p = torch_log_density_logistic(pi_prior, torch.Tensor([self.lambda_prior]).to(wa.device), Y_post)

            self.y_post = Y_post
            self.log_q = log_q
            self.log_p = log_p

            kld_binary += torch.sum(log_q - log_p)

        out = self.linear(H.view(batch_size, -1))

        return F.log_softmax(out,dim=1),kld_binary,kld_v,kld_r





class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


class CNNFashion_Mnist(nn.Module):
    def __init__(self, args):
        super(CNNFashion_Mnist, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(7*7*32, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return F.log_softmax(out, dim=1)





class CNNCifar(nn.Module):
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)


class CNNCifarBN(nn.Module):
    def __init__(self, args):
        super(CNNCifarBN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3)
        self.bn1 = nn.BatchNorm2d(6)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.bn2 = nn.BatchNorm2d(16)
        self.fc1 = nn.Linear(16 * 6 * 6, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1, 16 * 6 * 6)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)


class CNNCifarV(nn.Module):
    def __init__(self, args):
        super(CNNCifarV, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.bn1 = nn.BatchNorm2d(16)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 16, 3)
        self.bn2 = nn.BatchNorm2d(16)
        self.fc1 = nn.Linear(16 * 6 * 6, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))

        x = x.view(-1, 16 * 6 * 6)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)







class modelC(nn.Module):
    def __init__(self, input_size, n_classes=10, **kwargs):
        super(modelC, self).__init__()
        self.conv1 = nn.Conv2d(input_size, 96, 3, padding=1)
        self.conv2 = nn.Conv2d(96, 96, 3, padding=1)
        self.conv3 = nn.Conv2d(96, 96, 3, padding=1, stride=2)
        self.conv4 = nn.Conv2d(96, 192, 3, padding=1)
        self.conv5 = nn.Conv2d(192, 192, 3, padding=1)
        self.conv6 = nn.Conv2d(192, 192, 3, padding=1, stride=2)
        self.conv7 = nn.Conv2d(192, 192, 3, padding=1)
        self.conv8 = nn.Conv2d(192, 192, 1)

        self.class_conv = nn.Conv2d(192, n_classes, 1)


    def forward(self, x):
        x_drop = F.dropout(x, .2)
        conv1_out = F.relu(self.conv1(x_drop))
        conv2_out = F.relu(self.conv2(conv1_out))
        conv3_out = F.relu(self.conv3(conv2_out))
        conv3_out_drop = F.dropout(conv3_out, .5)
        conv4_out = F.relu(self.conv4(conv3_out_drop))
        conv5_out = F.relu(self.conv5(conv4_out))
        conv6_out = F.relu(self.conv6(conv5_out))
        conv6_out_drop = F.dropout(conv6_out, .5)
        conv7_out = F.relu(self.conv7(conv6_out_drop))
        conv8_out = F.relu(self.conv8(conv7_out))

        class_out = F.relu(self.class_conv(conv8_out))
        pool_out = F.adaptive_avg_pool2d(class_out, 1)
        pool_out.squeeze_(-1)
        pool_out.squeeze_(-1)
        return F.log_softmax(pool_out,dim=1)


class BertBinaryClassifier(nn.Module):
    def __init__(self, args,dropout=0.1):
        super(BertBinaryClassifier, self).__init__()
        self.args = args

        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.linear = nn.Linear(768, args.num_classes)

    def forward(self, tokens, masks=None):
        _, pooled_output = self.bert(tokens, attention_mask=masks) #, output_all_encoded_layers=False)
        linear_output = self.linear(pooled_output)
        return F.log_softmax(linear_output, dim=1)



### other big models for cifar10


#### VGGMODEL

cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        #return out
        return F.log_softmax(out)

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)



class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out


