#!/usr/bin/env python
# coding: utf-8

# Code for variational energy minimization wih IBP Prior for Weight Factors

import matplotlib.pyplot as plt
import math
import os
import pdb

#import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.app.flags.DEFINE_string('f', '', 'kernel')

import numpy as np
import time
from tqdm import trange

from sklearn.utils import shuffle
import sklearn.covariance
from collections import defaultdict
from functools import reduce

import torch
from torch import nn
import torch.nn.functional as F
SMALL = 1e-7

def load_dataset(fashion=False):

    'Adapted from: https://github.com/rahafaljundi/Gradient-based-Sample-Selection/blob/master/main.py'

    binary_data_dir = "{}/data/NP_BNN_Experiments/binary_mnist/".format(HOME_DIR)

    if fashion:
        # data_file = fashion_data_dir + 'task_{}'.format(task_no) + '.npz'
        data_file = fashion_data_dir + 'split_mnist.npz'
    else:
        data_file = binary_data_dir + 'split_mnist.npz'

    data = np.load(data_file, allow_pickle=True)

    d_tr = data['tasks_tr']
    d_te = data['tasks_te']

    n_inputs = d_tr[0][1].shape[1]
    n_outputs = 0
    for i in range(len(d_tr)):
        n_outputs = max(n_outputs, d_tr[i][2].max())
        n_outputs = max(n_outputs, d_te[i][2].max())

    return d_tr, d_te, n_inputs, n_outputs + 1, len(d_tr)




def one_hot_encode(x, n_classes):
    """
    One hot encode a list of sample labels. Return a one-hot encoded vector for each label.
    : x: List of sample Labels
    : return: Numpy array of one-hot encoded labels
     """
    return np.eye(n_classes)[x]


def binary_labels(Y, task_id):
    Y = Y.astype(np.float32)
    Y[np.where(Y == task_id*2)] = 0.
    Y[np.where(Y == (task_id*2+1))] = 1.
    return Y


def compute_mahalanobis(mu, sigma_inv, h):
    return np.sum(np.matmul(h - mu, sigma_inv) * (h-mu), axis=1)

# Load data
def get_task(task_no):
    X, Y, X_test, Y_test = load_permuted_dataset(task_no)
    return X, Y, X_test, Y_test

def np_binary_accuracy(logit, Y_gt):
    Y_ref = Y_gt.copy()

    Y_ref[np.where(Y_ref==0)] = -1.

    score = np.sign(Y_ref*logit)

    score[np.where(score <= 0)] = 0.


    return np.mean(score)

def np_accuracy(logit, Y_gt, binary):
    return np.mean(np.argmax(logit, axis=1) == np.argmax(Y_gt, axis=1))

def IBP( N,alpha ):
   #Generate initial number of dishes from a Poisson
    n_init = np.random.poisson(alpha,1)[0]
    Z = np.zeros(shape=(N,n_init))
    Z[0,:] = 1
    m = np.sum(Z,0)
    K = n_init

    for i in range(1,N):
        #Calculate probability of visiting past dishes
        prob = m/(i+1)
        index = np.greater(prob,np.random.rand(1,K))
        Z[i,:] = index.astype(int);
        #Calculate the number of new dishes visited by customer i
        knew = np.random.poisson(alpha/(i+1),1)[0]
        Z=np.concatenate((Z,np.zeros(shape=(N,knew))), axis=1)
        Z[i,K:K+knew:1] = 1
        #Update matrix size and dish popularity count
        K = K+knew
        m = sum(Z,0)

    return Z

def stick_breaking (N, alpha, kumaraswamy=False, truncation=100, probs=False):

    if kumaraswamy:
        v = np_sample_kumaraswamy(alpha, 1, size=(N, truncation))
    else:
        v = np.random.beta(alpha, 1, size=(N, truncation))

    print (v.shape)
    pi = np.cumprod(v, axis=1)

    if probs:
        return pi
    else:
        Z = np.greater(pi, np.random.rand(N, truncation))
        return Z.astype(int)

def tf_stick_breaking_weights(a, b, size=None):

    """
    Args:
    a: Parameter
    b: parameter
    size: Shape of 2D tensor

    Returns: Log probabilities for the binary vector
    """
    if size is None:
        size = a.get_shape()

    v = tf_sample_kumaraswamy(a, b, size)
    v_term = tf.log(v + SMALL)
    log_prior = tf.cumsum(v_term, axis=0)

    return log_prior

def torch_stick_breaking_weights(a, b, size=None):

    """
    Args:
    a: Parameter
    b: parameter
    size: Shape of 2D tensor

    Returns: Log probabilities for the binary vector
    """
    if size is None:
        size = a.shape

    v = torch_sample_kumaraswamy(a, b, size)
    v_term = torch.log(v + SMALL)
    log_prior = torch.cumsum(v_term, axis=0)

    return log_prior

