import os
import sys
sys.path.append('..')
sys.path.append('.')
import warnings
warnings.filterwarnings("ignore")
import torch
import numpy as np
import argparse
from typing import Dict
from collections import OrderedDict
import flwr as fl
from flwr.server.strategy import FedAvg
from federated_basetrain.train import test
from dataloader.dataloader import get_server_test_dataloader
from utils.logger import get_log
from utils.tool import get_device, get_model


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="server")
    parser.add_argument("--model", type=str, default="vgg11", help="(bif-)vgg11 (bif-)resnet18 (bif-)resnet34")
    parser.add_argument("--dataset", type=str, default="svhn", help="svhn, cifar10, cifar100")
    parser.add_argument("--lr", type=float, default=0.01, help="learning rate")

    parser.add_argument("--rounds", type=int, default=100, help="rounds")
    parser.add_argument("--epochs", type=int, default=5, help="epochs")
    parser.add_argument("--batch_size", type=int, default=64, help="batch_size")
    parser.add_argument("--num_client", type=int, default=10, choices=range(2, 100), help="number of clients")  
    parser.add_argument("--save_round", type=int, default=0, help="whether to save checkpoints in each round")

    parser.add_argument("--gpu", type=int, default=1, help="-1 0 1")
    parser.add_argument("--ip", type=str, default="0.0.0.0:10000", help="server address")
    parser.add_argument("--log_dir", type=str, default="../log/debug/", help="dir")
    parser.add_argument("--log_name", type=str, default="debug", help="log")
    args = parser.parse_args()

    dada={
        "m": 0.9,
        "phi": args.rounds/10,
        "G": [0],
        "bit": 1,
        "bit_list": 1*np.ones(args.num_client),
        "used_times": 0,
        "bit_sum": 0,
    }

    def set_dada_bit(round):
        if round <=0:
            path = os.path.join(args.log_dir, "results", "loss_npy")
            if not os.path.exists(path):
                os.makedirs(path)
            path = os.path.join(args.log_dir, "results", "loss_npy", "bit.npy")
            np.save(path, dada["bit_list"])
            return
        dada["bit_sum"] += np.mean(dada["bit_list"]) 
        
        num_list = np.zeros(args.num_client)
        loss_list = np.zeros(args.num_client)
        for i in range(args.num_client):
            path = os.path.join(args.log_dir, "results", "loss_npy", str(i)+".npy")
            data = np.load(path, allow_pickle=True)
            num_list[i]=data[0]
            loss_list[i]=data[1]

        num_list = num_list/np.sum(num_list)
        g = np.sum(num_list*loss_list)
        if dada["G"][0] == 0:
            dada["G"][0] = g
        g = g * dada["m"] + (1-dada["m"])*dada["G"][-1]
        dada["G"].append(g)
        dada["used_times"] +=1

        if dada["used_times"]>dada["phi"] and dada["bit"] < 4:
            if dada["G"][-1]>=dada["G"][-int(dada["phi"])]:
                dada["bit"] *=2
                dada["used_times"] = 0
        co = np.sqrt(np.sum(num_list**(2/3))/np.sum(num_list**(2)/dada["bit"]**2))
        bits = np.round(co * num_list**(2/3))
        bits[bits < 1] = 1
        dada["bit_list"]=bits
        path = os.path.join(args.log_dir, "results", "loss_npy", "bit.npy")
        np.save(path, dada["bit_list"])

    device = get_device(args.gpu)
    logger = get_log(args.log_dir, args.log_name)
    logger.info(args)
    pt_path = os.path.join(args.log_dir, "results")
    if not os.path.exists(pt_path):
        os.makedirs(pt_path)

    model = get_model(args.model)
    model_parameters = [val.cpu().numpy() for _, val in model.state_dict().items()]
    print(model)
    
    def fit_config(server_round: int):
        config = {
            "round": server_round,
            "epochs": args.epochs,
            "batch_size": args.batch_size,
            "lr": args.lr,       
        }
        return config

    def evaluate_config(server_round: int):
        config = {
            "round": server_round,
            "batch_size": args.batch_size,
        }
        return config
    
    record = {"accuracy":[], "loss":[]}

    def get_evaluate_fn(model: torch.nn.Module, dataset: str):
        test_loader = get_server_test_dataloader(dataset, batch_size=args.batch_size)
        def evaluate(server_round: int, parameters: fl.common.NDArrays, config: Dict[str, fl.common.Scalar]):
            params_dict = zip(model.state_dict().keys(), parameters)
            state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
            model.load_state_dict(state_dict, strict=True)
            print("Starting server evalutation...")
            loss, accuracy = test(model, test_loader, None, device)

            record["accuracy"].append(accuracy)
            record["loss"].append(loss)
            logger.info("Round %d - server test loss:%.4f; acc:%.4f" %(server_round, loss, accuracy))

            if args.save_round:
                torch.save(model.state_dict(), os.path.join(pt_path, str(server_round)+".pt"))

            if accuracy >= np.max(np.array(record["accuracy"])):
                torch.save(model.state_dict(), os.path.join(pt_path, "best.pt"))
            
            if "dadaquant" in args.log_name:
                set_dada_bit(server_round)
                if server_round>0:
                    logger.info("Average bit: %.4f" %(dada["bit_sum"]/server_round))

            return loss, {"accuracy": accuracy}
        return evaluate

    strategy = FedAvg(
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_fit_clients=args.num_client,
        min_evaluate_clients=args.num_client,
        min_available_clients=args.num_client,
        evaluate_fn=get_evaluate_fn(model, args.dataset),
        on_fit_config_fn=fit_config,
        on_evaluate_config_fn=evaluate_config,
        initial_parameters=fl.common.ndarrays_to_parameters(model_parameters),
    )

    fl.server.start_server(
        server_address=args.ip,
        config=fl.server.ServerConfig(num_rounds=args.rounds),
        strategy=strategy,
    )

    best_round = np.argmax(np.array(record["accuracy"]))
    best_acc = record["accuracy"][best_round]
    best_loss = record["loss"][best_round]    
    np.save(pt_path, record)
    logger.info("Best round: %d; Best acc: %.4f; Best loss: %.4f" %(best_round, best_acc, best_loss))

    if "feddq" in args.log_name:
        bit_sum=0
        for i in range(args.num_client):
            path = os.path.join(args.log_dir, "results", str(i)+".npy")
            data = np.load(path, allow_pickle=True)
            bit_sum += data
        logger.info("Avg bit: %4f." %(bit_sum/args.num_client/args.rounds))
