import torch
import torch.nn as nn
import numpy as np
import random
import sys
import os
import copy
from collections import OrderedDict

path = os.getcwd() #current path
sys.path.append(os.path.abspath(os.path.join(path, os.pardir))) #import the parent directory

from model import binarization


class Client():
    def __init__(self, args, model, loss, client_id, tr_loader, te_loader, device, scheduler = None):
        self.args = args
        self.model = model
        self.loss = loss
        self.scheduler = scheduler
        self.client_id = client_id
        self.tr_loader = tr_loader
        self.te_loader = te_loader
        self.device = device
        self.optimizer = None
        # self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[80, 120], gamma=0.1)
        self.FLOPs = 0
        self.ratio_per_layer = {}
        self.density = 1
        self.num_weights, self.num_thresholds = self.get_model_numbers()
        self.sequence_length =28
        self.input_size = 28
    
    def repackage_hidden(self, h):
        """Wraps hidden states in new Tensors, to detach them from their history."""
        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(self.repackage_hidden(v) for v in h)
    
    def local_training(self, comm_rounds):
        """
        Flow: it freezes parameters or thresholds in a given model by conditioning on _iter
        Return: trained model 
        """
        if self.args.optim == 'adam':
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        else:
            self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.learning_rate, momentum=self.args.momentum)
        if self.args.mask:

            for epoch in range(1, self.args.local_epoch+1):


                if epoch < self.args.local_epoch: #freeze thresholds and train only parameters
                    hidden = self.model.init_hidden(self.args.batch_size)
                    
                    for data, label in self.tr_loader:
                        data.to(self.device), label.to(self.device)
                        data = data.reshape(-1, self.sequence_length, self.input_size)

                        hidden = self.repackage_hidden(hidden)

                        self.model.train()
                        for name, param in self.model.named_parameters():
                            if name.find('threshold') != -1: #train only parameters
                                # print(name, param)
                                param.requires_grad = False
                            elif name.find('weight') != -1:
                                param.requires_grad = True 


                        output, hidden = self.model(data, hidden)
                        loss_val = self.loss(output, label)
                        # print("-"*30, "Client" , self.client_id , " loss" , loss_val, "-"*30)
                        self.optimizer.zero_grad()
                        loss_val.backward(retain_graph=True)
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
                        self.optimizer.step()


                        if self.scheduler is not None:
                            self.scheduler.step()
                    
            
                if epoch == self.args.local_epoch: #Train only thresholds
                    hidden = self.model.init_hidden(self.args.batch_size)

                    for data, label in self.tr_loader:
                        data.to(self.device), label.to(self.device)
                        data = data.reshape(-1, self.sequence_length, self.input_size)
                        hidden = self.repackage_hidden(hidden)

                        self.model.train()
                        for name, param in self.model.named_parameters():
                            if name.find('threshold') != -1: #train only thresholds
                                param.requires_grad = True
                            elif name.find('weight') != -1:
                                param.requires_grad = False 


                        output, hidden = self.model(data, hidden)
                        loss_val = self.loss(output, label)
                        

                        sparse_regularization = torch.tensor(0.).to(self.device)
                        for name, param in self.model.named_parameters():
                            if name.find('threshold') != -1:
                                sparse_regularization += torch.sum(torch.exp(-param))
                        loss_val = loss_val + self.args.th_coeff * sparse_regularization

                        self.optimizer.zero_grad()
                        loss_val.backward()
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5)
                        self.optimizer.step()
                    
                        if self.scheduler is not None:
                            self.scheduler.step()
                            
            self.FLOP_count_weight()
            self.FLOP_count_threshold()

        else: #Do like FedAvg
            for epoch in range(1, self.args.local_epoch+1):
                hidden = self.model.init_hidden(self.args.batch_size)
                for data, label in self.tr_loader:
                    data.to(self.device), label.to(self.device)

                    data = data.reshape(-1, self.sequence_length, self.input_size)
                    hidden = self.repackage_hidden(hidden)

                    self.model.train()
                    output , hidden = self.model(data, hidden)
                    loss_val = self.loss(output, label)

                    self.optimizer.zero_grad()
                    loss_val.backward()
                    self.optimizer.step()

                    if self.scheduler is not None:
                        self.scheduler.step()
            self.FLOP_count_weight()

    def local_test(self):

        total_acc = 0.0
        num = 0
        self.model.eval()
        std_loss = 0. 
        iteration = 0.
        with torch.no_grad():
            hidden = self.model.init_hidden(self.args.batch_size)
            print("# of test batches:", len(self.te_loader))
            if len(self.te_loader) == 0:
                if self.args.mask == 1:
                    density = self.get_density()
                    return 'zero', 'zero', density
                else:
                    return 'zero', 'zero'
                
            for data, label in self.te_loader:
                data, label = data.to(self.device), label.to(self.device)
                data = data.reshape(-1, self.sequence_length, self.input_size)
                hidden = self.repackage_hidden(hidden)

                output, hidden = self.model(data, hidden)
                _, predicted = torch.max(output.data, 1)
                te_acc = (predicted.cpu().numpy()== label.cpu().numpy()).astype(np.float32).sum()

                total_acc += te_acc
                num += label.shape[0]

                std_loss += self.loss(output, label)
                iteration += 1
        std_acc = total_acc/num*100.
        std_loss /= iteration


        if self.args.mask:
            density = self.get_density()
            print("-"*30, "Client" , self.client_id , "acc" , std_acc,  "density", density, "-"*30)
            return std_acc, std_loss, density
           

        else:
            return std_acc, std_loss
        
    def local_fedavg_test(self, global_model):
        
        local_model = OrderedDict()
        for key in self.model.state_dict().keys():
            local_model[key] = self.model.state_dict()[key] #retain the current local model
        
        self.model = copy.deepcopy(global_model)
        total_acc = 0.0
        num = 0
        self.model.eval()
        std_loss = 0. 
        iteration = 0.
        with torch.no_grad():
            hidden = self.model.init_hidden(self.args.batch_size)
            print("# of test batches:", len(self.te_loader))
            if len(self.te_loader) == 0:
                    return 'zero', 'zero'
                
            for data, label in self.te_loader:
                data, label = data.to(self.device), label.to(self.device)
                data = data.reshape(-1, self.sequence_length, self.input_size)
                hidden = self.repackage_hidden(hidden)

                output, hidden = self.model(data, hidden)
                _, predicted = torch.max(output.data, 1)
                te_acc = (predicted.cpu().numpy()== label.cpu().numpy()).astype(np.float32).sum()

                total_acc += te_acc
                num += label.shape[0]

                std_loss += self.loss(output, label)
                iteration += 1
        std_acc = total_acc/num*100.
        std_loss /= iteration

        self.model.load_state_dict(local_model) #go back to your model

        return std_acc, std_loss

    def get_model_numbers(self):
        total_weights = 0
        total_thresholds = 0
        if self.args.mask==1:
            total_weights += self.model.lstm1.cell.total_number
            total_weights += self.model.lstm2.cell.total_number
            total_weights += 128*10 #fc layer
            total_thresholds += self.model.lstm1.cell.total_threshold_number
            total_thresholds += self.model.lstm2.cell.total_threshold_number

        else:
            total_weights += self.model.lstm1.cell.total_number
            total_weights += self.model.lstm2.cell.total_number
            total_weights += 128*10 #fc layer

        return total_weights, total_thresholds

    def get_density(self):

        density_1 = self.model.lstm1.cell.keep_ratio
        density_2 = self.model.lstm2.cell.keep_ratio
        density = 1/2 * (density_1 + density_2)

        return density
    
    def th_update(self, global_difference):
        with torch.no_grad():
            cached_naemd = None
            for name, params in self.model.named_parameters():
                if name.find('weight') != -1:
                    cached_naemd = name
                    weight_shape = params.shape
                    weight = params
                    transposed_shape = weight.transpose(0, 1).shape
                    weight_sum_sign = torch.sign(torch.sum(weight, 0))
                    weight_sum_sign = weight_sum_sign.view(weight_shape[1], -1)
                    weight_sum_sign = torch.mul(weight_sum_sign, torch.ones(transposed_shape).to(self.device))
                if name.find('threshold') != -1:
                    threshold_name = name
                    threshold_dir = global_difference[threshold_name].view(weight_shape[1], -1)
                    threshold_dir = torch.mul(threshold_dir, weight_sum_sign)
                    update_direction = threshold_dir *1/(weight_shape[1])

                    update_direction = update_direction.view(weight_shape)
                    self.model.state_dict()[cached_naemd] +=update_direction * (-1)


    def FLOP_count_weight(self): #Conv4
        if self.args.mask == 1:
            LSTM1_FLOP = 4*(self.args.local_epoch-1) *self.model.lstm1.cell.hidden_size * (self.model.lstm1.cell.input_size * self.model.lstm1.cell.weight_ratio
                                                                +self.model.lstm1.cell.hidden_size *self.model.lstm1.cell.hidden_ratio )  * len(self.tr_loader.dataset) 
            LSTM1_FLOP_back = (1 + (self.model.lstm1.cell.keep_ratio )) * LSTM1_FLOP

            LSTM2_FLOP = 4*(self.args.local_epoch-1) *self.model.lstm2.cell.hidden_size * (self.model.lstm2.cell.input_size * self.model.lstm2.cell.weight_ratio
                                                                +self.model.lstm2.cell.hidden_size *self.model.lstm2.cell.hidden_ratio ) * len(self.tr_loader.dataset) 
            LSTM2_FLOP_back = (1 + (self.model.lstm2.cell.keep_ratio )) * LSTM2_FLOP     
            self.FLOPs += LSTM1_FLOP + LSTM2_FLOP + LSTM1_FLOP_back +LSTM2_FLOP_back       
            #######weightupdate with imporatnce extraction#######
            self.FLOPs += (self.num_weights * self.density) * len(self.tr_loader) + self.num_weights*1.5

        else:
            LSTM1_FLOP = 4*(self.args.local_epoch) *self.model.lstm1.cell.hidden_size * (self.model.lstm1.cell.input_size
                                                                +self.model.lstm1.cell.hidden_size) * len(self.tr_loader.dataset) 
            LSTM1_FLOP_back = (2) * LSTM1_FLOP

            LSTM2_FLOP = 4*(self.args.local_epoch) *self.model.lstm2.cell.hidden_size * (self.model.lstm2.cell.input_size 
                                                                +self.model.lstm2.cell.hidden_size) * len(self.tr_loader.dataset) 
            LSTM2_FLOP_back = (2) * LSTM2_FLOP     
            self.FLOPs += LSTM1_FLOP + LSTM2_FLOP + LSTM1_FLOP_back +LSTM2_FLOP_back         
            #######weightupdate#######
            self.FLOPs += (self.num_weights * 1) * len(self.tr_loader) 


    def FLOP_count_threshold(self): 
        if self.args.mask == 1:
            LSTM1_FLOP = 4*self.model.lstm1.cell.hidden_size * (self.model.lstm1.cell.input_size * self.model.lstm1.cell.weight_ratio
                                                                +self.model.lstm1.cell.hidden_size *self.model.lstm1.cell.hidden_ratio )
            LSTM1_FLOP_back = (1 + (self.model.lstm1.cell.keep_ratio )) * LSTM1_FLOP

            LSTM2_FLOP = 4*self.model.lstm2.cell.hidden_size * (self.model.lstm2.cell.input_size * self.model.lstm2.cell.weight_ratio
                                                                +self.model.lstm2.cell.hidden_size *self.model.lstm2.cell.hidden_ratio )
            LSTM2_FLOP_back = (1 + (self.model.lstm2.cell.keep_ratio )) * LSTM2_FLOP     
            self.FLOPs += LSTM1_FLOP + LSTM2_FLOP + LSTM1_FLOP_back +LSTM2_FLOP_back  
            
            #######threshold_update#######
            self.FLOPs += self.num_thresholds * len(self.tr_loader) + self.num_weights          