def np_sample_kumaraswamy(a, b, size):
    """
    Numpy function to sample k ~ Kumaraswamy(a, b)
    Args:
    a: shape parameter 1
    b: shape parameter 2
    size: Return shape of np array
    """
    assert a>0 and b>0, "Parameters can not be zero"

    U = np.random.uniform(size=size)
    K = (1 - (1 - U)**(1.0/b))**(1.0/a)
    return K

def tf_sample_kumaraswamy(a, b, size=None):
    """
    TF function to sample k ~ Kumaraswamy(a, b)
    Args:
    a: shape parameter 1
    b: shape parameter 2
    size: Return shape of tf tensor
    """
    U = tf.random.uniform(minval=0.0001, maxval=0.9999, shape=size)
    K = (1 - (1 - U)**(1.0/b))**(1.0/a)
    return K

def torch_sample_kumaraswamy(a, b, size=None,minval=0.0001, maxval=0.9999):
    """
    TF function to sample k ~ Kumaraswamy(a, b)
    Args:
    a: shape parameter 1
    b: shape parameter 2
    size: Return shape of tf tensor
    """
    #U = tf.random.uniform(minval=0.0001, maxval=0.9999, shape=size)

    U = torch.FloatTensor(size).uniform_(minval, maxval)
    U = U.to(a.device)

    K = (1 - (1 - U)**(1.0/b))**(1.0/a)
    return K

def Beta_fn(a, b):
    return tf.exp(tf.math.lgamma(a) + tf.math.lgamma(b) - tf.math.lgamma(a+b))

def torch_Beta_fn(a, b):
    return torch.exp(torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a+b))

def kullback_kumar_beta(a1, b1, prior_alpha, prior_beta=1):

    """
    Credit: Nalisnick et al. "SBP DGMS" ICLR 2017

    TF function to approximate the Kullback Leibler Distance
    between Kumar(a1, b1) and Beta(prior_alpha, prior_beta)
    D = KL(Kumar(a1, b1) || Beta(prior_alpha, prior_beta))
    Important note: Kumar (a, 1) and Beta (a, 1) are the same.
    """

    # compute taylor expansion for E[log (1-v)] term
    # hard-code so we don't have to use Scan()
    kl = 1./(1+a1*b1) * Beta_fn(1./a1, b1)
    kl += 1./(2+a1*b1) * Beta_fn(2./a1, b1)
    kl += 1./(3+a1*b1) * Beta_fn(3./a1, b1)
    kl += 1./(4+a1*b1) * Beta_fn(4./a1, b1)
    kl += 1./(5+a1*b1) * Beta_fn(5./a1, b1)
    kl += 1./(6+a1*b1) * Beta_fn(6./a1, b1)
    kl += 1./(7+a1*b1) * Beta_fn(7./a1, b1)
    kl += 1./(8+a1*b1) * Beta_fn(8./a1, b1)
    kl += 1./(9+a1*b1) * Beta_fn(9./a1, b1)
    kl += 1./(10+a1*b1) * Beta_fn(10./a1, b1)
    kl *= (prior_beta-1)*b1

    # use another taylor approx for Digamma function
    psi_b_taylor_approx = tf.log(b1) - 1./(2 * b1) - 1./(12 * b1**2)

    #     psi_b_taylor_approx = tf.math.digamma(b1)
    kl += (a1-prior_alpha)/a1 * (-0.57721 - psi_b_taylor_approx - 1/b1) # T.psi(self.posterior_b)

    # add normalization constants
    kl += tf.log(a1*b1) + tf.log(Beta_fn(prior_alpha, prior_beta))

    # final term
    kl += -(b1-1)/b1

    return kl

