import math
import os
import sys
import traceback
import numpy as np
import ipdb

import torch
from torch import nn
from torch.nn import functional as F

class Learner(nn.Module):

    def __init__(self, config, args = None):
        """

        :param config: network config file, type:list of (string, list)
        :param imgc: 1 or 3
        :param imgsz:  28 or 84
        """
        super(Learner, self).__init__()

        self.config = config
        self.tf_counter = 0
        self.args = args

        # this dict contains all tensors needed to be optimized
        self.vars = nn.ParameterList()
        # running_mean and running_var
        self.vars_bn = nn.ParameterList()

        self.names = []

        for i, (name, param, extra_name) in enumerate(self.config):
            print(name)
            if name is 'conv2d':
                # [ch_out, ch_in, kernelsz, kernelsz]                
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param[:4]))
                    b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    self.vars.append(b)
                else:
                    w = nn.Parameter(torch.ones(*param[:4]))
                    # gain=1 according to cbfin's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                    # [ch_out]
                    self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name is 'convt2d':
                # [ch_in, ch_out, kernelsz, kernelsz, stride, padding]
                w = nn.Parameter(torch.ones(*param[:4]))
                # gain=1 according to cbfin's implementation
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                # [ch_in, ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[1])))

            elif name is 'linear':
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

                         #linear-simple
            elif name == 'linear-simple':
                print('Pass as linear-simple')
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'linear-simple-n':
                print('Pass as linear-simple-n')
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'linear-simple-sort':
                print('Pass as linear-simple-sort')
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))


            elif name == 'non-linear-relu':
                print('Pass as non-linar-relu')
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'non-linear-sparse':
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'non-linear':
                print("pass as non-linear")
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'non-linear-[e-no-copy]':
                print("pass as non-linear-[e-no-copy]")
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'non-linear-random':
                print("pass as non-linear random")
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'linear-[v0]':
                print("pass as linear-[v0]")
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'non-linear-doub':
                print("pass as non-linear double")
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'non-linear-soft':
                print("pass as non-linear-soft")
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'non-linear-v0':
                print("pass as non-linear-v0")
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'non-linear-v1':
                print("pass as non-linear-v1")
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'non-linear-v2':
                print("pass as non-linear-v2")
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'non-linear-full':
                print("pass as non-linear-full")
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name == 'non-linear-relu':
                print("pass as non-linear-relu")
                # layer += 1
                if(self.args.xav_init):
                    w = nn.Parameter(torch.ones(*param))
                    # b = nn.Parameter(torch.zeros(param[0]))
                    torch.nn.init.xavier_normal_(w.data)
                    # b.data.normal_(0, math.sqrt(2)/math.sqrt(1+9*b.data.shape[0]))
                    self.vars.append(w)
                    # self.vars.append(b)
                else:     
                    # [ch_out, ch_in]
                    w = nn.Parameter(torch.ones(*param))
                    # gain=1 according to cbfinn's implementation
                    torch.nn.init.kaiming_normal_(w)
                    self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name is 'cat':
                pass
            elif name is 'cat_start':
                pass
            elif name is "rep":
                pass
            elif name in ["residual3", "residual5", "in"]:
                pass
            elif name in ["BN", "LN", "GN", "IN", "CN"]:
                # [ch_out]
                w = nn.Parameter(torch.ones(param[0]))
                self.vars.append(w)
                # [ch_out]
                b = nn.Parameter(torch.zeros(param[0]))
                self.vars.append(b)

                # must set requires_grad=False
                running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
                running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
                self.vars_bn.extend([running_mean, running_var])

            elif name in "CN-1":
                num_of_bins_curent = self.args.num_of_bins
                bin_size_curent = self.args.sub_bin_size                
                # [ch_out]
                #w = nn.Parameter(torch.ones(param[0]))
                for i in range(0, num_of_bins_curent):
                    w = nn.Parameter(torch.ones(bin_size_curent))
                    self.vars.append(w)
                    b = nn.Parameter(torch.zeros(bin_size_curent))
                    self.vars.append(b)
                # [ch_out]
                #self.vars.append(nn.Parameter(torch.zeros(param[0])))
                    

                # must set requires_grad=False
                for i in range(0, num_of_bins_curent):
                    running_mean = nn.Parameter(torch.zeros(bin_size_curent), requires_grad=False)
                    running_var  = nn.Parameter(torch.ones(bin_size_curent), requires_grad=False)
                    self.vars_bn.extend([running_mean, running_var])

            elif name in ['tanh', 'relu', 'upsample', 'avg_pool2d', 'max_pool2d',
                          'flatten', 'reshape', 'leakyrelu', 'sigmoid', 'only-sparse']:
                continue
            else:
                print(name)
                raise NotImplementedError

    def extra_repr(self):

        info = ''

        for name, param, extra_name in self.config:

            if name is 'conv2d':
                tmp = 'conv2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)' \
                      % (param[1], param[0], param[2], param[3], param[4], param[5],)
                info += tmp + '\n'

            elif name is 'convt2d':
                tmp = 'convTranspose2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)' \
                      % (param[0], param[1], param[2], param[3], param[4], param[5],)
                info += tmp + '\n'

            elif name is 'linear':
                tmp = 'linear:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

                         #linear-simple
                         #linear-simple
            elif name is 'linear-simple':
                tmp = 'linear-simple:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'linear-simple-n':
                tmp = 'linear-simple-n:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'linear-simple-sort':
                tmp = 'linear-simple-sort:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'non-linear-relu':
                tmp = 'non-linear-relu:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'non-linear-sparse':
                tmp = 'non-linear-sparse:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'non-linear-random':
                tmp = 'non-linear-random:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'linear-[v0]':
                tmp = 'linear-[v0]:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'
                #linear-[v0]

            elif name is 'non-linear-[e-no-copy]':
                tmp = 'non-linear-[e-no-copy]:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'non-linear-relu':
                tmp = 'non-linear-relu:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'non-linear-doub':
                tmp = 'non-linear:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'non-linear-v0':
                tmp = 'non-linear-v0:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'non-linear-full':
                tmp = 'non-linear-full:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'non-linear-v1':
                tmp = 'non-linear-v1:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'non-linear-v2':
                tmp = 'non-linear-v2:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'non-linear-relu':
                tmp = 'non-linear-relu:(in:%d, out:%d)' % (param[1], param[0])
                info += tmp + '\n'

            elif name is 'leakyrelu':
                tmp = 'leakyrelu:(slope:%f)' % (param[0])
                info += tmp + '\n'

            elif name is 'cat':
                tmp = 'cat'
                info += tmp + "\n"
            elif name is 'cat_start':
                tmp = 'cat_start'
                info += tmp + "\n"

            elif name is 'rep':
                tmp = 'rep'
                info += tmp + "\n"


            elif name is 'avg_pool2d':
                tmp = 'avg_pool2d:(k:%d, stride:%d, padding:%d)' % (param[0], param[1], param[2])
                info += tmp + '\n'
            elif name is 'max_pool2d':
                tmp = 'max_pool2d:(k:%d, stride:%d, padding:%d)' % (param[0], param[1], param[2])
                info += tmp + '\n'
            elif name in ['flatten', 'tanh', 'relu', 'upsample', 'reshape', 'sigmoid', 'use_logits', 'bn', 'only-sparse']:
                tmp = name + ':' + str(tuple(param))
                info += tmp + '\n'
            else:
                raise NotImplementedError

        return info

    def forward(self, x, vars=None, bn_training=False, feature=False):
        """
        This function can be called by finetunning, however, in finetunning, we dont wish to update
        running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
        Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
        but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
        :param x: [b, 1, 28, 28]
        :param vars:
        :param bn_training: set False to not update
        :return: x, loss, likelihood, kld
        """

        cat_var = False
        cat_list = []

        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0
        non_linear_layer = 0
        pass_e = 'pass_e[no]'
        if self.args.parametric_normalization != None:
            non_linear_layer = 1
            e = []
        try:

            for (name, param, extra_name) in self.config:
                # assert(name == "conv2d")
                if name == 'conv2d':
                    w, b = vars[idx], vars[idx + 1]
                    x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                    idx += 2

                    # print(name, param, '\tout:', x.shape)
                elif name == 'convt2d':
                    w, b = vars[idx], vars[idx + 1]
                    x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5])
                    idx += 2


                elif name == 'linear':

                    # ipdb.set_trace()
                    #print('In the linear layer ...')
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        #w, b = vars[idx], vars[idx + 1]
                        #x = F.linear(x, w, b)
                        #idx += 2
                        
                        w, b = vars[idx], vars[idx + 1]
                        if pass_e == 'pass_e[no]':
                            x = F.linear(x, w, b)
                            #print('Pass trough x')
                        elif pass_e == 'pass_e[yes]':
                            x = F.linear(e, w, b)
                            #print('Pass trough e')
                        idx += 2
                        #x = F.linear(x, w, b)
                        #idx += 2

                    if cat_var:
                        cat_list.append(x)

                elif name == 'non-linear-old':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2		    
                    #x_abs = torch.abs(x)
                    #range_var = torch.tensor(range(0, 40)).long()
                    #tr_mean = torch.mean(x_abs[:, range_var], dim=1)
                    #tr = ( x_abs[:, range_var] <= tr_mean.unsqueeze(1) )
                    ##tr = .5 * ( ( x_abs[:, range_var] <= tr_mean.detach() ) + ( x_abs[:, range_var].detach() <= tr_mean )  )
                    #tr = tr.float()
                    #tr = tr.detach()
                    #e = x[:, range_var] * tr
                    ##e = .5 * ( x[:, range_var] * tr.detach() + x[:, range_var].detach() * tr )
                    ##print(x.size())
                    ##print(e.size())
                    #for ind_per_block in range(1, 8):
                    #    #not l-2 norm normilized error !
                    #    tr_mean = torch.mean(x_abs[:, 40 * ind_per_block + range_var], dim=1)
                    #    tr = ( x_abs[:, 40 * ind_per_block + range_var] <= tr_mean.unsqueeze(1) )
                    #    #tr = .5 * ( ( x_abs[:, range_var] <= tr_mean.detach() ) + ( x_abs[:, range_var].detach() <= tr_mean )  )
                    #    tr = tr.float()
                    #    tr = tr.detach()
                    #    #print(tr.size())
                    #    #e = torch.cat( (e, x[:, 40 * ind_per_block + range_var] * tr), 1 )
                    #    tmp = x[:, 40 * ind_per_block + range_var] * tr
                    #    #tmp = .5 * ( x[:, range_var] * tr.detach() + x[:, range_var].detach() * tr )
                    #    e = torch.cat( (e, tmp), 1 )
                    #x = x - e.detach()
                    r = x.reshape(x.shape[0], 40, 8).detach().clone()
                    #r = r.detach()
                    tr = torch.abs( r ) <=  torch.mean( torch.abs(r), 1).unsqueeze(1)
                    tr = tr.float()
                    e_i = r * tr #.detach() 
                    e_i = e_i.reshape(x.shape[0], 320)
                    x = x - e_i #.detach() 
                    e = x.clone()
                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1

                elif name == 'non-linear-soft':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2                    

                    r = x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                    tr = torch.abs(r) <  torch.mean( torch.abs(r), 1).unsqueeze(1)
                    e_i = r * tr.float()
                    tr_0 = e_i == 0
                    e_i = e_i + tr_0.float() * torch.sign(r) * torch.max(torch.abs(e_i), 1).values.unsqueeze(1)
                    e_i = e_i.reshape(x.shape[0], 320)
                    x = x - e_i

                    x = x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins )
                    if self.args.normilize_sub_bins == 'normilize_sub_bins[ON]':
                        if self.args.const_normalization == 'yes':
                            z = x.detach().clone()
                            z_norm = torch.sqrt( torch.einsum('ijk -> ik', z*z).unsqueeze(1) )
                            x = x/(z_norm + .0000001) 
                        elif self.args.const_normalization == 'no':
                            x = F.normalize(x, dim=1)
                            
                    e = x.reshape(x.shape[0], 320).clone()
                    if self.args.block_sparse == 1:
                        x_tmp = x.detach().clone()
                        x_tmp_per_bl = torch.einsum('ijk -> ik', torch.abs(x_tmp) ).unsqueeze(1)
                        x_tmp_per_bl_mean = torch.einsum( 'ijk -> ij', x_tmp_per_bl).unsqueeze(2)/self.args.num_of_bins 
                        tr = x_tmp_per_bl <= x_tmp_per_bl_mean
                        x = x * tr.float()
                    x = x.reshape(x.shape[0], 320)

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1


                elif name == 'non-linear-doub-va':

                    r = x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                    tr = torch.abs( r ) <=  torch.mean( torch.abs(r), 1).unsqueeze(1)
                    tr = tr.float()
                    e_i = r * tr
                    e_i = e_i.reshape(x.shape[0], 320)
                    x = x - e_i
                    x = F.normalize(x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins ), dim=1).reshape(x.shape[0], 320)

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2                    

                    r = x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                    tr = torch.abs( r ) <=  torch.mean( torch.abs(r), 1).unsqueeze(1)
                    tr = tr.float()
                    e_i = r * tr
                    e_i = e_i.reshape(x.shape[0], 320)
                    x = x - e_i
                    x = F.normalize(x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins ), dim=1).reshape(x.shape[0], 320)
                    e = x.clone()

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1

                elif name == 'non-linear-doub':

                    r = x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                    tr = torch.abs( r ) <=  torch.mean( torch.abs(r), 1).unsqueeze(1)
                    tr = tr.float()
                    e_i = r * tr
                    e_i = e_i.reshape(x.shape[0], 320)
                    x = x - e_i
                    x = F.normalize(x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins ), dim=1).reshape(x.shape[0], 320)
                    e_pr = x.clone()

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2                    

                    r = x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                    tr = torch.abs( r ) <=  torch.mean( torch.abs(r), 1).unsqueeze(1)
                    tr = tr.float()
                    e_i = r * tr
                    e_i = e_i.reshape(x.shape[0], 320)
                    x = x - e_i
                    x = F.normalize(x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins ), dim=1).reshape(x.shape[0], 320)
                    e = torch.stack((x.clone(), e_pr))
                    
                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1

                elif name == 'non-linear':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2 
                   
                    #print(x.shape)
                    shape_0 =x.shape[0]
                    shape_1 =x.shape[1]

                    r = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                    tr = torch.abs( r ) <=  torch.mean( torch.abs(r), 1).unsqueeze(1)
                    e_i = r * tr.float()
                    e_i = e_i.reshape(shape_0, shape_1)
                    x = x - e_i

                    x = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins )
                    if self.args.normilize_sub_bins == 'normilize_sub_bins[ON]':
                        if self.args.const_normalization == 'yes':
                            z = x.detach().clone()
                            z_norm = torch.sqrt( torch.einsum('ijk -> ik', z*z).unsqueeze(1) )
                            x = x/(z_norm + .0000001) 
                        elif self.args.const_normalization == 'no':
                            x = F.normalize(x, dim=1)
                    e = x.reshape(shape_0, shape_1).clone()
                    if self.args.block_sparse == 1:
                        x_tmp = x.detach().clone()
                        x_tmp_per_bl = torch.einsum('ijk -> ik', torch.abs(x_tmp) ).unsqueeze(1)
                        x_tmp_per_bl_mean = torch.einsum( 'ijk -> ij', x_tmp_per_bl).unsqueeze(2)/self.args.num_of_bins 
                        tr = x_tmp_per_bl <= x_tmp_per_bl_mean
                        x = x * tr.float()
                    x = x.reshape(shape_0, shape_1)
                    #x = F.normalize(x, dim=1).reshape(x.shape[0], 320)
                    #e = x.clone()

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1

                elif name == 'non-linear-sparse':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2 
                   
                    shape_0 =x.shape[0]
                    shape_1 =x.shape[1]

                    #nonlinear activation
                    if self.args.nonlinearity_at_penultimate == 'sT[hard]':
                        r = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                        tr = torch.abs( r ) <=  torch.mean( torch.abs(r), 1).unsqueeze(1)
                        e_i = r * tr.float()
                        e_i = e_i.reshape(shape_0, shape_1)
                        x = x - e_i
                    if self.args.nonlinearity_at_penultimate == 'sT[soft]':      
                        r = x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                        tr = torch.abs(r) <  torch.mean( torch.abs(r), 1).unsqueeze(1)
                        e_i = r * tr.float()
                        tr_0 = e_i == 0
                        e_i = e_i + tr_0.float() * torch.sign(r) * torch.max(torch.abs(e_i), 1).values.unsqueeze(1)
                        e_i = e_i.reshape(x.shape[0], 320)
                        x = x - e_i                  
                    if self.args.nonlinearity_at_penultimate == 'relu':      
                        x = F.relu(x)
                    if self.args.nonlinearity_at_penultimate == 'LWTA':                              
                        g = x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                        val, ind = torch.max(g, dim=1)
                        tr = (g>=val.unsqueeze(1))
                        x = x * tr.float().reshape(shape_0, shape_1)

                    #GNC section
                    x = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins )
                    if self.args.normilize_sub_bins == 'normilize_sub_bins[ON]':
                        if self.args.const_normalization == 'yes':
                            z = x.detach().clone()
                            z_norm = torch.sqrt( torch.einsum('ijk -> ik', z*z).unsqueeze(1) )
                            x = x/(z_norm + .0000001) 
                        elif self.args.const_normalization == 'no':
                            if self.args.normalization_type == 'l-norm-2':
                                x = F.normalize(x, p=2, dim=1)
                            elif self.args.normalization_type == 'l-norm-1':
                                x = F.normalize(x, p=1, dim=1)
                            elif self.args.normalization_type == 'l-norm-inf':
                                val, ind = torch.max(x, dim=1)
                                x = x/( val.unsqueeze(1) + .00000001)
                            elif self.args.normalization_type == 'max':
                                g = x.detach().clone()
                                val, ind = torch.max(g, dim=1)
                                tr = (g>=val.unsqueeze(1))
                                x = x * tr.float()
                            elif self.args.normalization_type == 'max_abs':
                                g = torch.abs(x).detach().clone()
                                val, ind = torch.max(g, dim=1)
                                tr = (g>=val.unsqueeze(1))
                                x = x * tr.float()      

                    e = x.reshape(shape_0, shape_1).clone()

                    x = x.reshape(shape_0, shape_1)
                    if cat_var:
                        cat_list.append(x)
                    non_linear_layer = 1
                              #linear-simple
                             #linear-simple

                elif name == 'linear-simple':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2 

                    shape_0 =x.shape[0]
                    shape_1 =x.shape[1]

                    x = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins )
                    if self.args.normilize_sub_bins == 'normilize_sub_bins[ON]':
                        if self.args.const_normalization == 'yes':
                            z = x.detach().clone()
                            z_norm = torch.sqrt( torch.einsum('ijk -> ik', z*z).unsqueeze(1) )
                            x = x/(z_norm + .0000001) 
                        elif self.args.const_normalization == 'no':
                            if self.args.normalization_type == 'l-norm-2':
                                x = F.normalize(x, p=2, dim=1)
                            elif self.args.normalization_type == 'l-norm-1':
                                x = F.normalize(x, p=1, dim=1)
                            elif self.args.normalization_type == 'l-norm-inf':
                                val, ind = torch.max(x, dim=1)
                                x = x/( val.unsqueeze(1) + .00000001)
                            elif self.args.normalization_type == 'max':
                                g = x.detach().clone()
                                val, ind = torch.max(g, dim=1)
                                tr = (g>=val.unsqueeze(1))
                                x = x * tr.float()
                            elif self.args.normalization_type == 'max_abs':
                                g = torch.abs(x).detach().clone()
                                val, ind = torch.max(g, dim=1)
                                tr = (g>=val.unsqueeze(1))
                                x = x * tr.float()      
                    e = x.reshape(shape_0, shape_1).clone()
                    if self.args.block_sparse == 1:
                        print('block sparse is on')
                        x_tmp = x.detach().clone()
                        x_tmp_per_bl = torch.einsum('ijk -> ik', torch.abs(x_tmp) ).unsqueeze(1)
                        x_tmp_per_bl_mean = torch.einsum( 'ijk -> ij', x_tmp_per_bl).unsqueeze(2)/self.args.num_of_bins 
                        tr = x_tmp_per_bl <= x_tmp_per_bl_mean
                        x = x * tr.float()
                    x = x.reshape(shape_0, shape_1)

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1

                elif name == 'linear-simple-n':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2 

                    shape_0 =x.shape[0]
                    shape_1 =x.shape[1]

                    x = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins )
                    if self.args.normilize_sub_bins == 'normilize_sub_bins[ON]':
                        if self.args.const_normalization == 'yes':
                            z = x.detach().clone()
                            z_norm = torch.sqrt( torch.einsum('ijk -> ik', z*z).unsqueeze(1) )
                            x = x/(z_norm + .0000001) 
                        elif self.args.const_normalization == 'no':
                            if self.args.normalization_type == 'l-norm-2':
                                x = F.normalize(x, p=2, dim=1)
                            elif self.args.normalization_type == 'l-norm-1':
                                x = F.normalize(x, p=1, dim=1)
                            elif self.args.normalization_type == 'l-norm-inf':
                                val, ind = torch.max(x, dim=1)
                                x = x/( val.unsqueeze(1) + .00000001)
                            elif self.args.normalization_type == 'max':
                                g = x.detach().clone()
                                val, ind = torch.max(g, dim=1)
                                tr = (g>=val.unsqueeze(1))
                                x = x * tr.float()
                            elif self.args.normalization_type == 'max_abs':
                                g = torch.abs(x).detach().clone()
                                val, ind = torch.max(g, dim=1)
                                tr = (g>=val.unsqueeze(1))
                                x = x * tr.float()      
                    e = x.reshape(shape_0, shape_1).clone()
                    if self.args.block_sparse == 1:
                        print('block sparse is on')
                        x_tmp = x.detach().clone()
                        x_tmp_per_bl = torch.einsum('ijk -> ik', torch.abs(x_tmp) ).unsqueeze(1)
                        x_tmp_per_bl_mean = torch.einsum( 'ijk -> ij', x_tmp_per_bl).unsqueeze(2)/self.args.num_of_bins 
                        tr = x_tmp_per_bl <= x_tmp_per_bl_mean
                        x = x * tr.float()
                    x = x.reshape(shape_0, shape_1)

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1


                elif name == 'linear-simple-sort':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2

                    shape_0 =x.shape[0]
                    shape_1 =x.shape[1]

                    x = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins )
                    if self.args.normilize_sub_bins == 'normilize_sub_bins[ON]':
                        if self.args.const_normalization == 'yes':
                            z = x.detach().clone()
                            z_norm = torch.sqrt( torch.einsum('ijk -> ik', z*z).unsqueeze(1) )
                            x = x/(z_norm + .0000001)
                        elif self.args.const_normalization == 'no':
                            if self.args.normalization_type == 'l-norm-2':
                                x = F.normalize(x, p=2, dim=1)
                            elif self.args.normalization_type == 'l-norm-1':
                                x = F.normalize(x, p=1, dim=1)
                            elif self.args.normalization_type == 'l-norm-inf':
                                val, ind = torch.max(x, dim=1)
                                x = x/( val.unsqueeze(1) + .00000001)
                            elif self.args.normalization_type == 'max':
                                g = x.detach().clone()
                                val, ind = torch.max(g, dim=1)
                                tr = (g>=val.unsqueeze(1))
                                x = x * tr.float()
                            elif self.args.normalization_type == 'max_abs':
                                g = torch.abs(x).detach().clone()
                                val, ind = torch.max(g, dim=1)
                                tr = (g>=val.unsqueeze(1))
                                x = x * tr.float()
                    e = x.reshape(shape_0, shape_1).clone()
                    if self.args.block_sparse == 1:
                        x_tmp = x.detach().clone()
                        x_tmp_per_bl = torch.einsum('ijk -> ik', torch.abs(x_tmp) ).unsqueeze(1)
                        x_tmp_per_bl_mean = torch.einsum( 'ijk -> ij', x_tmp_per_bl).unsqueeze(2)/self.args.num_of_bins
                        tr = x_tmp_per_bl <= x_tmp_per_bl_mean
                        x = x * tr.float()
                    x = x.reshape(shape_0, shape_1)

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1


                elif name == 'non-linear-norm-const':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2 
                   
                    #print(x.shape)
                    shape_0 =x.shape[0]
                    shape_1 =x.shape[1]

                    r = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                    tr = torch.abs( r ) <=  torch.mean( torch.abs(r), 1).unsqueeze(1)
                    e_i = r * tr.float()
                    e_i = e_i.reshape(shape_0, shape_1)
                    x = x - e_i

                    x = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins )
                    if self.args.normilize_sub_bins == 'normilize_sub_bins[ON]':
                        z = x.detach().clone()
                        z_norm = torch.sqrt( torch.einsum('ijk -> ik', z*z).unsqueeze(1) )
                        x = x/(z_norm + .0000001) 
                    e = x.reshape(shape_0, shape_1).clone()
                    if self.args.block_sparse == 1:
                        x_tmp = x.detach().clone()
                        x_tmp_per_bl = torch.einsum('ijk -> ik', torch.abs(x_tmp) ).unsqueeze(1)
                        x_tmp_per_bl_mean = torch.einsum( 'ijk -> ij', x_tmp_per_bl).unsqueeze(2)/self.args.num_of_bins 
                        tr = x_tmp_per_bl <= x_tmp_per_bl_mean
                        x = x * tr.float()
                    x = x.reshape(shape_0, shape_1)
                    #x = F.normalize(x, dim=1).reshape(x.shape[0], 320)
                    #e = x.clone()

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1
                elif name == 'non-linear-[e-no-copy]':

                    #print('In the nonlinear layer...')
                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2 
                   
                    #print(x.shape)
                    shape_0 =x.shape[0]
                    shape_1 =x.shape[1]

                    r = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                    if self.args.sub_bin_size > 1:
                        tr = torch.abs( r ) <=  torch.mean( torch.abs(r), 1).unsqueeze(1)
                    else:
                        tr = torch.abs( r ) <=  torch.mean( torch.abs(r), 2).unsqueeze(2)
                    e_i = r * tr.float()
                    e_i = e_i.reshape(shape_0, shape_1)
                    x = x - e_i

                    x = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins )
                    if self.args.normilize_sub_bins == 'normilize_sub_bins[ON]':
                        if self.args.const_normalization == 'yes':
                            z = x.detach().clone()
                            z_norm = torch.sqrt( torch.einsum('ijk -> ik', z*z).unsqueeze(1) )
                            x = x/(z_norm + .0000001) 
                        elif self.args.const_normalization == 'no':
                            x = F.normalize(x, dim=1)
                    e = x.reshape(shape_0, shape_1)
                    if self.args.block_sparse == 1:
                        x_tmp = x.detach().clone()
                        x_tmp_per_bl = torch.einsum('ijk -> ik', torch.abs(x_tmp) ).unsqueeze(1)
                        x_tmp_per_bl_mean = torch.einsum( 'ijk -> ij', x_tmp_per_bl).unsqueeze(2)/self.args.num_of_bins 
                        tr = x_tmp_per_bl <= x_tmp_per_bl_mean
                        x = x * tr.float()
                    #x = x.reshape(shape_0, shape_1)
                    #x = F.normalize(x, dim=1).reshape(x.shape[0], 320)
                    #e = x.clone()

                    if cat_var:
                        cat_list.append(e)
                    pass_e = 'pass_e[yes]'

                    non_linear_layer = 1

                elif name == 'non-linear-random':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2 
                   
                    #print(x.shape)
                    shape_0 =x.shape[0]
                    shape_1 =x.shape[1]

                    r = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                    if self.args.evaluate_on_tasks == 'no':
                        z = torch.randn(r.shape)
                        tr = torch.abs( z ) <=  torch.mean( torch.abs(z), 1).unsqueeze(1)
                        if self.cuda:
                            tr = tr.cuda()
                    elif self.args.evaluate_on_tasks == 'yes':
                        #print('I evlauate on tasks')
                        tr = torch.abs( r ) <=  torch.mean( torch.abs(r), 1).unsqueeze(1)
                    e_i = r * tr.float()
                    e_i = e_i.reshape(shape_0, shape_1)
                    x = x - e_i

                    x = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins )
                    if self.args.normilize_sub_bins == 'normilize_sub_bins[ON]':
                        x = F.normalize(x, dim=1)
                    e = x.reshape(shape_0, shape_1).clone()
                    if self.args.block_sparse == 1:
                        x_tmp = x.detach().clone()
                        x_tmp_per_bl = torch.einsum('ijk -> ik', torch.abs(x_tmp) ).unsqueeze(1)
                        x_tmp_per_bl_mean = torch.einsum( 'ijk -> ij', x_tmp_per_bl).unsqueeze(2)/self.args.num_of_bins 
                        tr = x_tmp_per_bl <= x_tmp_per_bl_mean
                        x = x * tr.float()
                    x = x.reshape(shape_0, shape_1)
                    #x = F.normalize(x, dim=1).reshape(x.shape[0], 320)
                    #e = x.clone()

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1

                elif name == 'linear-[v0]':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2 
                   
                    #print(x.shape)
                    shape_0 =x.shape[0]
                    shape_1 =x.shape[1]

                    #r = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                    #z = torch.randn(r.shape)
                    #tr = torch.abs( z ) <=  torch.mean( torch.abs(z), 1).unsqueeze(1)
                    #if self.cuda:
                    #    tr = tr.cuda()
                    #e_i = r * tr.float()
                    #e_i = e_i.reshape(shape_0, shape_1)
                    #x = x - e_i

                    x = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins )
                    if self.args.normilize_sub_bins == 'normilize_sub_bins[ON]':
                        x = F.normalize(x, dim=1)
                    e = x.reshape(shape_0, shape_1).clone()
                    if self.args.block_sparse == 1:
                        x_tmp = x.detach().clone()
                        x_tmp_per_bl = torch.einsum('ijk -> ik', torch.abs(x_tmp) ).unsqueeze(1)
                        x_tmp_per_bl_mean = torch.einsum( 'ijk -> ij', x_tmp_per_bl).unsqueeze(2)/self.args.num_of_bins 
                        tr = x_tmp_per_bl <= x_tmp_per_bl_mean
                        x = x * tr.float()
                    x = x.reshape(shape_0, shape_1)
                    #x = F.normalize(x, dim=1).reshape(x.shape[0], 320)
                    #e = x.clone()

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1

                elif name == 'non-linear-relu':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2                    

                    x = F.relu(x)
                    #r = b.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                    #tr = r  <= 0
                    #tr = tr.float()
                    #e_i = r * tr
                    #e_i = e_i.reshape(x.shape[0], 320)
                    #x = x - e_i
                    #x = F.normalize(x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins ), dim=1)
                    if self.args.normilize_sub_bins == 'normilize_sub_bins[ON]':
                        if self.args.const_normalization == 'yes':
                            z = x.detach().clone()
                            z_norm = torch.sqrt( torch.einsum('ijk -> ik', z*z).unsqueeze(1) )
                            x = x/(z_norm + .0000001) 
                        elif self.args.const_normalization == 'no':
                            if self.args.normalization_type == 'l-norm-2':
                                x = F.normalize(x, p=2, dim=1)
                            elif self.args.normalization_type == 'l-norm-1':
                                x = F.normalize(x, p=1, dim=1)
                            elif self.args.normalization_type == 'l-norm-inf':
                                val, ind = torch.max(x, dim=1)
                                x = x/( val.unsqueeze(1) + .00000001)
                            elif self.args.normalization_type == 'max':
                                g = x.detach().clone()
                                val, ind = torch.max(g, dim=1)
                                tr = (g>=val.unsqueeze(1))
                                x = x * tr.float()
                            elif self.args.normalization_type == 'max_abs':
                                g = torch.abs(x).detach().clone()
                                val, ind = torch.max(g, dim=1)
                                tr = (g>=val.unsqueeze(1))
                                x = x * tr.float()
                    x = x.reshape(x.shape[0], self.args.sub_bin_size * self.args.num_of_bins)
                    e = x.clone()

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1


                elif name == 'non-linear-vd':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2
                    
                    x_abs = torch.abs(x)
                    range_var = torch.tensor(range(0, 40)).long()

                    tr_mean = torch.mean(x_abs[:, range_var], dim=1)
                    tr = ( x_abs[:, range_var] <= tr_mean.unsqueeze(1) )
                    tr = tr.float()
                    tr = tr.detach()
                    e = x[:, range_var] * tr
                    #print(x.size())
                    #print(e.size())
                    for ind_per_block in range(1, 8):
                        #not l-2 norm normilized error !
                        tr_mean = torch.mean(x_abs[:, 40 * ind_per_block + range_var], dim=1)
                        tr = ( x_abs[:, 40 * ind_per_block + range_var] <= tr_mean.unsqueeze(1) )
                        tr = tr.float()
                        tr = tr.detach()
                        #print(tr.size())
                        e = torch.cat( (e, x[:, 40 * ind_per_block + range_var] * tr), 1 )
                    x = x - e.detach()
                    #e = x
                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1

                elif name == 'non-linear-v0':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2
                    
                    x_abs = torch.abs(x)
                    range_var = torch.tensor(range(0, 40)).long()

                    tr_mean = torch.mean(x_abs[:, range_var], dim=1)
                    tr = ( x_abs[:, range_var] <= tr_mean.unsqueeze(1) )
                    tr = tr.float()
                    #tr = tr.detach()
                    e = x[:, range_var] * tr
                    #print(x.size())
                    #print(e.size())
                    for ind_per_block in range(1, 8):
                        #not l-2 norm normilized error !
                        tr_mean = torch.mean(x_abs[:, 40 * ind_per_block + range_var], dim=1)
                        tr = ( x_abs[:, 40 * ind_per_block + range_var] <= tr_mean.unsqueeze(1) )
                        tr = tr.float()
                        #tr = tr.detach()
                        #print(tr.size())
                        e = torch.cat( (e, x[:, 40 * ind_per_block + range_var] * tr), 1 )
                    x = x - e

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1

                elif name == 'non-linear-v1':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2
                    
                    x_abs = torch.abs(x)
                    range_var = torch.tensor(range(0, 40)).long()

                    tr_mean = torch.mean(x_abs[:, range_var], dim=1)
                    tr = ( x_abs[:, range_var] <= tr_mean.unsqueeze(1) )
                    tr = tr.float()
                    #tr = tr.detach()
                    e = x[:, range_var] * tr
                    #print(x.size())
                    #print(e.size())
                    for ind_per_block in range(1, 8):
                        #not l-2 norm normilized error !
                        tr_mean = torch.mean(x_abs[:, 40 * ind_per_block + range_var], dim=1)
                        tr = ( x_abs[:, 40 * ind_per_block + range_var] <= tr_mean.unsqueeze(1) )
                        tr = tr.float()
                        #tr = tr.detach()
                        #print(tr.size())
                        e = torch.cat( (e, x[:, 40 * ind_per_block + range_var] * tr), 1 )
                    x = x - e

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1

                elif name == 'non-linear-v2':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2
                    
                    x_abs = torch.abs(x)
                    range_var = torch.tensor(range(0, 40)).long()

                    tr_mean = torch.mean(x_abs[:, range_var], dim=1)
                    tr = ( x_abs[:, range_var] <= tr_mean.unsqueeze(1) )
                    tr = tr.float()
                    tr = tr.detach()
                    e = x[:, range_var] * tr
                    #print(x.size())
                    #print(e.size())
                    for ind_per_block in range(1, 8):
                        #not l-2 norm normilized error !
                        tr_mean = torch.mean(x_abs[:, 40 * ind_per_block + range_var], dim=1)
                        tr = ( x_abs[:, 40 * ind_per_block + range_var] <= tr_mean.unsqueeze(1) )
                        tr = tr.float()
                        tr = tr.detach()
                        #print(tr.size())
                        e = torch.cat( (e, x[:, 40 * ind_per_block + range_var] * tr), 1 )
                    x = x - e

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1


                elif name == 'non-linear-full':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2
                    
                    #x_abs = torch.abs(x)
                    #tr_mean = torch.mean(x_abs, dim=1)
                    #tr = ( x_abs <= tr_mean.unsqueeze(1) )
                    #tr = tr.float()
                    #tr = tr.detach()
                    #e = x * tr

                    #x = x - e.detach()

                    tr = torch.abs(x.detach().clone())
                    tr = ( tr <= torch.mean(tr, dim=1).unsqueeze(1) )
                    tr = tr.float()
                    x = tr * x
                    x = F.normalize(x, dim=1)
                    e = x.clone()

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1

                elif name == 'non-linear-relu-1':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2
                    
                    #e = x.clone()
                    #e = x
                    #r = x.reshape(x.shape[0], 40, 8).clone()
                    #tr =  F.relu(  x.reshape(x.shape[0], 40, 8) -  torch.mean( torch.abs(x.reshape(x.shape[0], 40, 8).detach()), 1).unsqueeze(1) ) - F.relu(- x.reshape(x.shape[0], 40, 8) -  torch.mean( torch.abs(x.reshape(x.shape[0], 40, 8).detach()), 1).unsqueeze(1) )
                    #e = x.reshape(x.shape[0], 320) - tr.reshape(x.shape[0], 320)
                    #x = tr.reshape(x.shape[0], 320)
                    #e = F.relu(-x)
                    
                    x = F.relu(x)
                    e = x

                    if cat_var:
                        cat_list.append(x)

                    non_linear_layer = 1

                elif name == 'rep':
                    # print('rep')
                    # print(x.shape)
                    if feature:
                        return x

                elif name == "cat_start":
                    cat_var = True
                    cat_list = []

                elif name == "cat":
                    cat_var = False
                    x = torch.cat(cat_list, dim=1)

                elif name == 'BN':
                    w, b = vars[idx], vars[idx + 1]
                    #x = F.group_norm(x, 32, weight=w, bias=b)
                    #x = F.group_norm(x, 1, weight=w, bias=b)
                    #x = F.group_norm(x, 320, weight=w, bias=b)
                    running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx + 1]
                    x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
                    #print('normalization ...')

                    idx += 2
                    bn_idx += 2

                elif name == 'LN':
                    #x = F.group_norm(x, 32, weight=w, bias=b)

                    #x = F.layer_norm(x, x.shape[1:], weight=w, bias=b)
                    w, b = vars[idx], vars[idx + 1]
                    x = F.group_norm(x, 1, weight=w, bias=b)
                    #x = F.group_norm(x, 1, weight=w, bias=b)

                    #x = F.group_norm(x, 320, weight=w, bias=b)
                    #running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx + 1]
                    #x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
                    #print('normalization ...')

                    idx += 2
                    bn_idx += 2

                elif name == 'GN':
                    w, b = vars[idx], vars[idx + 1]
                    x = F.group_norm(x, self.args.num_of_bins, weight=w, bias=b)
                    #x = F.group_norm(x, 1, weight=w, bias=b)
                    #x = F.group_norm(x, 320, weight=w, bias=b)
                    #running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx + 1]
                    #x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
                    #print('normalization ...')

                    idx += 2
                    bn_idx += 2 

                elif name == 'IN':
                    w, b = vars[idx], vars[idx + 1]
                    running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx + 1]
                    #x = F.instance_norm(x, running_mean, running_var, weight=w, bias=b, use_input_stats=True, momentum=0.1, eps=1e-05)
                    x = F.instance_norm(x, running_mean, running_var, weight=w, bias=b)

                    idx += 2
                    bn_idx += 2                     
                elif name == 'CN':
                    
                    #x = x.reshape(x.shape[0], self.args.sub_bin_size, self.args.num_of_bins )
                    #num_of_bins_curent = self.args.num_of_bins
                    #bin_size_curent = self.args.sub_bin_size
                    #for i in range(0, num_of_bins_curent):
                    #    w, b = vars[idx], vars[idx + 1]
                    #    running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx + 1]
                    #    x[:,:,i] = F.batch_norm(x[:,:,i], running_mean, running_var, weight=w, bias=b, training=bn_training)
                    #    idx += 2
                    #    bn_idx += 2
                    #x = x.reshape(x.shape[0], self.args.sub_bin_size*self.args.num_of_bins )
                    #dk=1
                    w, b = vars[idx], vars[idx + 1]
                    running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx + 1]
                    x = F.group_norm(x, self.args.num_of_bins, None, None, 1e-5)
                    x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
                    #x = F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps)
                    
                    #x = F.group_norm(x, 32, weight=w, bias=b)
                    #x = F.group_norm(x, 1, weight=w, bias=b)
                    #x = F.group_norm(x, 320, weight=w, bias=b)
                    #running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx + 1]
                    #x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
                    #print('normalization ...')

                    idx += 2
                    bn_idx += 2

                elif name == 'flatten':
                    # print('flatten')
                    # print(x.shape)

                    x = x.view(x.size(0), -1)

                elif name == 'reshape':
                    # [b, 8] => [b, 2, 2, 2]
                    x = x.view(x.size(0), *param)
                elif name == 'relu':
                    x = F.relu(x, inplace=param[0])
                elif name == 'only-sparse':

                    shape_0 =x.shape[0]
                    shape_1 =x.shape[1]
                    
                    r = x.reshape(shape_0, self.args.sub_bin_size, self.args.num_of_bins ).detach().clone()
                    tr = torch.abs( r ) <=  torch.mean( torch.abs(r), 1).unsqueeze(1)
                    e_i = r * tr.float()
                    e_i = e_i.reshape(shape_0, shape_1)
                    x = x - e_i
                    x = x.reshape(shape_0, shape_1)
                    
                elif name == 'leakyrelu':
                    x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
                elif name == 'tanh':
                    x = F.tanh(x)
                elif name == 'sigmoid':
                    x = torch.sigmoid(x)
                elif name == 'upsample':
                    x = F.upsample_nearest(x, scale_factor=param[0])
                elif name == 'max_pool2d':
                    x = F.max_pool2d(x, param[0], param[1], param[2])
                elif name == 'avg_pool2d':
                    x = F.avg_pool2d(x, param[0], param[1], param[2])

                else:
                    print(name)
                    raise NotImplementedError

        except:
            traceback.print_exc(file=sys.stdout)
            ipdb.set_trace()

        # make sure variable is used properly
        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)
        
        if non_linear_layer == 1:
            return x, e
        else:
            return x

    def zero_grad(self, vars=None):
        """

        :param vars:
        :return:
        """
        with torch.no_grad():
            if vars is None:
                for p in self.vars:
                    if p.grad is not None:
                        p.grad.zero_()
            else:
                for p in vars:
                    if p.grad is not None:
                        p.grad.zero_()

    def define_task_lr_params(self, alpha_init=1e-3): 
        # Setup learning parameters
        self.alpha_lr = nn.ParameterList([])

        self.lr_name = []
        for n, p in self.named_parameters():
            self.lr_name.append(n)

        for p in self.parameters():
            self.alpha_lr.append(nn.Parameter(alpha_init * torch.ones(p.shape, requires_grad=True)))                                           

    def define_task_lr_params_e(self, alpha_init=1e-3): 
        # Setup learning parameters
        self.alpha_lr_e = nn.ParameterList([])

        self.lr_name_e = []
        for n, p in self.named_parameters():
            self.lr_name_e.append(n)

        for p in self.parameters():
            self.alpha_lr_e.append(nn.Parameter(alpha_init * torch.ones(p.shape, requires_grad=True)))  

    def define_reg_param(self, gamma_init=1e-3, type_parameter='multi'): 
        # Setup regularization parameters
        #self.gamma_name.append('gamma_reg')

        if type_parameter == 'single':
            self.gamma_r = nn.ParameterList([])
            self.gamma_r.append( nn.Parameter(gamma_init * torch.ones(1, requires_grad=True)) )
        else:
            self.gamma_r = nn.ParameterList([])
            self.lr_name_gamma = []
            for n, p in self.named_parameters():
                self.lr_name_gamma.append(n)

            for p in self.parameters():
                self.gamma_r.append(nn.Parameter(gamma_init * torch.ones(1, requires_grad=True)))  

    def parameters(self):
        """
        override this function since initial parameters will return with a generator.
        :return:
        """
        return self.vars


