import torch
import torch.nn as nn
import numpy as np
import random
import sys
import os

path = os.getcwd() #current path
sys.path.append(os.path.abspath(os.path.join(path, os.pardir))) #import the parent directory

from model import binarization


class Client():
    def __init__(self, args, model, loss, client_id, tr_loader, te_loader, device, scheduler = None):
        self.args = args
        self.model = model
        self.loss = loss
        self.scheduler = scheduler
        self.client_id = client_id
        self.tr_loader = tr_loader
        self.te_loader = te_loader
        self.device = device
        self.optimizer = None
        self.FLOPs = 0
        self.ratio_per_layer = {}
        self.density = 1
        self.num_weights, self.num_thresholds = self.get_model_numbers()
    
    def local_training(self, comm_rounds):
        """
        Flow: it freezes parameters or thresholds in a given model by conditioning on _iter
        Return: trained model 
        """
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr= self.args.learning_rate * self.args.lr_decay ** comm_rounds, 
                            momentum=self.args.momentum, weight_decay=self.args.weight_decay)
        if self.args.mask:
            for epoch in range(1, self.args.local_epoch+1):

                if epoch < self.args.local_epoch: #freeze thresholds and train only parameters
                    for data, label in self.tr_loader:
                        data.to(self.device), label.to(self.device)
                        self.model.train()
                        if epoch == 1: #Fix a mask
                            for name, layer in self.model.named_modules():
                                if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                                    layer.simple_mask_generation(layer.weight, layer.threshold) 
                                    self.ratio_per_layer[name] = layer.ratio
                                    layer.threshold.requires_grad = False #It remains false unless we set it as True
                                    layer.weight.requires_grad = True
                        output = self.model(data)
                        loss_val = self.loss(output, label)
                        self.optimizer.zero_grad()
                        loss_val.backward()
                        self.optimizer.step()

                        for layer in self.model.modules(): #remove hook from layers
                            if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                                if layer.weight.requires_grad == True:
                                    layer.hook.remove()

                        if self.scheduler is not None:
                            self.scheduler.step()
                    
            
                if epoch == self.args.local_epoch: #Train only thresholds
                    for data, label in self.tr_loader:
                        data.to(self.device), label.to(self.device)
                        self.model.train()
                        for layer in self.model.modules():
                            if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                                layer.threshold.requires_grad = True
                                layer.weight.requires_grad = False
                                layer.mask_generation(layer.weight, layer.threshold)

                        output = self.model(data)
                        loss_val = self.loss(output, label)

                        
                        current_density = 0
                        for key, layer_density in self.ratio_per_layer.items():
                            current_density += layer_density
                        current_density *= 1/len(self.ratio_per_layer)
                        self.density = current_density

                        for layer in self.model.modules():
                            if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                                    loss_val += self.args.th_coeff * torch.sum(torch.exp(-layer.threshold))

                        self.optimizer.zero_grad()
                        loss_val.backward()
                        self.optimizer.step()
                    
                        for layer in self.model.modules(): #remove hook from layers
                            if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                                if layer.weight.requires_grad == True:
                                    layer.hook.remove()
                   
                        if self.scheduler is not None:
                            self.scheduler.step()
            #Calculate number of FLOPs in this local training 

            if self.args.dataset == 'cifar10':                
                self.FLOP_count_weight_cifar10()
                self.FLOP_count_threshold_cifar10()

            elif self.args.dataset == 'cifar100':
                self.FLOP_count_weight_cifar100()
                self.FLOP_count_threshold_cifar100()

        else: #Do like FedAvg
            for epoch in range(1, self.args.local_epoch+1):
                for data, label in self.tr_loader:
                    data.to(self.device), label.to(self.device)
                    self.model.train()
                    output = self.model(data)
                    loss_val = self.loss(output, label)

                    self.optimizer.zero_grad()
                    loss_val.backward()
                    self.optimizer.step()

                    if self.scheduler is not None:
                        self.scheduler.step()
            if self.args.dataset == 'cifar10':                
                self.FLOP_count_weight_cifar10()

            elif self.args.dataset == 'cifar100':
                self.FLOP_count_weight_cifar100()

    def local_test(self):

        total_acc = 0.0
        num = 0
        self.model.eval()
        std_loss = 0. 
        iteration = 0.
        with torch.no_grad():
            for data, label in self.te_loader:
                data, label = data.to(self.device), label.to(self.device)
                for layer in self.model.modules(): #generate a mask
                    if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                        layer.simple_mask_generation(layer.weight, layer.threshold) 
                output = self.model(data)
                pred = torch.max(output, dim=1)[1]
                te_acc = (pred.cpu().numpy()== label.cpu().numpy()).astype(np.float32).sum()

                total_acc += te_acc
                num += output.shape[0]

                std_loss += self.loss(output, label)
                iteration += 1
        std_acc = total_acc/num*100.
        std_loss /= iteration

        if self.args.mask:
            density = self.get_density()
            return std_acc, std_loss, density
        
        else:
            return std_acc, std_loss

    def get_model_numbers(self):
        total_weights = 0
        total_thresholds = 0
        if self.args.mask==1:
            for layer in self.model.modules():
                if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                    total_weights += layer.weight.numel()
                    total_thresholds += layer.threshold.numel()
        else:
            for layer in self.model.modules():
                if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                    total_weights += layer.weight.numel()
        return total_weights, total_thresholds

    def get_density(self):
        total = 0. 
        keep = 0.
        for layer in self.model.modules():
            if isinstance(layer, binarization.MaskedMLP):
                abs_weight = torch.abs(layer.weight)
                threshold = layer.threshold.view(abs_weight.shape[0], -1)
                abs_weight = abs_weight-threshold
                mask = layer.step(abs_weight)
                ratio = torch.sum(mask) / mask.numel() #torch.tensor.numel() returns the number of elements
                total += mask.numel()
                keep += torch.sum(mask)

            if isinstance(layer, binarization.MaskedConv2d):
                weight_shape = layer.weight.shape 
                threshold = layer.threshold.view(weight_shape[0], -1)
                weight = torch.abs(layer.weight)
                weight = weight.view(weight_shape[0], -1)
                weight = weight - threshold
                mask = layer.step(weight)
                ratio = torch.sum(mask) / mask.numel()
                total += mask.numel()
                keep += torch.sum(mask)
        self.density = (keep / total).cpu().detach().numpy()

        return (keep / total).cpu().detach().numpy()
    
    def th_update(self, global_difference):
        with torch.no_grad():
            for name, layer in self.model.named_modules():
                if isinstance(layer, binarization.MaskedConv2d) or isinstance(layer, binarization.MaskedMLP):
                    weight_shape = layer.weight.shape
                    weight = layer.weight
                    weight = weight.view(weight_shape[0], -1)

                    weight_sum_sign = torch.sign(torch.sum(weight, 1))
                    weight_sum_sign = weight_sum_sign.view(weight_shape[0], -1)
                    weight_sum_sign = torch.mul(weight_sum_sign, torch.ones(weight.shape).to(self.device))

                    threshold_dir = global_difference[name].view(weight_shape[0], -1)
                    threshold_dir = torch.mul(threshold_dir, weight_sum_sign)
                    if isinstance(layer, binarization.MaskedConv2d):
                        update_direction = threshold_dir *1/(layer.kernel_size[0]**2)
                    else:
                        update_direction = threshold_dir *1/(layer.in_size)

                    update_direction = update_direction.view(weight_shape)
                    layer.weight += update_direction * (-1)

    def FLOP_count_weight_cifar10(self):
        if self.args.mask == 1:
            for key in self.ratio_per_layer:
                if key == 'conv1':
                    forward_conv1 = (self.args.local_epoch-1) *self.ratio_per_layer[key] *3* 64*5*5 * len(self.tr_loader.dataset) * 28*28 
                    backward_conv1 = (1 + self.ratio_per_layer[key]) * forward_conv1
                    self.FLOPs += forward_conv1 + backward_conv1

                elif key =='conv2':
                    forward_conv2 = (self.args.local_epoch-1) *self.ratio_per_layer[key] * 64* 64 * 5 * 5 * len(self.tr_loader.dataset) * 12 *12
                    backward_conv2 = (1 + self.ratio_per_layer[key]) * forward_conv2 
                    self.FLOPs += forward_conv2 + backward_conv2

                elif key =='conv3':
                    forward_conv3 = (self.args.local_epoch-1) *self.ratio_per_layer[key] * 64* 128*5 *5 * len(self.tr_loader.dataset) * 8 *8
                    backward_conv3 = (1 + self.ratio_per_layer[key]) * forward_conv3
                    self.FLOPs += forward_conv3 + backward_conv3

                elif key == 'conv4':
                    forward_conv4 = (self.args.local_epoch-1) *self.ratio_per_layer[key]* 128 * 128 * 5* 5 *len(self.tr_loader.dataset) *2 *2
                    backward_conv4 = (1 + self.ratio_per_layer[key]) * forward_conv4
                    self.FLOPs += forward_conv4 + backward_conv4

                elif key == 'dense1':
                    forward_dense1 = (self.args.local_epoch-1) *self.ratio_per_layer[key]* 512*128 * len(self.tr_loader.dataset)
                    backward_dense1 = (1 + self.ratio_per_layer[key]) * forward_dense1
                    self.FLOPs += forward_dense1 + backward_dense1

                elif key == 'dense2':
                    forward_dense2 = (self.args.local_epoch-1) *self.ratio_per_layer[key]* 128*128 * len(self.tr_loader.dataset)
                    backward_dense2 = (1 + self.ratio_per_layer[key]) * forward_dense2
                    self.FLOPs += forward_dense2 + backward_dense2

                elif key == 'dense3':
                    forward_dense3 = (self.args.local_epoch-1) *self.ratio_per_layer[key]* 128*10 * len(self.tr_loader.dataset)
                    backward_dense3 = (1 + self.ratio_per_layer[key]) * forward_dense3
                    self.FLOPs += forward_dense3 + backward_dense3

            #######weightupdate with imporatnce extraction#######
            self.FLOPs += (self.num_weights * self.density) * len(self.tr_loader) + self.num_weights*1.5

        else:
            forward_conv1 = (self.args.local_epoch) * 3* 64*5*5 * len(self.tr_loader.dataset) * 28*28 
            backward_conv1 = (2) * forward_conv1
            self.FLOPs += forward_conv1 + backward_conv1

            forward_conv2 = (self.args.local_epoch) * 64* 64 * 5 * 5 * len(self.tr_loader.dataset) * 12 *12
            backward_conv2 = (2) * forward_conv2 
            self.FLOPs += forward_conv2 + backward_conv2

            forward_conv3 = (self.args.local_epoch) * 64* 128*5 *5 * len(self.tr_loader.dataset) * 8 *8
            backward_conv3 = (2) * forward_conv3
            self.FLOPs += forward_conv3 + backward_conv3

            forward_conv4 = (self.args.local_epoch)* 128 * 128 * 5* 5 *len(self.tr_loader.dataset) *2 *2
            backward_conv4 = (2) * forward_conv4
            self.FLOPs += forward_conv4 + backward_conv4

            forward_dense1 = (self.args.local_epoch)* 512*128 * len(self.tr_loader.dataset)
            backward_dense1 = (2) * forward_dense1
            self.FLOPs += forward_dense1 + backward_dense1

            forward_dense2 = (self.args.local_epoch)* 128*128 * len(self.tr_loader.dataset)
            backward_dense2 = (2) * forward_dense2
            self.FLOPs += forward_dense2 + backward_dense2

            forward_dense3 = (self.args.local_epoch)* 128*10 * len(self.tr_loader.dataset)
            backward_dense3 = (2) * forward_dense3
            self.FLOPs += forward_dense3 + backward_dense3
            
            #######weightupdate#######
            self.FLOPs += (self.num_weights * 1) * len(self.tr_loader) 

    def FLOP_count_weight_cifar100(self):
        if self.args.mask == 1:
            for key in self.ratio_per_layer:
                if key == 'conv1':
                    forward_conv1 = (self.args.local_epoch-1) *self.ratio_per_layer[key] *3* 16*3*3 * len(self.tr_loader.dataset) * 32 * 32 
                    backward_conv1 = (1 + self.ratio_per_layer[key]) * forward_conv1
                    self.FLOPs += forward_conv1 + backward_conv1

                elif key =='block1.layer.0.conv1':
                    forward_conv2 = (self.args.local_epoch-1) *self.ratio_per_layer[key] * 16 * 16 * 3 * 3 * len(self.tr_loader.dataset) * 32 *32
                    backward_conv2 = (1 + self.ratio_per_layer[key]) * forward_conv2 
                    self.FLOPs += forward_conv2 + backward_conv2

                elif key =='block1.layer.0.conv2':
                    forward_conv3 = (self.args.local_epoch-1) *self.ratio_per_layer[key] * 16* 16 *3 * 3 * len(self.tr_loader.dataset)* 32 *32
                    backward_conv3 = (1 + self.ratio_per_layer[key]) * forward_conv3
                    self.FLOPs += forward_conv3 + backward_conv3

                elif key == 'block2.layer.0.conv1':
                    forward_conv4 = (self.args.local_epoch-1) *self.ratio_per_layer[key]* 16* 32 *3 * 3  * len(self.tr_loader.dataset) * 16 * 16
                    backward_conv4 = (1 + self.ratio_per_layer[key]) * forward_conv4
                    self.FLOPs += forward_conv4 + backward_conv4

                elif key == 'block2.layer.0.conv2':
                    forward_conv5 = (self.args.local_epoch-1) *self.ratio_per_layer[key]* 32* 32 *3 * 3  * len(self.tr_loader.dataset) * 16 * 16
                    backward_conv5 = (1 + self.ratio_per_layer[key]) * forward_conv5
                    self.FLOPs += forward_conv5 + backward_conv5

                elif key == 'block2.layer.0.convShortcut':
                    forward_conv6 = (self.args.local_epoch-1) *self.ratio_per_layer[key]* 16* 32 *1 * 1  * len(self.tr_loader.dataset) * 16 * 16
                    backward_conv6 = (1 + self.ratio_per_layer[key]) * forward_conv6
                    self.FLOPs += forward_conv6 + backward_conv6

                elif key == 'block3.layer.0.conv1':
                    forward_conv7 = (self.args.local_epoch-1) *self.ratio_per_layer[key]* 32* 64 *3 * 3  * len(self.tr_loader.dataset) * 8 * 8
                    backward_conv7 = (1 + self.ratio_per_layer[key]) * forward_conv7
                    self.FLOPs += forward_conv7 + backward_conv7

                elif key == 'block3.layer.0.conv2':
                    forward_conv8 = (self.args.local_epoch-1) *self.ratio_per_layer[key]* 64* 64 *3 * 3  * len(self.tr_loader.dataset) * 8 * 8
                    backward_conv8 = (1 + self.ratio_per_layer[key]) * forward_conv8
                    self.FLOPs += forward_conv8 + backward_conv8

                elif key == 'block3.layer.0.convShortcut':
                    forward_conv9 = (self.args.local_epoch-1) *self.ratio_per_layer[key]* 32* 64 *1 * 1  * len(self.tr_loader.dataset) * 8 * 8
                    backward_conv9 = (1 + self.ratio_per_layer[key]) * forward_conv9
                    self.FLOPs += forward_conv9 + backward_conv9

                elif key == 'fc':
                    forward_fc = (self.args.local_epoch-1) *self.ratio_per_layer[key]* 64*100  * len(self.tr_loader.dataset)
                    backward_fc = (1 + self.ratio_per_layer[key]) * forward_fc
                    self.FLOPs += forward_fc + backward_fc
            
            #######weightupdate#######
            self.FLOPs += (self.num_weights * self.density) * len(self.tr_loader) + self.num_weights*1.5

        else:
            forward_conv1 = (self.args.local_epoch)  *3* 16*3*3 * len(self.tr_loader.dataset) * 32 * 32 
            backward_conv1 = (2) * forward_conv1
            self.FLOPs += forward_conv1 + backward_conv1

            forward_conv2 = (self.args.local_epoch) * 16 * 16 * 3 * 3 * len(self.tr_loader.dataset) * 32 *32
            backward_conv2 = (2) * forward_conv2 
            self.FLOPs += forward_conv2 + backward_conv2

            forward_conv3 = (self.args.local_epoch)  * 16* 16 *3 * 3 * len(self.tr_loader.dataset)* 32 *32
            backward_conv3 = (2) * forward_conv3
            self.FLOPs += forward_conv3 + backward_conv3

            forward_conv4 = (self.args.local_epoch) * 16* 32 *3 * 3  * len(self.tr_loader.dataset) * 16 * 16
            backward_conv4 = (2) * forward_conv4
            self.FLOPs += forward_conv4 + backward_conv4

            forward_conv5 = (self.args.local_epoch) * 32* 32 *3 * 3  * len(self.tr_loader.dataset) * 16 * 16
            backward_conv5 = (2) * forward_conv5
            self.FLOPs += forward_conv5 + backward_conv5

            forward_conv6 = (self.args.local_epoch) * 16* 32 *1 * 1  * len(self.tr_loader.dataset) * 16 * 16
            backward_conv6 = (2) * forward_conv6
            self.FLOPs += forward_conv6 + backward_conv6

            forward_conv7 = (self.args.local_epoch) * 32* 64 *3 * 3  * len(self.tr_loader.dataset) * 8 * 8
            backward_conv7 = (2) * forward_conv7
            self.FLOPs += forward_conv7 + backward_conv7

            forward_conv8 = (self.args.local_epoch) * 64* 64 *3 * 3  * len(self.tr_loader.dataset) * 8 * 8
            backward_conv8 = (2) * forward_conv8
            self.FLOPs += forward_conv8 + backward_conv8

            forward_conv9 = (self.args.local_epoch) * 32* 64 *1 * 1  * len(self.tr_loader.dataset) * 8 * 8
            backward_conv9 = (2) * forward_conv9
            self.FLOPs += forward_conv9 + backward_conv9

            forward_fc = (self.args.local_epoch) * 64*100  * len(self.tr_loader.dataset) * 8 * 8
            backward_fc = (2) * forward_fc
            self.FLOPs += forward_fc + backward_fc
            
            #######weightupdate#######
            self.FLOPs += (self.num_weights * 1) * len(self.tr_loader) 

    def FLOP_count_threshold_cifar10(self):
        if self.args.mask == 1:
            for key in self.ratio_per_layer:
                if key == 'conv1':
                    forward_conv1 =self.ratio_per_layer[key] *3* 64*5*5 * len(self.tr_loader.dataset) * 28*28 
                    backward_conv1 = (1 + self.ratio_per_layer[key]) * forward_conv1
                    self.FLOPs += forward_conv1 + backward_conv1

                elif key =='conv2':
                    forward_conv2 = self.ratio_per_layer[key] * 64* 64 * 5 * 5 * len(self.tr_loader.dataset) * 12 *12
                    backward_conv2 = (1 + self.ratio_per_layer[key]) * forward_conv2 
                    self.FLOPs += forward_conv2 + backward_conv2

                elif key =='conv3':
                    forward_conv3 = self.ratio_per_layer[key] * 64* 128*5 *5 * len(self.tr_loader.dataset) * 8 *8
                    backward_conv3 = (1 + self.ratio_per_layer[key]) * forward_conv3
                    self.FLOPs += forward_conv3 + backward_conv3

                elif key == 'conv4':
                    forward_conv4 = self.ratio_per_layer[key]* 128 * 128 * 5* 5 *len(self.tr_loader.dataset) *2 *2
                    backward_conv4 = (1 + self.ratio_per_layer[key]) * forward_conv4
                    self.FLOPs += forward_conv4 + backward_conv4

                elif key == 'dense1':
                    forward_dense1 = self.ratio_per_layer[key]* 512*128 * len(self.tr_loader.dataset)
                    backward_dense1 = (1 + self.ratio_per_layer[key]) * forward_dense1
                    self.FLOPs += forward_dense1 + backward_dense1

                elif key == 'dense2':
                    forward_dense2 = self.ratio_per_layer[key]* 128*128 * len(self.tr_loader.dataset)
                    backward_dense2 = (1 + self.ratio_per_layer[key]) * forward_dense2
                    self.FLOPs += forward_dense2 + backward_dense2

                elif key == 'dense3':
                    forward_dense3 = self.ratio_per_layer[key]* 128*10 * len(self.tr_loader.dataset)
                    backward_dense3 = (1 + self.ratio_per_layer[key]) * forward_dense3
                    self.FLOPs += forward_dense3 + backward_dense3
            
            #######threshold_update#######
            self.FLOPs += self.num_thresholds * len(self.tr_loader) + self.num_weights 

    def FLOP_count_threshold_cifar100(self):
        if self.args.mask == 1:
            for key in self.ratio_per_layer:
                if key == 'conv1':
                    forward_conv1 = self.ratio_per_layer[key] *3* 16*3*3 * len(self.tr_loader.dataset) * 32 * 32 
                    backward_conv1 = (1 + self.ratio_per_layer[key]) * forward_conv1
                    self.FLOPs += forward_conv1 + backward_conv1

                elif key =='block1.layer.0.conv1':
                    forward_conv2 = self.ratio_per_layer[key] * 16 * 16 * 3 * 3 * len(self.tr_loader.dataset) * 32 *32
                    backward_conv2 = (1 + self.ratio_per_layer[key]) * forward_conv2 
                    self.FLOPs += forward_conv2 + backward_conv2

                elif key =='block1.layer.0.conv2':
                    forward_conv3 = self.ratio_per_layer[key] * 16* 16 *3 * 3 * len(self.tr_loader.dataset)* 32 *32
                    backward_conv3 = (1 + self.ratio_per_layer[key]) * forward_conv3
                    self.FLOPs += forward_conv3 + backward_conv3

                elif key == 'block2.layer.0.conv1':
                    forward_conv4 = self.ratio_per_layer[key]* 16* 32 *3 * 3  * len(self.tr_loader.dataset) * 16 * 16
                    backward_conv4 = (1 + self.ratio_per_layer[key]) * forward_conv4
                    self.FLOPs += forward_conv4 + backward_conv4

                elif key == 'block2.layer.0.conv2':
                    forward_conv5 = self.ratio_per_layer[key]* 32* 32 *3 * 3  * len(self.tr_loader.dataset) * 16 * 16
                    backward_conv5 = (1 + self.ratio_per_layer[key]) * forward_conv5
                    self.FLOPs += forward_conv5 + backward_conv5

                elif key == 'block2.layer.0.convShortcut':
                    forward_conv6 = self.ratio_per_layer[key]* 16* 32 *1 * 1  * len(self.tr_loader.dataset) * 16 * 16
                    backward_conv6 = (1 + self.ratio_per_layer[key]) * forward_conv6
                    self.FLOPs += forward_conv6 + backward_conv6

                elif key == 'block3.layer.0.conv1':
                    forward_conv7 = self.ratio_per_layer[key]* 32* 64 *3 * 3  * len(self.tr_loader.dataset) * 8 * 8
                    backward_conv7 = (1 + self.ratio_per_layer[key]) * forward_conv7
                    self.FLOPs += forward_conv7 + backward_conv7

                elif key == 'block3.layer.0.conv2':
                    forward_conv8 = self.ratio_per_layer[key]* 64* 64 *3 * 3  * len(self.tr_loader.dataset) * 8 * 8
                    backward_conv8 = (1 + self.ratio_per_layer[key]) * forward_conv8
                    self.FLOPs += forward_conv8 + backward_conv8

                elif key == 'block3.layer.0.convShortcut':
                    forward_conv9 = self.ratio_per_layer[key]* 32* 64 *1 * 1  * len(self.tr_loader.dataset) * 8 * 8
                    backward_conv9 = (1 + self.ratio_per_layer[key]) * forward_conv9
                    self.FLOPs += forward_conv9 + backward_conv9

                elif key == 'fc':
                    forward_fc = self.ratio_per_layer[key]* 64*100  * len(self.tr_loader.dataset) * 8 * 8
                    backward_fc = (1 + self.ratio_per_layer[key]) * forward_fc
                    self.FLOPs += forward_fc + backward_fc
            
            #######weightupdate#######
            self.FLOPs += self.num_thresholds * len(self.tr_loader)  + self.num_weights