def torch_kullback_kumar_beta(a1, b1, prior_alpha, prior_beta=1):

    """
    Credit: Nalisnick et al. "SBP DGMS" ICLR 2017

    TF function to approximate the Kullback Leibler Distance
    between Kumar(a1, b1) and Beta(prior_alpha, prior_beta)
    D = KL(Kumar(a1, b1) || Beta(prior_alpha, prior_beta))
    Important note: Kumar (a, 1) and Beta (a, 1) are the same.
    """

    # compute taylor expansion for E[log (1-v)] term
    # hard-code so we don't have to use Scan()
    kl = 1./(1+a1*b1) * torch_Beta_fn(1./a1, b1)
    kl += 1./(2+a1*b1) * torch_Beta_fn(2./a1, b1)
    kl += 1./(3+a1*b1) * torch_Beta_fn(3./a1, b1)
    kl += 1./(4+a1*b1) * torch_Beta_fn(4./a1, b1)
    kl += 1./(5+a1*b1) * torch_Beta_fn(5./a1, b1)
    kl += 1./(6+a1*b1) * torch_Beta_fn(6./a1, b1)
    kl += 1./(7+a1*b1) * torch_Beta_fn(7./a1, b1)
    kl += 1./(8+a1*b1) * torch_Beta_fn(8./a1, b1)
    kl += 1./(9+a1*b1) * torch_Beta_fn(9./a1, b1)
    kl += 1./(10+a1*b1) * torch_Beta_fn(10./a1, b1)
    kl *= (prior_beta-1)*b1

    # use another taylor approx for Digamma function
    psi_b_taylor_approx = torch.log(b1) - 1./(2 * b1) - 1./(12 * b1**2)

    #     psi_b_taylor_approx = tf.math.digamma(b1)
    kl += (a1-prior_alpha)/a1 * (-0.57721 - psi_b_taylor_approx - 1/b1) # T.psi(self.posterior_b)

    # add normalization constants
    kl += torch.log(a1*b1) + torch.log(torch_Beta_fn(prior_alpha, prior_beta))

    # final term
    kl += -(b1-1)/b1

    return kl

def kullback_kumar_kumar(a1, b1, prior_alpha, prior_beta):

    """
    TF function to approximate the Kullback Leibler Distance
    between Kumar(a1, b1) and Kumar(prior_alpha, prior_beta)
    D = KL(Kumar(a1, b1) || Kumar(prior_alpha, prior_beta))
    """

    # compute taylor expansion for E[log (1-v^{prior_alpha})] term
    # hard-code so we don't have to use Scan()
    kl = 1./(1*prior_alpha+a1*b1) * Beta_fn(1.*prior_alpha/a1, b1)
    kl += 1./(2*prior_alpha+a1*b1) * Beta_fn(2.*prior_alpha/a1, b1)
    kl += 1./(3*prior_alpha+a1*b1) * Beta_fn(3.*prior_alpha/a1, b1)
    kl += 1./(4*prior_alpha+a1*b1) * Beta_fn(4.*prior_alpha/a1, b1)
    kl += 1./(5*prior_alpha+a1*b1) * Beta_fn(5.*prior_alpha/a1, b1)
    kl += 1./(6*prior_alpha+a1*b1) * Beta_fn(6.*prior_alpha/a1, b1)
    kl += 1./(7*prior_alpha+a1*b1) * Beta_fn(7.*prior_alpha/a1, b1)
    kl += 1./(8*prior_alpha+a1*b1) * Beta_fn(8.*prior_alpha/a1, b1)
    kl += 1./(9*prior_alpha+a1*b1) * Beta_fn(9.*prior_alpha/a1, b1)
    kl += 1./(10*prior_alpha+a1*b1) * Beta_fn(10.*prior_alpha/a1, b1)
    kl += 1./(11*prior_alpha+a1*b1) * Beta_fn(11.*prior_alpha/a1, b1)
    kl += 1./(12*prior_alpha+a1*b1) * Beta_fn(12.*prior_alpha/a1, b1)
    kl += 1./(13*prior_alpha+a1*b1) * Beta_fn(13.*prior_alpha/a1, b1)
    kl += 1./(14*prior_alpha+a1*b1) * Beta_fn(14.*prior_alpha/a1, b1)
    kl += 1./(15*prior_alpha+a1*b1) * Beta_fn(15.*prior_alpha/a1, b1)
    kl *= (prior_beta-1)*prior_alpha*b1

#     psi_b = tf.math.digamma(b1)
    # use another taylor approx for Digamma function
    psi_b_taylor_approx = tf.log(b1) - 1./(2 * b1) - 1./(12 * b1**2)

    kl += tf.log(a1*b1) + (a1-prior_alpha)/a1 * (-0.57721 - psi_b_taylor_approx - 1/b1)
    kl += -(b1-1)/b1
    kl += -tf.log(prior_alpha*prior_beta)

    return kl


