import torch
import torch.nn as nn
import numpy as np
import random
import sys
import os

path = os.getcwd() #current path
sys.path.append(os.path.abspath(os.path.join(path, os.pardir))) #import the parent directory

from model import binarization


class Client():
    def __init__(self, args, model, loss, client_id, tr_loader, te_loader, device, scheduler = None):
        self.args = args
        self.model = model
        self.loss = loss
        self.scheduler = scheduler
        self.client_id = client_id
        self.tr_loader = tr_loader
        self.te_loader = te_loader
        self.device = device
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr= self.args.learning_rate, 
                            momentum=self.args.momentum, weight_decay=self.args.weight_decay)
        self.FLOPs = 0
        self.ratio_per_layer = {}
        self.density = 1
        self.num_weights, self.num_thresholds = self.get_model_numbers()
    
    def local_training(self, comm_rounds):
        """
        Flow: it freezes parameters or thresholds in a given model by conditioning on _iter
        Return: trained model 
        """
        if self.args.mask:
            for epoch in range(1, self.args.local_epoch+1):
                # logger.info("-"*30 + "Epoch start" + "-"*30)

                if epoch < self.args.local_epoch: #freeze thresholds and train only parameters
                    for data, label in self.tr_loader:
                        data.to(self.device), label.to(self.device)
                        self.model.train()
                        if epoch == 1: #Fix a mask
                            for name, layer in self.model.named_modules():
                                if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                                    layer.simple_mask_generation(layer.weight, layer.threshold) 
                                    self.ratio_per_layer[name] = layer.ratio
                                    layer.threshold.requires_grad = False #It remains false unless we set it as True
                                    layer.weight.requires_grad = True
                        output = self.model(data)
                        loss_val = self.loss(output, label)
                        self.optimizer.zero_grad()
                        loss_val.backward()
                        self.optimizer.step()

                        for layer in self.model.modules(): #remove hook from layers
                            if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                                if layer.weight.requires_grad == True:
                                    layer.hook.remove()

                        if self.scheduler is not None:
                            self.scheduler.step()
                    
            
                if epoch == self.args.local_epoch: #Train only thresholds
                    for data, label in self.tr_loader:
                        data.to(self.device), label.to(self.device)
                        self.model.train()
                        for layer in self.model.modules():
                            if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                                layer.threshold.requires_grad = True
                                layer.weight.requires_grad = False
                                layer.mask_generation(layer.weight, layer.threshold)

                        output = self.model(data)
                        loss_val = self.loss(output, label)

                        
                        current_density = 0
                        for key, layer_density in self.ratio_per_layer.items():
                            current_density += layer_density
                        current_density *= 1/len(self.ratio_per_layer)
                        self.density = current_density

                        for layer in self.model.modules():
                            if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                                if comm_rounds < self.args.penalty_scheduler:
                                    penalty_schedule = 1/(1+np.exp(5 - 10*comm_rounds/(self.args.penalty_scheduler) ) )
                                    loss_val += self.args.th_coeff *penalty_schedule * torch.sum(torch.exp(-layer.threshold))
                                elif comm_rounds >= self.args.penalty_scheduler and self.density > 0.10 :
                                    loss_val += self.args.th_coeff * torch.sum(torch.exp(-layer.threshold))


                        self.optimizer.zero_grad()
                        loss_val.backward()
                        self.optimizer.step()
                    
                        for layer in self.model.modules(): #remove hook from layers
                            if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                                if layer.weight.requires_grad == True:
                                    layer.hook.remove()
                   
                        if self.scheduler is not None:
                            self.scheduler.step()
            #Calculate number of FLOPs in this local training 
                            
            self.FLOP_count_weight()
            self.FLOP_count_threshold()

        else: #Do like FedAvg
            for epoch in range(1, self.args.local_epoch+1):
                for data, label in self.tr_loader:
                    data.to(self.device), label.to(self.device)
                    self.model.train()
                    output = self.model(data)
                    loss_val = self.loss(output, label)

                    self.optimizer.zero_grad()
                    loss_val.backward()
                    self.optimizer.step()

                    if self.scheduler is not None:
                        self.scheduler.step()
            self.FLOP_count_weight()

    def local_test(self):

        total_acc = 0.0
        num = 0
        self.model.eval()
        std_loss = 0. 
        iteration = 0.
        with torch.no_grad():
            for data, label in self.te_loader:
                data, label = data.to(self.device), label.to(self.device)
                for layer in self.model.modules(): #generate a mask
                    if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                        layer.simple_mask_generation(layer.weight, layer.threshold) 
                output = self.model(data)
                pred = torch.max(output, dim=1)[1]
                te_acc = (pred.cpu().numpy()== label.cpu().numpy()).astype(np.float32).sum()

                total_acc += te_acc
                num += output.shape[0]

                std_loss += self.loss(output, label)
                iteration += 1
        std_acc = total_acc/num*100.
        std_loss /= iteration

        if self.args.mask:
            density = self.get_density()
            return std_acc, std_loss, density
        
        else:
            return std_acc, std_loss

    def get_model_numbers(self):
        total_weights = 0
        total_thresholds = 0
        if self.args.mask==1:
            for layer in self.model.modules():
                if isinstance(layer, binarization.MaskedMLP) or isinstance(layer, binarization.MaskedConv2d):
                    total_weights += layer.weight.numel()
                    total_thresholds += layer.threshold.numel()
        else:
            for layer in self.model.modules():
                if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                    total_weights += layer.weight.numel()
        return total_weights, total_thresholds

    def get_density(self):
        total = 0. 
        keep = 0.
        for layer in self.model.modules():
            if isinstance(layer, binarization.MaskedMLP):
                abs_weight = torch.abs(layer.weight)
                threshold = layer.threshold.view(abs_weight.shape[0], -1)
                abs_weight = abs_weight-threshold
                mask = layer.step(abs_weight)
                ratio = torch.sum(mask) / mask.numel() #torch.tensor.numel() returns the number of elements
                total += mask.numel()
                keep += torch.sum(mask)

            if isinstance(layer, binarization.MaskedConv2d):
                weight_shape = layer.weight.shape 
                threshold = layer.threshold.view(weight_shape[0], -1)
                weight = torch.abs(layer.weight)
                weight = weight.view(weight_shape[0], -1)
                weight = weight - threshold
                mask = layer.step(weight)
                ratio = torch.sum(mask) / mask.numel()
                total += mask.numel()
                keep += torch.sum(mask)
        self.density = (keep / total).cpu().detach().numpy()

        return (keep / total).cpu().detach().numpy()
    
    def th_update(self, global_difference):
        with torch.no_grad():
            for key in global_difference:
                if key == 'conv1.threshold':
                    weight_shape = self.model.conv1.weight.shape
                    weight = self.model.conv1.weight
                    weight = weight.view(weight_shape[0], -1)
                    
                    weight_sum_sign = torch.sign(torch.sum(weight, 1))
                    weight_sum_sign = weight_sum_sign.view(weight_shape[0], -1) #match dim with threhsold difference
                    weight_sum_sign = torch.mul(weight_sum_sign, torch.ones(weight.shape).to(self.device)) #allocates direction fo weights
                    
                    threshold_dir = global_difference[key].view(weight_shape[0],-1)
                    threshold_dir = torch.mul(threshold_dir, weight_sum_sign)
                    update_direction = threshold_dir * 1/(self.model.conv1.kernel_size[0]**2)
                    update_direction = update_direction.view(weight_shape)
                    self.model.conv1.weight += update_direction * (-1)


                if key == 'conv2.threshold':

                    weight_shape = self.model.conv2.weight.shape
                    weight = self.model.conv2.weight
                    weight = weight.view(weight_shape[0], -1)
                    
                    weight_sum_sign = torch.sign(torch.sum(weight, 1))
                    weight_sum_sign = weight_sum_sign.view(weight_shape[0], -1) #match dim with threhsold difference
                    weight_sum_sign = torch.mul(weight_sum_sign, torch.ones(weight.shape).to(self.device)) #allocates direction fo weights
                    
                    threshold_dir = global_difference[key].view(weight_shape[0],-1)
                    threshold_dir = torch.mul(threshold_dir, weight_sum_sign)
                    update_direction = threshold_dir * 1/(self.model.conv2.kernel_size[0]**2)
                    update_direction = update_direction.view(weight_shape)
                    self.model.conv2.weight += update_direction * (-1)


                if key == 'fc3.threshold':
                    weight_shape = self.model.fc3.weight.shape
                    weight = self.model.fc3.weight
                    threshold = global_difference[key].view(weight_shape[0], -1)
                    
                    weight_sum_sign = torch.sign(torch.sum(weight, 1))
                    weight_sum_sign = weight_sum_sign.view(weight_shape[0], -1)
                    weight_sum_sign = torch.mul(weight_sum_sign, torch.ones(weight.shape).to(self.device))

                    threshold_dir = global_difference[key].view(weight_shape[0], -1)
                    threshold_dir = torch.mul(threshold_dir, weight_sum_sign )
                    update_direction = threshold_dir *1/(self.model.fc3.in_size)
                    self.model.fc3.weight += update_direction *(-1)

                if key == 'fc4.threshold':

                    weight_shape = self.model.fc4.weight.shape
                    weight = self.model.fc4.weight
                    threshold = global_difference[key].view(weight_shape[0], -1)
                    
                    weight_sum_sign = torch.sign(torch.sum(weight, 1))
                    weight_sum_sign = weight_sum_sign.view(weight_shape[0], -1)
                    weight_sum_sign = torch.mul(weight_sum_sign, torch.ones(weight.shape).to(self.device))

                    threshold_dir = global_difference[key].view(weight_shape[0], -1)
                    threshold_dir = torch.mul(threshold_dir, weight_sum_sign )
                    update_direction = threshold_dir *1/(self.model.fc4.in_size)
                    self.model.fc4.weight += update_direction *(-1)

    def FLOP_count_weight(self):
        if self.args.mask == 1:
            for key in self.ratio_per_layer:
                if key == 'conv1':
                    forward_conv1 = (self.args.local_epoch-1) *self.ratio_per_layer[key] * 20*5*5 * len(self.tr_loader.dataset) * 24*24 
                    backward_conv1 = (1 + self.ratio_per_layer[key]) * forward_conv1
                    self.FLOPs += forward_conv1 + backward_conv1

                elif key =='conv2':
                    forward_conv2 = (self.args.local_epoch-1) *self.ratio_per_layer[key] * 50 * 20 * 5 * 5 * len(self.tr_loader.dataset) * 8 *8
                    backward_conv2 = (1 + self.ratio_per_layer[key]) * forward_conv2 
                    self.FLOPs += forward_conv2 + backward_conv2

                elif key =='fc3':
                    forward_fc3 = (self.args.local_epoch-1) *self.ratio_per_layer[key] * 4*4*50 * 500 * len(self.tr_loader.dataset)
                    backward_fc3 = (1 + self.ratio_per_layer[key]) * forward_fc3
                    self.FLOPs += forward_fc3 + backward_fc3

                elif key == 'fc4':
                    forward_fc4 = (self.args.local_epoch-1) *self.ratio_per_layer[key]* 500 * 10 * len(self.tr_loader.dataset)
                    backward_fc4 = (1 + self.ratio_per_layer[key]) * forward_fc4
                    self.FLOPs += forward_fc4 + backward_fc4
            
            self.FLOPs += (self.num_weights * self.density) * len(self.tr_loader) + self.num_weights*1.5

        else:
            forward_conv1 = (self.args.local_epoch) * 20*5*5 * len(self.tr_loader.dataset) * 24*24 
            backward_conv1 = (2) * forward_conv1
            self.FLOPs += forward_conv1 + backward_conv1

            forward_conv2 = (self.args.local_epoch) * 50 * 20 * 5 * 5 * len(self.tr_loader.dataset) * 8 *8
            backward_conv2 = (2) * forward_conv2 
            self.FLOPs += forward_conv2 + backward_conv2

            forward_fc3 = (self.args.local_epoch) * 4*4*50 * 500 * len(self.tr_loader.dataset)
            backward_fc3 = (2) * forward_fc3
            self.FLOPs += forward_fc3 + backward_fc3

            forward_fc4 = (self.args.local_epoch) * 500 * 10 * len(self.tr_loader.dataset)
            backward_fc4 = (2) * forward_fc4
            self.FLOPs += forward_fc4 + backward_fc4
            
            #######weightupdate#######
            self.FLOPs += (self.num_weights * 1) * len(self.tr_loader) 

    def FLOP_count_threshold(self):
        if self.args.mask == 1:
            for key in self.ratio_per_layer:
                if key == 'conv1':
                    forward_conv1 = self.ratio_per_layer[key] * 20*5*5 * len(self.tr_loader.dataset) * 24*24 
                    backward_conv1 = (1 + self.ratio_per_layer[key]) * forward_conv1 
                    self.FLOPs += forward_conv1 + backward_conv1

                elif key =='conv2':
                    forward_conv2 = self.ratio_per_layer[key]* 50 * 20 * 5 * 5 * len(self.tr_loader.dataset) * 8 *8
                    backward_conv2 = (1 + self.ratio_per_layer[key]) * forward_conv2 
                    self.FLOPs += forward_conv2 + backward_conv2

                elif key =='fc3':
                    forward_fc3 = self.ratio_per_layer[key] * 4*4*50 * 500 * len(self.tr_loader.dataset)
                    backward_fc3 = (1 + self.ratio_per_layer[key]) * forward_fc3 
                    self.FLOPs += forward_fc3 + backward_fc3

                elif key == 'fc4':
                    forward_fc4 = self.ratio_per_layer[key] * 500 * 10 * len(self.tr_loader.dataset)
                    backward_fc4 = (1 + self.ratio_per_layer[key]) * forward_fc4 
                    self.FLOPs += forward_fc4 + backward_fc4
            
            #######threshold_update#######
            self.FLOPs += self.num_thresholds * len(self.tr_loader)  + self.num_weights
