from http import client
import copy
import flwr as fl
import argparse
from collections import OrderedDict
import warnings
import sys
sys.path.append('..')
warnings.filterwarnings("ignore")
import torch
from torch.utils.data import DataLoader
from dataloader.dataloader import get_client_data
from federated_train.train import train, test
from utils.logger import get_log
from utils.tool import get_device, get_model


seed = 2023 
generator = torch.Generator().manual_seed(seed)

def cal_weight(bits, alpha):
    bits = torch.sign(torch.relu(bits))
    num_bit=bits.size()[0]
    unit_mask = 2**torch.arange(num_bit-1, -1, -1).view(num_bit, *([1]*(bits.ndim-1)))
    magnitude = torch.sum(bits*unit_mask, dim=0) - 2**(num_bit-1)
    return magnitude*alpha

def cal_bits(weight, num_bit, alpha=None):
    if alpha == None:
        alpha= weight.abs().max()/2**(num_bit-1)

    weight = torch.floor(weight / alpha + torch.rand(weight.shape, generator=generator))
    weight += 2**(num_bit-1)
    bits=torch.randn_like(weight).unsqueeze(0).repeat(num_bit, *([1]*weight.ndim))
    for i in range(num_bit):
        bits[num_bit-1-i] = weight % 2
        weight=(weight/2).int()
    bits = bits*2-1
    return bits, alpha

class Client_FedBiF(fl.client.NumPyClient):
    def __init__(self, logger, trainset, valset, testset, device, id, model, train_mask):
        self.logger, self.device, self.id = logger, device, id
        self.trainset, self.valset, self.testset = trainset, valset, testset
        
        self.train_mask, self.bit = train_mask, len(train_mask)
        self.alpha_keys = None
        self.model_name = model
        self.model = get_model(model, train_mask)
    
    def set_parameters(self, parameters):
        if self.alpha_keys==None:
            model_dict = torch.load("../federated_train/pt/"+self.model_name+"-"+str(self.bit)+".pt")
            self.model.load_state_dict(model_dict)    
            self.alpha_keys = [key for key in model_dict.keys() if 'alpha' in key]
            return

        bits_parameters = []
        for val in parameters:
            if val.ndim > 1:
                bits, alpha=cal_bits(torch.tensor(val), self.bit)
                bits_parameters.append(bits)
                bits_parameters.append(alpha)
            else:
                bits_parameters.append(torch.tensor(val))
        params_dict = zip(self.model.state_dict().keys(), bits_parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        current_dict = self.model.state_dict() # please do not change current_dict, as it will change the model also
        for key in self.alpha_keys:
            k = key.replace('alpha', 'weight')
            state_dict[k] *= current_dict[k].abs() 
        self.model.load_state_dict(state_dict, strict=True)

    def get_parameters(self):
        state_dict = copy.deepcopy(self.model.state_dict())
        for k in self.alpha_keys:
            key_w = k.replace('alpha', 'weight')
            state_dict[key_w] = cal_weight(state_dict[key_w], state_dict[k])
            del state_dict[k]

        return [val.cpu().numpy() for _, val in state_dict.items()]

    def fit(self, parameters, config):
        self.set_parameters(parameters)

        if config["tune_mask"]>0:
            self.train_mask = config["train_mask"]
            self.model.reset_masks(self.train_mask)
            self.logger.info("mask is set to: "+self.train_mask)

        trainLoader = DataLoader(self.trainset, batch_size=config["batch_size"], shuffle=True)
        # valLoader = DataLoader(self.valset, batch_size=config["batch_size"])
        valLoader = None
        results = train(self.model, trainLoader, valLoader, config, self.device)
        self.logger.info("Round %d client #%d, val  loss: %.4f, val  acc: %.4f" 
                % (config["round"], self.id, results["val_loss"], results["val_accuracy"]))
        
        parameters_prime = self.get_parameters()
        num_examples_train = len(self.trainset)
        return parameters_prime, num_examples_train, results

    def evaluate(self, parameters, config):
        # self.set_parameters(parameters) # use global parameters 
        # testLoader = DataLoader(self.testset, batch_size=config["batch_size"])
        testLoader = None
        loss, accuracy = test(self.model, testLoader, None, self.device)
        self.logger.info("Round %d client #%d, test loss: %.4f, acc: %.4f" % (config["round"], self.id, loss, accuracy))
        return float(loss), len(self.testset), {"accuracy": float(accuracy)}

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="client")
    parser.add_argument("--model", type=str, default="bif-vgg11", help="model")
    parser.add_argument("--dataset", type=str, default="svhn", help="dataset")
    parser.add_argument("--part_strategy", type=str, default="labeldir0.5", help="iid labeldir0.5 labeldir0.1")
    parser.add_argument("--num_client", type=int, default=10, choices=range(2, 100), help="num_client")
    parser.add_argument("--id", type=int, default=0, choices=range(0, 100), help="client id")
    parser.add_argument("--train_mask", type=str, default="1111", help="train_mask")
    parser.add_argument("--val_ratio", type=float, default=0.1, help="dataset")        
    parser.add_argument("--gpu", type=int, default=0, help="-1 0 1")
    parser.add_argument("--ip", type=str, default="0.0.0.0:10000", help="server address")
    parser.add_argument("--log_dir", type=str, default="../log/debug/", help="dir")
    parser.add_argument("--log_name", type=str, default="debug", help="log")
    args = parser.parse_args()

    logger = get_log(args.log_dir, args.log_name+"-"+str(args.id))
    logger.info(args)
    device = get_device(args.gpu)
    trainset, valset, testset = get_client_data(args.dataset, args.part_strategy, args.num_client, args.id, args.val_ratio)

    client = Client_FedBiF(logger, trainset, valset, testset, device, args.id, args.model, args.train_mask)
    fl.client.start_numpy_client(server_address=args.ip, client=client)