def torch_kullback_kumar_kumar(a1, b1, prior_alpha, prior_beta):

    """
    TF function to approximate the Kullback Leibler Distance
    between Kumar(a1, b1) and Kumar(prior_alpha, prior_beta)
    D = KL(Kumar(a1, b1) || Kumar(prior_alpha, prior_beta))
    """

    # compute taylor expansion for E[log (1-v^{prior_alpha})] term
    # hard-code so we don't have to use Scan()
    kl = 1./(1*prior_alpha+a1*b1) *  torch_Beta_fn(1.*prior_alpha/a1, b1)
    kl += 1./(2*prior_alpha+a1*b1) * torch_Beta_fn(2.*prior_alpha/a1, b1)
    kl += 1./(3*prior_alpha+a1*b1) * torch_Beta_fn(3.*prior_alpha/a1, b1)
    kl += 1./(4*prior_alpha+a1*b1) * torch_Beta_fn(4.*prior_alpha/a1, b1)
    kl += 1./(5*prior_alpha+a1*b1) * torch_Beta_fn(5.*prior_alpha/a1, b1)
    kl += 1./(6*prior_alpha+a1*b1) * torch_Beta_fn(6.*prior_alpha/a1, b1)
    kl += 1./(7*prior_alpha+a1*b1) * torch_Beta_fn(7.*prior_alpha/a1, b1)
    kl += 1./(8*prior_alpha+a1*b1) * torch_Beta_fn(8.*prior_alpha/a1, b1)
    kl += 1./(9*prior_alpha+a1*b1) * torch_Beta_fn(9.*prior_alpha/a1, b1)
    kl += 1./(10*prior_alpha+a1*b1) * torch_Beta_fn(10.*prior_alpha/a1, b1)
    kl += 1./(11*prior_alpha+a1*b1) * torch_Beta_fn(11.*prior_alpha/a1, b1)
    kl += 1./(12*prior_alpha+a1*b1) * torch_Beta_fn(12.*prior_alpha/a1, b1)
    kl += 1./(13*prior_alpha+a1*b1) * torch_Beta_fn(13.*prior_alpha/a1, b1)
    kl += 1./(14*prior_alpha+a1*b1) * torch_Beta_fn(14.*prior_alpha/a1, b1)
    kl += 1./(15*prior_alpha+a1*b1) * torch_Beta_fn(15.*prior_alpha/a1, b1)
    kl *= (prior_beta-1)*prior_alpha*b1

#     psi_b = tf.math.digamma(b1)
    # use another taylor approx for Digamma function
    psi_b_taylor_approx = torch.log(b1) - 1./(2 * b1) - 1./(12 * b1**2)

    kl += torch.log(a1*b1) + (a1-prior_alpha)/a1 * (-0.57721 - psi_b_taylor_approx - 1/b1)
    kl += -(b1-1)/b1
    kl += -torch.log(prior_alpha*prior_beta)

    return kl


def kullback_normal_normal(mu_1, sigma2_1, mu_2, sigma2_2):

    kl = 0.5*tf.log(sigma2_2 / sigma2_1)
    kl += (sigma2_1 + ((mu_1 - mu_2)**2)) / (2*sigma2_2)
    kl += -0.5

    return kl

def torch_kullback_normal_normal(mu_1, sigma2_1, mu_2, sigma2_2):

    kl = 0.5*torch.log(sigma2_2 / sigma2_1)
    kl += (sigma2_1 + ((mu_1 - mu_2)**2)) / (2*sigma2_2)
    kl += -0.5

    return kl

def tf_sample_logistic_Y(p, lambd):
    """
    Y = (log (p / (1-p)) + L) / lambd

    Args:
    p: Bernoulli parameter used to construct \alpha = p / (1-p) for BinConcrete(\alpha, temperature)
    lambd: Temp parameter for BinConcrete
    """

    assert (lambd > 0 and lambd <= 1), "Temperature not in (0,1]"
    p = tf.clip_by_value(p, clip_value_min=0.001, clip_value_max=0.999)

    alpha = p / (1 - p)
    U = tf.random.uniform(minval=0.001, maxval=0.999, shape=tf.shape(p))
    L = tf.log(U) - tf.log(1-U)

    Y = (tf.log(alpha) + L)/lambd
    return Y

def torch_sample_logistic_Y(p, lambd,minval=0.001,maxval=0.999):
    """
    Y = (log (p / (1-p)) + L) / lambd

    Args:
    p: Bernoulli parameter used to construct \alpha = p / (1-p) for BinConcrete(\alpha, temperature)
    lambd: Temp parameter for BinConcrete
    """

    assert (lambd > 0 and lambd <= 1), "Temperature not in (0,1]"
    p = torch.clamp(p, min=0.001, max=0.999)

    alpha = p / (1 - p)
    #U = tf.random.uniform(minval=0.001, maxval=0.999, shape=tf.shape(p))
    U = torch.FloatTensor(p.shape).uniform_(minval, maxval)
    U = U.to(alpha.device)
    L = torch.log(U) - torch.log(1-U)
    Y = (torch.log(alpha) + L)/lambd
    return Y

