from http import client
import flwr as fl
import argparse
from collections import OrderedDict
import warnings
import sys
sys.path.append('..')
warnings.filterwarnings("ignore")
import torch
from torch.utils.data import DataLoader
from dataloader.dataloader import get_client_data
from federated_basetrain.train import train, test
from utils.logger import get_log
from utils.tool import get_device, get_model

class Client_FedAvg(fl.client.NumPyClient):
    def __init__(self, logger, trainset, valset, testset, device, id, model):
        self.logger, self.device, self.id = logger, device, id
        self.trainset, self.valset, self.testset = trainset, valset, testset
        self.model = get_model(model)
    
    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)
    
    def get_parameters(self):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        trainLoader = DataLoader(self.trainset, batch_size=config["batch_size"], shuffle=True)
        # valLoader = DataLoader(self.valset, batch_size=config["batch_size"])
        valLoader = None
        results = train(self.model, trainLoader, valLoader, config, self.device)
        self.logger.info("Round %d client #%d, val  loss: %.4f, val  acc: %.4f" 
                % (config["round"], self.id, results["val_loss"], results["val_accuracy"]))
        
        parameters_prime = self.get_parameters()
        num_examples_train = len(self.trainset)
        return parameters_prime, num_examples_train, results

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="client")
    parser.add_argument("--model", type=str, default="vgg11", help="model")
    parser.add_argument("--dataset", type=str, default="svhn", help="dataset")
    parser.add_argument("--part_strategy", type=str, default="labeldir0.5", help="iid labeldir0.5 labeldir0.1")
    parser.add_argument("--num_client", type=int, default=10, choices=range(2, 100), help="num_client")
    parser.add_argument("--id", type=int, default=0, choices=range(0, 100), help="client id")
    parser.add_argument("--val_ratio", type=float, default=0.1, help="dataset")        
    parser.add_argument("--gpu", type=int, default=0, help="-1 0 1")
    parser.add_argument("--ip", type=str, default="0.0.0.0:10000", help="server address")
    parser.add_argument("--log_dir", type=str, default="../log/debug/", help="dir")
    parser.add_argument("--log_name", type=str, default="debug", help="log")
    args = parser.parse_args()

    logger = get_log(args.log_dir, args.log_name+"-"+str(args.id))
    logger.info(args)
    device = get_device(args.gpu)
    trainset, valset, testset = get_client_data(args.dataset, args.part_strategy, args.num_client, args.id, args.val_ratio)

    client = Client_FedAvg(logger, trainset, None, None, device, args.id, args.model)
    fl.client.start_numpy_client(server_address=args.ip, client=client)

        