import copy
import logging
import math
import time

import numpy as np
import torch

class Client:

    def __init__(self, client_idx, local_training_data, local_test_data, local_sample_number, args, device,
                 model_trainer):
        self.client_idx = client_idx
        self.local_training_data = local_training_data
        self.local_test_data = local_test_data
        self.local_sample_number = local_sample_number
        logging.info("self.local_sample_number = " + str(self.local_sample_number))
        self.args = args
        self.device = device
        self.model_trainer = model_trainer

    def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number):
        self.client_idx = client_idx
        self.local_training_data = local_training_data
        self.local_test_data = local_test_data
        self.local_sample_number = local_sample_number

    def get_sample_number(self):
        return self.local_sample_number

    def train(self, w, masks,round):


        # downlink params
        num_comm_params = self.model_trainer.count_communication_params(w)
        # apply global model weights according to the new masks
        self.model_trainer.set_model_params(w)
        self.model_trainer.set_masks(masks)
        self.model_trainer.set_id(self.client_idx)
        begin = time.time()
        self.model_trainer.train(self.local_training_data, self.device, self.args, round)
        end = time.time()
        # record train time
        if round ==1:
            logging.info("train time elapse: {}".format(end-begin))

        weights = self.model_trainer.get_model_params()
        update= {}
        for name in weights:
            update[name] = w[name] - weights[name]
        logging.info("after train{}".format(sum([torch.count_nonzero(weights[name]) for name in weights])/sum([torch.numel(weights[name]) for name in weights])))
        logging.info("-----------------------------------")
        gradient = None
        search_begin = time.time()
        if not self.args.static:
            # do mask searching
            if not self.args.dis_gradient_check:
                gradient = self.model_trainer.screen_gradients(self.local_training_data, self.device)
            masks, num_remove = self.fire_mask(masks, weights, round)
            # logging.info("masked2 {}".format(sum([torch.count_nonzero(masks[name]) for name in masks])))
            masks = self.regrow_mask(masks, num_remove, gradient)
        search_end = time.time()
        # record mask searching time
        if round == 1:
            logging.info("mask searching time elapse: {}".format(search_end - search_begin))
        sparse_flops_per_data = self.model_trainer.count_training_flops_per_sample()
        full_flops = self.model_trainer.count_full_flops_per_sample()
        logging.info("training flops per data {}".format(sparse_flops_per_data))
        logging.info("full flops for search {}".format(full_flops))
        # we train the data for `self.args.epochs` epochs, and forward one batch of data with full density to screen gradient.
        training_flops = self.args.epochs*self.local_sample_number*sparse_flops_per_data+\
                         self.args.batch_size* full_flops

        # uplink params
        num_comm_params += self.model_trainer.count_communication_params(update)
        return masks,  update, training_flops, num_comm_params

    # prune out a portion of weights
    def fire_mask(self, masks, weights, round):
        drop_ratio = self.args.anneal_factor / 2 * (1 + np.cos((round * np.pi) / self.args.comm_round))
        new_masks = copy.deepcopy(masks)
        num_remove = {}
        for name in masks:
            num_non_zeros = torch.sum(masks[name])
            num_remove[name] = math.ceil(drop_ratio * num_non_zeros)
            temp_weights = torch.where(masks[name] > 0, torch.abs(weights[name]), 100000 * torch.ones_like(weights[name]))
            x, idx = torch.sort(temp_weights.view(-1).to(self.device))
            new_masks[name].view(-1)[idx[:num_remove[name]]] = 0
        return new_masks, num_remove

    # regrow a portion of weights
    def regrow_mask(self, masks,  num_remove, gradient=None):
        new_masks = copy.deepcopy(masks)
        for name in masks:
            if not self.args.dis_gradient_check:
                temp = torch.where(masks[name] == 0, torch.abs(gradient[name]), -100000 * torch.ones_like(gradient[name]))
                sort_temp, idx = torch.sort(temp.view(-1).to(self.device), descending=True)
                new_masks[name].view(-1)[idx[:num_remove[name]]] = 1
            else:
                temp = torch.where(masks[name] == 0, torch.ones_like(masks[name]),torch.zeros_like(masks[name]) )
                idx = torch.multinomial( temp.flatten().to(self.device),num_remove[name], replacement=False)
                new_masks[name].view(-1)[idx]=1
        return new_masks


    def local_test(self, w_per, b_use_test_dataset):
        if b_use_test_dataset:
            test_data = self.local_test_data
        else:
            test_data = self.local_training_data
        # logging.info("test{}".format(sum([torch.count_nonzero(w_test[name]) for name in w_test])))
        self.model_trainer.set_model_params(w_per)
        metrics = self.model_trainer.test(test_data, self.device, self.args)
        return metrics