def tf_sample_BernConcrete(p, lambd):
    """
    TF function to sample from Concrete equivalent of
    Bernoulli(p).

    Args:
    p: bernoulli parameter \in (0, 1)

    Returns:
    Samples with shape same as p
    """

    Y = tf_sample_logistic_Y(p, lambd)
    return Y, tf.math.sigmoid(Y)

def torch_sample_BernConcrete(p, lambd):
    """
    TF function to sample from Concrete equivalent of
    Bernoulli(p).

    Args:
    p: bernoulli parameter \in (0, 1)

    Returns:
    Samples with shape same as p
    """

    Y = torch_sample_logistic_Y(p, lambd)
    return Y, torch.sigmoid(Y)


def tf_log_density_logistic(p, lambd, y_sample):
    """
    MC estimate of log(p_{lambd}(Y)) where Y =  (log (p / (1-p)) + L) / lambd

    Args:
    p: Bern parameter
    lambd: Concrete temperature parameter
    y_sample: For MC approximation
    """

    p = tf.clip_by_value(p, clip_value_min=0.001, clip_value_max=0.999)
    lAlpha = tf.log(p / (1-p))

    log_density = lAlpha - lambd*y_sample + tf.log(lambd) - 2 * tf.math.softplus(lAlpha-lambd*y_sample)

    return log_density

def torch_log_density_logistic(p, lambd, y_sample):
    """
    MC estimate of log(p_{lambd}(Y)) where Y =  (log (p / (1-p)) + L) / lambd

    Args:
    p: Bern parameter
    lambd: Concrete temperature parameter
    y_sample: For MC approximation
    """

    p = torch.clamp(p, min=0.001,max=0.999)
    lAlpha = torch.log(p / (1-p))

    log_density = lAlpha - lambd*y_sample + torch.log(lambd) - 2 * F.softplus(lAlpha-lambd*y_sample)

    return log_density




def tf_kl_logistic (y_sample, p_post, lambd_post, p_prior, lambd_prior):
    """
    MC estimate for KL (q(Y) || p(Y))
    where Y = (log(alpha_post) + L) / lambd_post => Y ~ q(Y)
    and Y = (log(alpha_prior) + L) / lambd_prior => Y ~ p(Y)

    Args:
    y_sample: Y ~ q(Y)
    p_post: for \alpha_post
    p_prior: for \alpha_prior
    lambd_post: temperature posterior
    lambd_prior: temperature prior

    Returns:
    KL estimate
    """

    log_q = tf_log_density_logistic(p_post, lambd_post, y_sample)
    log_p = tf_log_density_logistic(p_prior, lambd_prior, y_sample)

    return log_q - log_p



def torch_kl_logistic (y_sample, p_post, lambd_post, p_prior, lambd_prior):
    """
    MC estimate for KL (q(Y) || p(Y))
    where Y = (log(alpha_post) + L) / lambd_post => Y ~ q(Y)
    and Y = (log(alpha_prior) + L) / lambd_prior => Y ~ p(Y)

    Args:
    y_sample: Y ~ q(Y)
    p_post: for \alpha_post
    p_prior: for \alpha_prior
    lambd_post: temperature posterior
    lambd_prior: temperature prior

    Returns:
    KL estimate
    """

    log_q = torch_log_density_logistic(p_post, lambd_post, y_sample)
    log_p = torch_log_density_logistic(p_prior, lambd_prior, y_sample)

    return log_q - log_p



def tf_odds(p):
    """
    Args: p \in (0.001, 0.999)
    returns logit for p \in (0,1)
    """
    p = tf.clip_by_value(p, clip_value_min=0.001, clip_value_max=0.999)
    return p/(1-p)

def torch_odds(p):
    """
    Args: p \in (0.001, 0.999)
    returns logit for p \in (0,1)
    """
    p = torch.clamp(p, min=0.001, max=0.999)
    return p/(1-p)

def tf_inv_odds(odds):
    """
    inverse of tf_odd function
    """
    return (odds / (odds+1))

def torch_inv_odds(odds):
    """
    inverse of tf_odd function
    """
    return (odds / (odds+1))

def tf_logit(p):
    """
    return log(p / (1-p))
    """
    p = tf.clip_by_value(p, clip_value_min=0.001, clip_value_max=0.999)
    return tf.log(p / (1-p))

def torch_logit(p):
    """
    return log(p / (1-p))
    """
    p = torch.clamp(p, min=0.001, max=0.999)
    return torch.log(p / (1-p))

def np_logit(p):
    """
    return log(p / (1-p))
    """
    p = torch.clamp(p,min=0.001, max=0.999)
    return torch.log(p / (1-p))

def get_tf_variable(name='W', shape=(10, 10), init=None, var_list=None, init_std=0.1):

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        if init is None:
            init = tf.random_normal(shape, mean=0.0, stddev=init_std, dtype=tf.float32)

        W = tf.get_variable(name, initializer=init, dtype=tf.float32)

    if (var_list is not None):
        var_list.append(W)
        return W, var_list
    else:
        return W

