import io
import copy
import time
import math
import logging
import numpy as np
import torch.nn as nn
from .misc import *
from typing import Any
from mpi4py import MPI
from .algorithm import *
from torch.optim import *
from omegaconf import DictConfig
from torch.utils.data import DataLoader

def run_server(
    cfg: DictConfig,
    comm: MPI.Comm,
    model: nn.Module,
    loss_fn: nn.Module,
    num_clients: int,
    test_dataset: Dataset = Dataset(),
    dataset_name: str = "appfl",
    flamby_metric: Any = None,
):
    comm_size = comm.Get_size()
    device = "cpu"

    """ log for a server """
    logger = logging.getLogger(__name__)
    logger = create_custom_logger(logger, cfg)
    cfg["logginginfo"]["comm_size"] = comm_size
    cfg["logginginfo"]["DataSet_name"] = dataset_name

    "Run validation if test data is given or the configuration is enabled."
    if cfg.validation == True and len(test_dataset) > 0:
        test_dataloader = DataLoader(
            test_dataset,
            num_workers=cfg.num_workers,
            batch_size=cfg.test_data_batch_size,
            shuffle=cfg.test_data_shuffle,
        )
    else:
        cfg.validation = False

    num_data = comm.gather(0, root=0)
    total_num_data = 0
    for rank in range(1, comm_size):
        for val in num_data[rank].values():
            total_num_data += val

    weight = []
    weights = {}
    for rank in range(comm_size):
        if rank == 0:
            weight.append(0)
        else:
            temp = {}
            for key in num_data[rank].keys():
                temp[key] = num_data[rank][key] / total_num_data
                weights[key] = temp[key]
            weight.append(temp)

    weight = comm.scatter(weight, root=0)

    if cfg.fed.args.do_simulation:
        if hasattr(cfg.fed.args, 'use_hetero_seed'):
            if cfg.fed.args.use_hetero_seed:
                np.random.seed(cfg.fed.args.seed)
            else:
                np.random.seed(42)
        else:
            np.random.seed(42)
        if cfg.fed.args.simulation_distrib == 'normal':
            while True:
                tpb = np.random.normal(loc=cfg.fed.args.avg_tpb, scale=cfg.fed.args.avg_tpb*cfg.fed.args.global_std_scale, size=comm_size)
                if np.all(tpb > 0): 
                    tpb = list(tpb)
                    break
        elif cfg.fed.args.simulation_distrib == 'homo':
            tpb = np.random.normal(loc=cfg.fed.args.avg_tpb, scale=0, size=comm_size)
        elif cfg.fed.args.simulation_distrib == 'exp':
            random_numbers = np.random.exponential(scale=cfg.fed.args.exp_scale, size=comm_size)
            rounded_numbers = np.round((random_numbers+cfg.fed.args.exp_bin_size)/cfg.fed.args.exp_bin_size) * cfg.fed.args.exp_bin_size
            tpb = list(rounded_numbers * (cfg.fed.args.avg_tpb/cfg.fed.args.exp_scale))
        else:
            raise NotImplementedError
        _ = comm.scatter(tpb, root=0)

    # Asynchronous federated learning server (aggregator)
    server = ServerFedAT(
        weights, 
        copy.deepcopy(model), 
        loss_fn, num_clients, 
        device,
        **cfg.fed.args
    )
    server.model.to("cpu")

    # Convert the model to bytes
    global_model = server.model.state_dict()
    gloabl_model_buffer = io.BytesIO()
    torch.save(global_model, gloabl_model_buffer)
    global_model_bytes = gloabl_model_buffer.getvalue()

    if cfg.fed.args.delta_warmup:
        warmup_steps = cfg.fed.args.local_steps
        for i in range(1, num_clients+1):
            comm.send((len(global_model_bytes), False, warmup_steps, cfg.fed.args.optim_args.lr), dest=i, tag=i)
        send_reqs = [comm.Isend(np.frombuffer(global_model_bytes, dtype=np.byte), dest=i, tag=i+comm_size) for i in range(1, num_clients+1)]
        recv_reqs = [comm.irecv(source=i, tag=i) for i in range(1, num_clients+1)]
        MPI.Request.waitall(recv_reqs)
        logger.info('Finish warming up')

    # Send (buffer size, finish flag) - INFO - to all clients in a blocking way
    warmup_steps = max(math.floor(0.2 * cfg.fed.args.local_steps), 1)
    for i in range(1, num_clients+1):
        comm.send((len(global_model_bytes), False, warmup_steps, cfg.fed.args.optim_args.lr), dest=i, tag=i)      # dest is the rank of the 
    send_reqs = [comm.Isend(np.frombuffer(global_model_bytes, dtype=np.byte), dest=i, tag=i+comm_size) for i in range(1, num_clients+1)]
    recv_reqs = [comm.irecv(source=i, tag=i) for i in range(1, num_clients+1)]
    scheduler = SchedulerFedAT(comm, server, cfg.fed.args.local_steps, num_clients, cfg.num_epochs, cfg.fed.args.optim_args.lr, cfg.fed.args.speed_ratio, logger)    
    start_time = time.time()
    client_counter = 0
    while client_counter < num_clients:
        client_idx, local_model_size = MPI.Request.waitany(recv_reqs)
        if client_idx != MPI.UNDEFINED:
            scheduler.speed_record(local_model_size, client_idx, time.time()-start_time)
            client_counter += 1

    recv_reqs = [comm.irecv(source=i, tag=i) for i in range(1, num_clients+1)]
    global_step, test_loss, test_accuracy, best_accuracy = 0, 0.0, 0.0, 0.0
    metric = [[0], [0]]
    while True:
        client_idx, local_model_size = MPI.Request.waitany(recv_reqs)
        if client_idx != MPI.UNDEFINED:
            global_step += 1
            scheduler.local_update(local_model_size, client_idx)
            recv_reqs.pop(client_idx)
            if global_step < cfg.num_epochs:
                recv_reqs.insert(client_idx, comm.irecv(source=client_idx+1, tag=client_idx+1))
            if (scheduler.validation_flag and global_step % cfg.fed.args.val_range == 0) or global_step >= cfg.num_epochs:
                validation_start = time.time()
                if cfg.validation == True:
                    test_loss, test_accuracy = validation(server, test_dataloader, flamby_metric)
                    if test_accuracy > best_accuracy:
                        best_accuracy = test_accuracy
                cfg["logginginfo"]["Validation_time"] = time.time() - validation_start
                cfg["logginginfo"]["PerIter_time"] = 0 # TODO
                cfg["logginginfo"]["Elapsed_time"] = time.time() - start_time
                cfg["logginginfo"]["test_loss"] = test_loss
                cfg["logginginfo"]["test_accuracy"] = test_accuracy
                cfg["logginginfo"]["BestAccuracy"] = best_accuracy
                cfg["logginginfo"]["LocalUpdate_time"] = 0 # TODO
                cfg["logginginfo"]["GlobalUpdate_time"] = 0 # TODO
                # logger.info(f"[Server Log] [Step #{global_step:3}] Iteration Logs:")
                if global_step != 1:
                    logger.info(server.log_title())
                server.logging_iteration(cfg, logger, global_step-1)
                metric[0].append(cfg["logginginfo"]["Elapsed_time"])
                metric[1].append(test_accuracy)
            if global_step == cfg.num_epochs: 
                break
    
    # Cancel outstanding requests
    for recv_req in recv_reqs:
        recv_req.cancel()

    # Send a finished indicator to all clients
    send_reqs = [comm.isend((0, True, -1, -1), dest=i, tag=i) for i in range(1, num_clients+1)]
    MPI.Request.waitall(send_reqs)

    server.logging_summary(cfg, logger)
    save_training_metric(metric, cfg)