def get_tf_logit_variable(name='POS_LOGIT', init_value=0.1, shape=(10,10)):

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        init = tf_logit(tf_inv_odds(tf.constant(init_value, shape=shape)))
        alpha_logit=tf.get_variable(name, initializer=init) # initializing the variable as a logit

    # convert logit to prob and then to positive real-number using tf_odds
    return tf_odds(tf.math.sigmoid(alpha_logit))

def get_torch_pos_variable(name='POS_VARIABLE', init_value=0.1, shape=(10,10)):

    init_tensor = torch.empty(shape)
    alpha_logit = torch_logit(torch_inv_odds(torch.nn.init.constant_(init_tensor, init_value)))
    alpha_logit.requires_grad = True # initializing the variable as a logit

    # convert logit to prob and then to positive real-number using tf_odds
    return torch_odds(torch.sigmoid(alpha_logit))

def get_torch_logit_variable(init_value=0.1, shape=(10,10)):
    init_tensor = torch.empty(shape)
    alpha_logit = torch_logit(torch_inv_odds(torch.nn.init.constant_(init_tensor, init_value)))
    alpha_logit.requires_grad = True # initializing the variable as a logit
    return alpha_logit

def get_torch_variable(name='W', shape=(10, 10), init=None, var_list=None, init_std=0.1):

    W = torch.empty(shape)
    if init is None:
        torch.nn.init.normal_(W, mean=0.0, std=init_std)
    else:
        W = torch.from_numpy(init)
    W = W.type(torch.float)
    W.requires_grad = True

    if (var_list is not None):
        var_list.append(W)
        return W, var_list
    else:
        return W


class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return self.softmax(x)






class Server_MLP_V3(nn.Module):

    def __init__(self, model_arch, a_prior,lambda_prior,lambda_post,p_threshold):
        super(Server_MLP_V3,self).__init__()

        self.Rt = {}
        self.PIt = {}

        self.clients = dict()
        self.model_arch = model_arch

        self.a_prior = a_prior
        self.lambda_prior = lambda_prior
        self.lambda_post = lambda_post
        self.p_threshold = p_threshold
        self.sumed_vec = []

        self.fixed_keys = 0
        self.fixed_r = []
        self.fixed_pi = []
        self.fixed_mask = []
        self.init_r_pi()

        # These are global parameters
        self.WA = []
        self.WB = []


        for ix, (dim_in, dim_hidden, dim_out) in enumerate(model_arch[:-1]):

            # include bias
            Wa = np.sqrt(2.0/dim_hidden) * torch.randn(dim_in, dim_hidden, dtype=torch.float, requires_grad=True)
            Wb = np.sqrt(2.0/dim_out) * torch.randn(dim_hidden, dim_out, dtype=torch.float, requires_grad=True)


            self.WA.append(Wa)
            self.WB.append(Wb)


        last_dim_in, _, last_dim_out = model_arch[-1]
        self.linear = nn.Linear(last_dim_in, last_dim_out)
        self.last_w = self.linear.weight
        self.last_b = self.linear.bias


    def forward(self, X, client_id, A_prior=None, B_prior=None, R_prior=None, grad_masks=None, expectation=False):
        # grad_mask is Fx1


        assert client_id in self.Rt, "client id not registered in Rt"
        assert client_id in self.PIt, "client id not registered in PIt"


    def get_client_name(self, T):
        return "_client_{:d}".format(T)

    def get_layer_name(self, ix):
        return "_layer_{:d}".format(ix)


    def init_r_pi(self,pi_init=None, R_init=None):

        R = []
        PI = []
        local_vars = [] # for initialization

        for ix, (dim_in, dim_hidden, dim_out) in enumerate(self.model_arch[:-1]):

            layer_name = "_layer_{:d}".format(ix)

            if R_init == None:
                Rl, local_vars = get_torch_variable("R"+layer_name, shape=(dim_hidden, 1), var_list=local_vars, init_std=0.1)
            else:
                Rl, local_vars = get_torch_variable("R"+layer_name, init=R_init[ix], shape=(dim_hidden, 1), var_list=local_vars)

            if pi_init == None:
                # Take \E[v] for v to initialize pi_post_logit
                init = np.ones((dim_hidden, 1), dtype=np.float32)*(self.a_prior / (self.a_prior+1))
                init = np.cumprod(init, axis=0)
            else:
                # use pi_{t} as the initial value of pi_{t+1}
                init = pi_init[ix]

            init = np.clip(init, a_min = 0.001, a_max = 0.999)
            # convert to logit
            init = np.log(init / ((1-init)))
            pi_post_logit, local_vars = get_torch_variable("pi_logit"+layer_name, init=init, shape=(dim_hidden, 1), var_list=local_vars)
            pi_post_logit = torch.reshape(pi_post_logit, shape=(dim_hidden, 1))


            self.fixed_r.append(Rl)
            self.fixed_pi.append(pi_post_logit)

            y_post, binary_l = torch_sample_BernConcrete(torch.sigmoid(pi_post_logit),self.lambda_post)
            self.fixed_mask.append(binary_l)


    def send_weights(self,client_id, expectation=False):


        weights = list()
        R = self.fixed_r
        MASK = self.fixed_mask

        for ix, (dim_in, dim_hidden,dim_out) in enumerate(self.model_arch[:-1]):
            Real_l = R[ix]
            mask = MASK[ix]

            Wa = self.WA[ix]
            Wb = self.WB[ix]
            client_wa = Wa
            client_wb = Wb

            weights.append((client_wa, Real_l,client_wb))
        weights.append((self.last_w, self.last_b))

        return weights


    def plain_update_weights(self, client_refer, epoch_ix):
        self.delta_Ns = list()
        self.delta_Es = list()

        for ix, (dim_in, dim_hidden, dim_out) in enumerate(self.model_arch[:-1]):

            Wa = self.WA[ix]
            Wb = self.WB[ix]
            rl = self.fixed_r[ix]
            new_wa = torch.zeros_like(Wa,dtype=Wa.dtype)
            new_wb = torch.zeros_like(Wb,dtype=Wb.dtype)

            new_r = torch.zeros_like(rl,dtype=rl.dtype)
            delta_r = torch.zeros_like(rl, dtype=rl.dtype)

            delta_E = 0

            current_keys = ['mlp_{}_wa'.format(ix),'mlp_{}_wb'.format(ix), 'mlp_{}_r'.format(ix), 'mlp_{}_p'.format(ix)]
            for client_id, (client_weights,mask_samples) in client_refer.items():
                current_wa, current_wb =client_weights[current_keys[0]].cpu(),client_weights[current_keys[1]].cpu()
                current_r =client_weights[current_keys[2]].cpu()
                current_mask = mask_samples[ix].cpu()
                current_masked_r = current_r * current_mask

                new_wa += current_wa
                new_wb += current_wb

                new_r += current_r

                delta_r += current_r - self.fixed_r[ix]
                delta_E += torch.norm(delta_r)



            Wa_updated = (new_wa) / (len(client_refer))
            Wb_updated = (new_wb) / (len(client_refer))

            r_updated = (new_r) / (len(client_refer))

            delta_N = torch.norm(r_updated - self.fixed_r[ix])
            self.delta_Ns.append(delta_N)

            current_delta_E = delta_E / len(client_refer)
            self.delta_Es.append(current_delta_E)


            self.WA[ix] = Wa_updated
            self.WB[ix] = Wb_updated

            self.fixed_r[ix] = r_updated



        new_last_w = reduce(lambda x, y: x + y, [client_refer[client][0]['linear.weight'] for client in client_refer])
        new_last_b = reduce(lambda x, y: x + y, [client_refer[client][0]['linear.bias'] for client in client_refer])

        last_w_updated = new_last_w / (len(client_refer))
        last_b_updated = new_last_b / (len(client_refer))


        self.last_w = torch.nn.Parameter(last_w_updated)
        self.last_b = torch.nn.Parameter(last_b_updated)