def run_client(
    cfg: DictConfig,
    comm: MPI.Comm,
    model: nn.Module,
    loss_fn: nn.Module,
    num_clients: int,
    train_data: Dataset,
    test_data: Dataset = Dataset(),
    flamby_metric: Any = None
):
    """Run PPFL simulation clients, each of which updates its own local parameters of model

    Args:
        cfg (DictConfig): the configuration for this run
        comm: MPI communicator
        model (nn.Module): neural network model to train
        num_clients (int): the number of clients used in PPFL simulation
        train_data (Dataset): training data
        test_data (Dataset): testing data
    """
    comm_size = comm.Get_size()
    comm_rank = comm.Get_rank()

    num_client_groups = np.array_split(range(num_clients), comm_size - 1)

    """ log for clients"""
    outfile = {}
    for _, cid in enumerate(num_client_groups[comm_rank - 1]):
        output_filename = cfg.output_filename + "_client_%s" % (cid)
        outfile[cid] = client_log(cfg.output_dirname, output_filename)

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    c_handler = logging.StreamHandler()
    logger.addHandler(c_handler)

    """
    Send the number of data to a server
    Receive "weight_info" from a server      
    """
    num_data = {}
    for _, cid in enumerate(num_client_groups[comm_rank - 1]):
        num_data[cid] = len(train_data[cid])
    comm.gather(num_data, root=0)
    weight = None
    weight = comm.scatter(weight, root=0)

    if cfg.fed.args.do_simulation:
        time_per_batch = None
        time_per_batch = comm.scatter(time_per_batch, root=0)
        np.random.seed(cfg.fed.args.seed * comm_rank)

    batchsize = {}
    for _, cid in enumerate(num_client_groups[comm_rank - 1]):
        batchsize[cid] = cfg.train_data_batch_size
        if cfg.batch_training == False:
            batchsize[cid] = len(train_data[cid])

    "Run validation if test data is given or the configuration is enabled."
    if cfg.validation == True and len(test_data) > 0:
        test_dataloader = DataLoader(
            test_data,
            num_workers=cfg.num_workers,
            batch_size=cfg.test_data_batch_size,
            shuffle=cfg.test_data_shuffle,
        )
    else:
        cfg.validation = False
        test_dataloader = None

    cid = num_client_groups[comm_rank - 1][0]

    client = eval(cfg.fed.clientname)(
        cid,
        weight[cid],
        copy.deepcopy(model),
        loss_fn,
        DataLoader(
            train_data[cid],
            num_workers=cfg.num_workers,
            batch_size=batchsize[cid],
            shuffle=True,
            pin_memory=True,
        ),
        cfg,
        outfile[cid],
        test_dataloader,
        metric = flamby_metric,
        **cfg.fed.args,
    )

    ######## Warmup on Delta: For fair and reproducible experiment results #########
    if cfg.fed.args.delta_warmup:
        warmup_start = time.time()
        global_model_size, done, num_local_steps, lr = comm.recv(source=0, tag=comm_rank)
        client.local_steps = num_local_steps
        client.optim_args.lr = lr
        global_model_bytes = np.empty(global_model_size, dtype=np.byte)
        comm.Recv(global_model_bytes, source=0, tag=comm_rank+comm_size)
        global_model_buffer = io.BytesIO(global_model_bytes.tobytes())
        global_model = torch.load(global_model_buffer)
        client.model.load_state_dict(global_model)
        client.update()
        comm.send(0, dest=0, tag=comm_rank)
        logger.info(f"Clinet {comm_rank-1} finishes the warmup in {time.time()-warmup_start} sec")
    ################################################################################

    # FedAsync: main local training loop
    while True:
        # Receive model size from the server
        global_model_size, done, num_local_steps, lr = comm.recv(source=0, tag=comm_rank)
        client.local_steps = num_local_steps
        client.optim_args.lr = lr
        if done: 
            break
        # Allocate a buffer to receive the byte stream
        global_model_bytes = np.empty(global_model_size, dtype=np.byte)
        # Receive the byte stream
        comm.Recv(global_model_bytes, source=0, tag=comm_rank+comm_size)
        if cfg.fed.args.do_simulation:
            if cfg.fed.args.speed_change_simulation:
                if np.random.uniform(0,1) <= cfg.fed.args.speed_change_prob:
                    if cfg.fed.args.simulation_distrib == 'normal':
                        while True:
                            tpb = np.random.normal(loc=cfg.fed.args.avg_tpb, scale=cfg.fed.args.avg_tpb*cfg.fed.args.global_std_scale, size=1)
                            if np.all(tpb > 0): 
                                tpb = list(tpb)
                                break
                    elif cfg.fed.args.simulation_distrib == 'homo':
                        tpb = np.random.normal(loc=cfg.fed.args.avg_tpb, scale=0, size=1)
                    elif cfg.fed.args.simulation_distrib == 'exp':
                        random_numbers = np.random.exponential(scale=cfg.fed.args.exp_scale, size=1)
                        rounded_numbers = np.round((random_numbers+cfg.fed.args.exp_bin_size)/cfg.fed.args.exp_bin_size) * cfg.fed.args.exp_bin_size
                        tpb = list(rounded_numbers * (cfg.fed.args.avg_tpb/cfg.fed.args.exp_scale))
                    else:
                        raise NotImplementedError
                    time_per_batch = tpb[0]
            local_training_time = np.random.normal(loc=time_per_batch, scale=cfg.fed.args.local_std_scale*time_per_batch)
            local_training_time *= num_local_steps
        start_time = time.time()

        # Load the byte to state dict
        global_model_buffer = io.BytesIO(global_model_bytes.tobytes())
        global_model = torch.load(global_model_buffer)
        logger.info(f"f[Client Log] [Client #{comm_rank-1}] Global model device {next(iter(global_model.values())).device}")

        # Train the model
        client.model.load_state_dict(global_model)
        client.update()

        # Compute gradient if the algorithm is gradient-based
        if cfg.fed.args.gradient_based:
            list_named_parameters = []
            for name, _ in client.model.named_parameters():
                list_named_parameters.append(name)
            local_model = {}
            for name in global_model:
                if name in list_named_parameters:
                    local_model[name] = global_model[name] - client.primal_state[name]
                else:
                    local_model[name] = client.primal_state[name]
        else:
            local_model = copy.deepcopy(client.primal_state)

        # Convert local model to bytes
        local_model_buffer = io.BytesIO()
        torch.save(local_model, local_model_buffer)
        local_model_bytes = local_model_buffer.getvalue()

        if cfg.fed.args.do_simulation:
            while time.time()-start_time < local_training_time:
                time.sleep(1)

        # Send the size of local model first
        comm.send(len(local_model_bytes), dest=0, tag=comm_rank)
        
        # Send the state dict
        comm.Isend(np.frombuffer(local_model_bytes, dtype=np.byte), dest=0, tag=comm_rank+comm_size)

    client.outfile.close()