class Server_V3(nn.Module):

    def __init__(self, model_arch, a_prior,lambda_prior,lambda_post,p_threshold):
        super(Server_V3,self).__init__()

        self.Rt = {}
        self.PIt = {}

        self.clients = dict()
        self.model_arch = model_arch

        self.a_prior = a_prior
        self.lambda_prior = lambda_prior
        self.lambda_post = lambda_post
        self.p_threshold = p_threshold
        self.sumed_vec = []

        self.fixed_keys = 0
        self.fixed_r = []
        self.fixed_pi = []
        self.fixed_mask = []
        self.init_r_pi()

        # These are global parameters
        self.WA = []
        self.WB = []


        for ix, (layer_name,dim_in, dim_hidden, dim_out) in enumerate(model_arch[:-1]):

            Wa = np.sqrt(2.0/dim_hidden) * torch.randn(dim_in, dim_hidden, dtype=torch.float, requires_grad=True)
            Wb = np.sqrt(2.0/dim_out) * torch.randn(dim_hidden, dim_out, dtype=torch.float, requires_grad=True)


            self.WA.append(Wa)
            self.WB.append(Wb)


        _,last_dim_in, _, last_dim_out = model_arch[-1]
        self.linear = nn.Linear(last_dim_in, last_dim_out)
        self.last_w = self.linear.weight
        self.last_b = self.linear.bias


    def forward(self, X, client_id, A_prior=None, B_prior=None, R_prior=None, grad_masks=None, expectation=False):


        assert client_id in self.Rt, "client id not registered in Rt"
        assert client_id in self.PIt, "client id not registered in PIt"


    def get_client_name(self, T):
        return "_client_{:d}".format(T)

    def get_layer_name(self, ix):
        return "_layer_{:d}".format(ix)


    def init_r_pi(self,pi_init=None, R_init=None):

        R = []
        PI = []
        local_vars = [] # for initialization

        for ix, (layer_name,dim_in, dim_hidden, dim_out) in enumerate(self.model_arch[:-1]):

            layer_name = layer_name + '_'+str(ix)

            if R_init == None:
                Rl, local_vars = get_torch_variable("R"+layer_name, shape=(dim_hidden, 1), var_list=local_vars, init_std=0.1)
            else:
                Rl, local_vars = get_torch_variable("R"+layer_name, init=R_init[ix], shape=(dim_hidden, 1), var_list=local_vars)

            if pi_init == None:
                # Take \E[v] for v to initialize pi_post_logit
                init = np.ones((dim_hidden, 1), dtype=np.float32)*(self.a_prior / (self.a_prior+1))
                init = np.cumprod(init, axis=0)
            else:
                # use pi_{t} as the initial value of pi_{t+1}
                init = pi_init[ix]

            init = np.clip(init, a_min = 0.001, a_max = 0.999)
            init = np.log(init / ((1-init)))
            pi_post_logit, local_vars = get_torch_variable("pi_logit"+layer_name, init=init, shape=(dim_hidden, 1), var_list=local_vars)
            pi_post_logit = torch.reshape(pi_post_logit, shape=(dim_hidden, 1))


            self.fixed_r.append(Rl)
            self.fixed_pi.append(pi_post_logit)

            y_post, binary_l = torch_sample_BernConcrete(torch.sigmoid(pi_post_logit),self.lambda_post)
            self.fixed_mask.append(binary_l)


    def send_weights(self,client_id, expectation=False):


        weights = list()
        R = self.fixed_r
        MASK = self.fixed_mask

        for ix, (layer_name,dim_in, dim_hidden,dim_out) in enumerate(self.model_arch[:-1]):
            Real_l = R[ix]
            mask = MASK[ix]

            Wa = self.WA[ix]
            Wb = self.WB[ix]
            client_wa = Wa
            client_wb = Wb

            weights.append((client_wa, Real_l,client_wb))
        weights.append((self.last_w, self.last_b))

        return weights


    def plain_update_weights(self, client_refer, epoch_ix):

        for ix, (layer_name, dim_in, dim_hidden, dim_out) in enumerate(self.model_arch[:-1]):

            Wa = self.WA[ix]
            Wb = self.WB[ix]
            rl = self.fixed_r[ix]

            # add the new current layer of wa and wb from clients
            new_wa = torch.zeros_like(Wa,dtype=Wa.dtype)
            new_wb = torch.zeros_like(Wb,dtype=Wb.dtype)

            new_r = torch.zeros_like(rl,dtype=rl.dtype)



            current_keys = [layer_name + '_{}_wa'.format(ix),layer_name +'_{}_wb'.format(ix),layer_name + '_{}_r'.format(ix)]

            for client_id, (client_weights,mask_samples) in client_refer.items():

                current_wa, current_wb =client_weights[current_keys[0]].cpu(),client_weights[current_keys[1]].cpu()
                current_r =client_weights[current_keys[2]].cpu()
                current_mask = mask_samples[ix].cpu()

                new_wa += current_wa
                new_wb += current_wb

                new_r += current_r


            Wa_updated = (new_wa) / (len(client_refer))
            Wb_updated = (new_wb) / (len(client_refer))

            r_updated = (new_r) / (len(client_refer))



            self.WA[ix] = Wa_updated
            self.WB[ix] = Wb_updated

            self.fixed_r[ix] = r_updated




        new_last_w = reduce(lambda x, y: x + y, [client_refer[client][0]['linear.weight'] for client in client_refer])
        new_last_b = reduce(lambda x, y: x + y, [client_refer[client][0]['linear.bias'] for client in client_refer])

        last_w_updated = new_last_w / (len(client_refer))
        last_b_updated = new_last_b / (len(client_refer))


        self.last_w = torch.nn.Parameter(last_w_updated)
        self.last_b = torch.nn.Parameter(last_b_updated)

