import copy
import time
from unittest import skip
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import pairwise_distances

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import utils
from averaging import get_averaged_params, weighted_average_numpy, weighted_average_torch
from client import init_model, init_optimizer


# Constants
# =========
Softmax = nn.Softmax(dim=-1)
LogSoftmax = nn.LogSoftmax(dim=-1)

KL_Loss = nn.KLDivLoss(reduction='batchmean')
CE_Loss = nn.CrossEntropyLoss()
CE_Loss_sample = nn.CrossEntropyLoss(reduction='none')


# Trainers
# ========
class Trainer(object):
    def __init__(self, args):
        if args.algorithm == 'Regular' or args.algorithm == 'Joint':
            self.train = train_regular
        elif args.algorithm.startswith('FedAvg'):
            self.train = train_avg
        elif args.algorithm == 'DFedEM':
            self.train = train_dfedem
        elif args.algorithm == 'Federico':
            self.train = train_federico
        elif args.algorithm == 'FedFomo':
            self.train = train_fedfomo
        elif args.algorithm == 'CFL':
            self.train = train_clustered
        else:
            raise ValueError("Unknown training method")


class ClientMetrics:
    def __init__(self, comm, logger, args, **local_metrics):
        self.comm = comm
        self.logger = logger
        self.args = args
        self.initialize_local_metrics(**local_metrics)

    def initialize_local_metrics(self, **local_metrics):
        default_metrics = {
            'standard_accuracy': np.zeros(self.args.n_rounds, dtype=np.float32),
            'customized_accuracy': np.zeros(self.args.n_rounds, dtype=np.float32),
            'train_accuracy': np.zeros(self.args.n_rounds, dtype=np.float32),
            'privacy_budget': np.zeros(self.args.n_rounds, dtype=np.float32)
        }
        self.local_metrics = default_metrics if len(local_metrics) == 0 else local_metrics

    def update_local_metrics(self, round, **update):
        for metric, array in self.local_metrics.items():
            array[round] = update[metric]

        if self.args.verbose:
            self.log_round(round, update)

    def log_round(self, round, update):
        logmsg = (
            f"Round {round}, Client {self.comm.rank}:  "
            f"Pvt std acc={update['standard_accuracy']:.4f} | "
            f"Pvt cust acc={update['customized_accuracy']:.4f}"
        )
        if self.args.use_private_SGD:
            logmsg += f" | ε={update['privacy_budget']:.2f}"

        self.logger.info(logmsg)

    def gather_metrics(self):
        self.gathered_metrics = dict()

        for metric, array in self.local_metrics.items():
            self.gathered_metrics[metric] = (
                np.empty([self.args.n_clients, *array.shape], dtype=np.float32)
                if self.comm.rank == 0 else None
            )
            self.comm.Gather(self.local_metrics[metric], self.gathered_metrics[metric], root=0)


def train_avg(client, eval_data, comm, logger, args):
    metrics = ClientMetrics(comm, logger, args)

    # Fixed client weights based on dataset sizes
    # Client dataset sizes shared non-privately
    private_data_size = len(client.private_train_data[0])
    client_weights = comm.gather(np.array(private_data_size), root=0)
    if comm.rank == 0:
        total_data_size = sum(client_weights)
        client_weights = client_weights / total_data_size
        logger.info(f"Total data size: {total_data_size}. client_weights: {client_weights}")\
    
    private_test_size = len(client.private_test_data[1])
    client_test_sizes = comm.gather(private_test_size, root=0)

    # Log hypers and iterate
    # ===
    logger.info(f"Hyperparameter setting = {args}")

    start_time = time.time()
    comm_time = 0.

    for r in range(args.n_rounds):
        regular_training_loop(client, None, args)

        comm.Barrier()  # for accurate comm time measurement
        comm_start_time = time.time()
        all_model_weights = comm.gather(utils.extract_numpy_params(client.private_model), root=0)
        comm_time += time.time() - comm_start_time

        if comm.rank == 0:
            fedavg_params, _ = get_averaged_params(all_model_weights, client_weights, args)
        else:
            fedavg_params = None

        fedavg_params = comm.bcast(fedavg_params, root=0)
        client.private_model.load_state_dict(utils.convert_np_params_to_tensor(fedavg_params))
        comm_time += time.time() - comm_start_time

        if args.algorithm == 'FedAvg+':
            regular_training_loop(client, None, args)

        metrics.update_local_metrics(r,
                                     standard_accuracy=utils.evaluate_model(client.private_model, eval_data, args),
                                     train_accuracy=utils.evaluate_model(client.private_model, client.private_train_data, args),
                                     customized_accuracy=utils.evaluate_model(client.private_model, client.private_test_data, args),
                                     privacy_budget=client.privacy_budget
                                     )
        client.private_model.load_state_dict(utils.convert_np_params_to_tensor(fedavg_params))
    metrics.gather_metrics()
    res = {
        **metrics.gathered_metrics,
        "training_time": time.time() - start_time,
        "comm_time": comm_time,
    }
    if comm.rank == 0:
        res['client_test_sizes'] = client_test_sizes
    return res

def train_regular(client, eval_data, comm, logger, args):
    # Regular training (with *combined* number of epochs)
    metrics = ClientMetrics(comm, logger, args)
    logger.info(f"Hyperparameter setting = {args}")
    private_test_size = len(client.private_test_data[1])
    client_test_sizes = comm.gather(private_test_size, root=0)
    start_time = time.time()

    for r in range(args.n_rounds):
        train_accuracy, train_privacy_budget = regular_training_loop(client, None, args)

        metrics.update_local_metrics(r,
                                     standard_accuracy=utils.evaluate_model(client.private_model, eval_data, args),
                                     customized_accuracy=utils.evaluate_model(client.private_model, client.private_test_data, args),
                                     train_accuracy=utils.evaluate_model(client.private_model, client.private_train_data, args),
                                     privacy_budget=client.privacy_budget
                                     )

    # Collect all results
    metrics.gather_metrics()
    res = {
        **metrics.gathered_metrics,
        "training_time": time.time() - start_time,
    }
    if comm.rank == 0:
        res['client_test_sizes'] = client_test_sizes
    return res


def train_dfedem(client, eval_data, comm, logger, args):
    """
    Implements
     `Federated Multi-Task Learning under a Mixture of Distributions`.

     Follows implementation from https://github.com/omarfoq/FedEM

    """
    start_time = time.time()
    metrics = ClientMetrics(comm, logger, args)
    metrics.local_metrics['component_weights'] = np.zeros((args.n_rounds, args.n_components), dtype=np.float32)
    logger.info(f"Hyperparameter setting = {args}")

    comm_time = 0.
    private_data_size = len(client.private_train_data[0])
    
    comm_start_time = time.time()
    clients_data_size = comm.allgather(private_data_size)
    comm_time += time.time() - comm_start_time
    total_data_size = sum(clients_data_size)
    clients_size_weights = [data_size / total_data_size for data_size in clients_data_size]
    clients_size_weights = np.array(clients_size_weights, dtype=np.float32)

    private_test_size = len(client.private_test_data[1])
    client_test_sizes = comm.gather(private_test_size, root=0)

    for r in range(args.n_rounds):
        # e-step
        # shape: (# of components, # of samples)
        all_losses = utils.calculate_all_components_losses(client.private_components, client.private_train_data, args)
        samples_weights = F.softmax((torch.log(client.private_component_weights) - all_losses.T), dim=1).T

        # m-step
        # shape: (# of components)
        client.private_component_weights = samples_weights.mean(dim=1)
        training_loop_fedem(args, client, samples_weights, epochs=1)

        # send and receive updates ([# clients, # of components])
        comm_start_time = time.time()
        component_params = [utils.extract_numpy_params(component) for component in client.private_components]
        all_model_params = comm.allgather(component_params)  # n_clients x n_components list of lists of params
        comm_time += time.time() - comm_start_time
        update_component_models(client, all_model_params, clients_size_weights, args)

        metrics.update_local_metrics(r,
                                     standard_accuracy=utils.evaluate_component_model(client, eval_data, args),
                                     customized_accuracy=utils.evaluate_component_model(client, client.private_test_data, args),
                                     train_accuracy=utils.evaluate_component_model(client, client.private_train_data, args),
                                     privacy_budget=client.privacy_budget,
                                     component_weights=client.private_component_weights.cpu().detach().numpy().copy()
                                     )

    # Collect all results
    metrics.gather_metrics()
    res = {
        **metrics.gathered_metrics,
        "training_time": time.time() - start_time,
        "comm time": comm_time
    }
    if comm.rank == 0:
        res['client_test_sizes'] = client_test_sizes
    return res


def train_federico(client, eval_data, comm, logger, args):
    metrics = ClientMetrics(comm, logger, args)
    metrics.local_metrics['component_weights'] = np.zeros((args.n_rounds, args.n_clients), dtype=np.float32)
    metrics.local_metrics['component_losses'] = np.zeros((args.n_rounds, args.n_clients), dtype=np.float32)
    metrics.local_metrics['column_component_weights'] = np.zeros((args.n_rounds, args.n_clients), dtype=np.float32)
    logger.info(f"Hyperparameter setting = {args}")

    comm_time = 0.
    start_time = time.time()
    # sample clients from which to receive models for m-step
    np.random.seed(comm.rank)

    with torch.no_grad():
        init_loss = utils.calculate_losses(client.private_model, client.private_train_data, args).mean()
        losses = torch.stack([init_loss] * args.n_clients)
        client_weights = (torch.ones(args.n_clients) / args.n_clients).to(args.device)
        acc_losses = torch.tensor(losses) 

    private_test_size = len(client.private_test_data[1])
    client_test_sizes = comm.gather(private_test_size, root=0)

    for r in range(args.n_rounds):
        column_client_weights = np.zeros(args.n_clients)
        torch.cuda.empty_cache()
        # print("round {} of client {}".format(r, comm.rank))
        private_model_old_params = utils.extract_numpy_params(client.private_model)
        # e-step: estimate q(z)
        # randomly pick neighbors with epsilon-greedy
        if np.random.rand() < args.greedy_eps:  
            incoming_ids = np.random.choice([i for i in range(args.n_clients) if i != comm.rank], 
                                            size=args.n_neighbors, 
                                            replace=False)
        else:
            prob = client_weights.cpu().detach().numpy()
            prob = np.delete(prob, comm.rank)
            prob = prob / np.sum(prob)
            incoming_ids = np.random.choice([i for i in range(args.n_clients) if i != comm.rank], 
                                            size=args.n_neighbors, 
                                            replace=False, 
                                            p=prob)


        # tell other clients who sampled who
        comm_start_time = time.time()
        all_sampled_ids = comm.allgather(incoming_ids)
        comm_time += time.time() - comm_start_time
        outgoing_ids = [i for i, sampled_ids in enumerate(all_sampled_ids)
                        if comm.rank in sampled_ids]

        # send/receive for m-step (https://courses.cs.ut.ee/MTAT.08.020/2019_fall/uploads/Main/MPI_p2p-slides.pdf)
        comm_start_time = time.time()
        send_requests = []
        for dest in outgoing_ids:
            send_requests.append(comm.isend(utils.extract_numpy_params(client.private_model), dest=dest))
        component_params = [comm.recv(source=i) for i in incoming_ids]
        for req in send_requests:
            req.wait()
        comm.Barrier()
        comm_time += time.time() - comm_start_time

        # load retrieved params into dummy models
        for i, params in enumerate(component_params):
            params = utils.convert_np_params_to_tensor(params, device=args.device)
            client.dummy_models[i].load_state_dict(params)

        # update stored losses for corresponding models
        with torch.no_grad():
            losses[comm.rank] = utils.calculate_losses(client.private_model, client.private_train_data, args).mean()
            for model_id, model in zip(incoming_ids, client.dummy_models):
                losses[model_id] = utils.calculate_losses(
                    model, client.private_train_data, args).mean()
            # using ema(exponential moving average)
            acc_losses = (1 - args.cw_momentum) * acc_losses + args.cw_momentum * losses
            client_weights = F.softmax(- acc_losses, dim=0)  
            # The non-ema implemetation: client_weights = F.softmax(torch.log(client_weights) - losses, dim=0)
        column_client_weights[comm.rank] = client_weights[comm.rank]
        weighted_model_updates = training_loop_federico(args, client, [client_weights[comm.rank]]+[client_weights[i] for i in incoming_ids],
                                                     [client.private_model] + client.dummy_models, epochs=1)  # train sampled models
        # send/receive model updates
        comm_start_time = time.time()
        component_params = [utils.extract_numpy_params(model) for model in client.dummy_models]
        send_requests = []
        for dest, weighted_model_update in zip(incoming_ids, weighted_model_updates[1:]):
            send_requests.append(comm.isend(weighted_model_update, dest=dest))
        incoming_model_updates = [weighted_model_updates[0]]+[comm.recv(source=i) for i in outgoing_ids]
        for s in send_requests: s.wait()
        comm.Barrier()
        comm_time += time.time() - comm_start_time

        send_requests = []
        for dest in incoming_ids:
            send_requests.append(comm.isend(client_weights[dest], dest=dest))
        for i in outgoing_ids: column_client_weights[i] = comm.recv(source=i)
        for s in send_requests: s.wait()

        avg_weights = np.ones((len(incoming_model_updates))) / len(incoming_model_updates)
        avg_param_updates, _ = weighted_average_numpy(incoming_model_updates, avg_weights)
        new_params, _ = weighted_average_numpy([avg_param_updates, private_model_old_params], np.ones(2))
        with torch.no_grad():
            client.private_model.load_state_dict(utils.convert_np_params_to_tensor(new_params, device=args.device))

        client.private_components = comm.allgather(client.private_model)
        for component in client.private_components: component.to(args.device)
        client.private_component_weights = client_weights
        for cw in client.private_component_weights: cw.to(args.device)

        metrics.update_local_metrics(r,
                                     standard_accuracy=utils.evaluate_component_model(client, eval_data, args),
                                     customized_accuracy=utils.evaluate_component_model(client, client.private_test_data, args),
                                     train_accuracy=utils.evaluate_component_model(client, client.private_train_data, args),
                                     privacy_budget=client.privacy_budget,
                                     component_weights=client.private_component_weights.cpu().detach().numpy().copy(),
                                     component_losses=losses.cpu().detach().numpy().copy(),
                                     column_component_weights=column_client_weights
                                     )

    # Collect all results
    metrics.gather_metrics()
    res = {
        **metrics.gathered_metrics,
        "training_time": time.time() - start_time,
        "comm time": comm_time
    }
    if comm.rank == 0:
        res['client_test_sizes'] = client_test_sizes
    return res

def train_fedfomo(client, eval_data, comm, logger, args):
    """
    Implements "Personalized Federated Learning with First Order Model Optimization"
    """
    metrics = ClientMetrics(comm, logger, args)
    metrics.local_metrics['component_weights'] = np.zeros((args.n_rounds, args.n_clients), dtype=np.float32)
    logger.info(f"Hyperparameter setting = {args}")

    comm_time = 0.
    start_time = time.time()

    with torch.no_grad():
        init_loss = utils.calculate_losses(client.private_model, client.private_val_data, args).mean()
        losses = torch.stack([init_loss] * args.n_clients)
        client_weights = (torch.ones(args.n_clients) / args.n_clients).to(args.device)

    private_test_size = len(client.private_test_data[1])
    client_test_sizes = comm.gather(private_test_size, root=0)

    for r in range(args.n_rounds):
        torch.cuda.empty_cache()
        #train model
        private_model_old_params = utils.get_param_tensor(client.private_model)
        private_old_loss = utils.calculate_losses(client.private_model, client.private_val_data, args).mean()
        for _ in range(args.n_local_epochs):
            train_accuracy, train_privacy_budget = regular_training_loop(client, None, args)

        comm_start_time = time.time()
        all_client_params = comm.allgather(utils.extract_numpy_params(client.private_model))
        comm_time += time.time() - comm_start_time

        with torch.no_grad():
            #update client weights
            client_weights_unnormalized = torch.ones(args.n_clients)
            client_updates = []
            for client_id, client_params in enumerate(all_client_params):
                client.dummy_model.load_state_dict(utils.convert_np_params_to_tensor(client_params, device=args.device))
                new_loss = utils.calculate_losses(client.dummy_model, client.private_val_data, args).mean()
                client_update = utils.get_param_tensor(client.dummy_model) - private_model_old_params
                client_weight = (private_old_loss - new_loss) / client_update.norm(p=1)
                client_weights_unnormalized[client_id] = torch.nn.functional.relu(client_weight)
                client_updates.append(client_update)
            normalizing_factor = sum(client_weights_unnormalized)
            if normalizing_factor < 1e-9: normalizing_factor += 1e-9
            client_weights = client_weights_unnormalized / normalizing_factor
            client_updates = torch.stack(client_updates, dim=0)

        #update local model
        new_params = private_model_old_params
        for c_id in range(args.n_clients):
            new_params += client_weights[c_id] * client_updates[c_id]
        new_params = utils.convert_param_tensor_to_np_params(new_params, client.private_model)
        client.private_model.load_state_dict(utils.convert_np_params_to_tensor(new_params))

        metrics.update_local_metrics(r,
                                     standard_accuracy=utils.evaluate_model(client.private_model, eval_data, args),
                                     customized_accuracy=utils.evaluate_model(client.private_model, client.private_test_data, args),
                                     train_accuracy=utils.evaluate_model(client.private_model, client.private_train_data, args),
                                     privacy_budget=client.privacy_budget,
                                     component_weights=client_weights.cpu().detach().numpy().copy()
                                     )

    # Collect all results
    metrics.gather_metrics()
    res = {
        **metrics.gathered_metrics,
        "training_time": time.time() - start_time,
        "comm time": comm_time
    }
    if comm.rank == 0:
        res['client_test_sizes'] = client_test_sizes
    return res


def train_clustered(client, eval_data, comm, logger, args, tol_1 = 0.4, tol_2 = 1.6):
    """
    Implements
     `Clustered Federated Learning: Model-Agnostic Distributed Multi-Task Optimization under Privacy Constraints`.

     Follows implementation from https://github.com/omarfoq/FedEM

    """
    metrics = ClientMetrics(comm, logger, args)
    clusters_indices = [np.arange(args.n_clients).astype("int")]
    n_clusters = 1

    # Fixed client weights based on dataset sizes
    # Client dataset sizes shared non-privately
    private_data_size = len(client.private_train_data[0])
    client_weights = comm.gather(np.array(private_data_size), root=0)
    if comm.rank == 0:
        total_data_size = sum(client_weights)
        client_weights = client_weights / total_data_size
        logger.info(f"Total data size: {total_data_size}. client_weights: {client_weights}")\
    
    private_test_size = len(client.private_test_data[1])
    client_test_sizes = comm.gather(private_test_size, root=0)

    # Log hypers and iterate
    # ===
    logger.info(f"Hyperparameter setting = {args}")

    start_time = time.time()
    comm_time = 0.

    for r in range(args.n_rounds):
        old_params = utils.get_param_tensor(client.private_model)
        for _ in range(args.n_local_epochs):
            regular_training_loop(client, None, args)
        params_update = (utils.get_param_tensor(client.private_model) - old_params).detach().cpu().numpy()
        comm.Barrier()  # for accurate comm time measurement
        comm_start_time = time.time()
        all_model_weights = comm.gather(utils.extract_numpy_params(client.private_model), root=0)
        clients_updates = comm.gather(params_update, root=0)
        comm_time += time.time() - comm_start_time
        clients_updates = np.array(clients_updates)

        cluster_params = None
        if comm.rank == 0:
            similarities = pairwise_distances(clients_updates, metric="cosine")
            new_cluster_indices = []
            for indices in clusters_indices:
                max_update_norm = np.linalg.norm(clients_updates[indices], axis=1).max()
                mean_update_norm = np.linalg.norm(np.mean(clients_updates[indices], axis=0))
                if mean_update_norm < tol_1 and max_update_norm > tol_2 and len(indices) > 2:
                    clustering = AgglomerativeClustering(affinity="precomputed", linkage="complete")
                    clustering.fit(similarities[indices][:, indices])
                    cluster_1 = np.argwhere(clustering.labels_ == 0).flatten()
                    cluster_2 = np.argwhere(clustering.labels_ == 1).flatten()
                    new_cluster_indices += [indices[cluster_1], indices[cluster_2]]

                else:
                    new_cluster_indices += [indices]
            clusters_indices = new_cluster_indices
            n_clusters = len(clusters_indices)
            for indices in clusters_indices:
                cluster_params, _ = get_averaged_params([all_model_weights[idx] for idx in indices], 
                                                        client_weights[indices]/client_weights[indices].sum(), args)
                print("indices:{}".format(indices))
                for client_id in indices:
                    if client_id == 0: 
                        client.private_model.load_state_dict(utils.convert_np_params_to_tensor(cluster_params))
                    else:
                        comm_start_time = time.time()
                        req = comm.send(cluster_params, dest=client_id)
                        comm_time += time.time() - comm_start_time
        else:
            comm_start_time = time.time()
            cluster_params = comm.recv(source=0)
            comm_time += time.time() - comm_start_time
            client.private_model.load_state_dict(utils.convert_np_params_to_tensor(cluster_params))
            

        metrics.update_local_metrics(r,
                                     standard_accuracy=utils.evaluate_model(client.private_model, eval_data, args),
                                     customized_accuracy=utils.evaluate_model(client.private_model, client.private_test_data, args),
                                     train_accuracy=utils.evaluate_model(client.private_model, client.private_train_data, args),
                                     privacy_budget=client.privacy_budget
                                     )

    metrics.gather_metrics()
    res = {
        **metrics.gathered_metrics,
        "training_time": time.time() - start_time,
        "comm_time": comm_time,
    }
    if comm.rank == 0:
        res['client_test_sizes'] = client_test_sizes
    return res

# TRAINING LOOPS
# ==============
def regular_training_loop(client, logger, args):

    client.private_model.train()
    train_private_acc = []
    train_privacy_budget = []

    epsilon = 0
    delta = 1.0 / client.private_train_data[0].shape[0]

    for e in range(args.n_epochs):

        train_loader = utils.data_loader(args.dataset,
                                         client.private_train_data[0],
                                         client.private_train_data[1],
                                         args.batch_size)

        correct_private = 0.0
        acc_private = 0.0

        for idx, (data, target) in enumerate(train_loader):

            client.private_opt.zero_grad()

            data = torch.from_numpy(data).to(client.device)
            target = torch.from_numpy(target).to(client.device)
            pred_private = client.private_model(data)

            loss_private = CE_Loss(pred_private, target)
            loss_private.backward()

            client.private_opt.step()

            pred_private = pred_private.argmax(dim=-1)
            correct_private += pred_private.eq(
                target.view_as(pred_private)).sum()
            acc_private = correct_private / \
                client.private_train_data[0].shape[0]
            train_private_acc.append(acc_private.cpu())

            if args.use_private_SGD:
                epsilon, optimal_alpha = client.private_opt.privacy_engine.get_privacy_spent(
                    delta)
                client.privacy_budget = epsilon

            train_privacy_budget.append(epsilon)

        if logger is not None and args.verbose:
            if args.use_private_SGD:
                logger.info(
                    f"Epoch {e}: train_private_acc={acc_private:.4f}, "
                    "ε={epsilon:.2f} and δ={delta:.4f} at α={optimal_alpha:.2f}")
            else:
                logger.info(f"Epoch {e}: train_private_acc={acc_private:.4f}")

    return (np.array(train_private_acc, dtype=np.float32),
            np.array(train_privacy_budget, dtype=np.float32))


def training_loop_federico(args, client, update_weights, models, epochs=1):
    model_updates = []
    assert len(update_weights) == len(models), "got update weights of len {} and models of len {}".format(len(update_weights), len(models))
    for id, model in enumerate(models):
        optimizer = init_optimizer(model, args)

        for _ in range(epochs):
            train_loader = utils.data_loader(args.dataset,
                                             client.private_train_data[0],
                                             client.private_train_data[1],
                                             args.batch_size,
                                             include_idx=False)
            model, _, model_update = train_model_with_weight(train_loader, model, update_weights[id], optimizer, args)
            models[id] = model
            model_updates.append(model_update)
    return model_updates


def training_loop_fedem(args, client, weights, epochs=1):
    for id, component in enumerate(client.private_components):
        for _ in range(epochs):
            train_loader = utils.data_loader(args.dataset,
                                             client.private_train_data[0],
                                             client.private_train_data[1],
                                             args.batch_size,
                                             include_idx=True)
            component, _, _ = train_model_with_sample_weight(train_loader, component, weights[id], client.private_opts[id], args)
            client.private_components[id] = component


def update_component_models(client, all_model_params, weights, args):
    component_params = []
    for i, params in enumerate(zip(*all_model_params)):
        params = [utils.convert_np_params_to_tensor(p, device=args.device) for p in params]
        avg_params, _ = get_averaged_params(params, weights, args)
        client.private_components[i].load_state_dict(avg_params)
        component_params.append(avg_params)
    return component_params


def train_model_with_sample_weight(train_loader, model, weight, optimizer, args):
    model.train()
    n_samples = 0
    correct_preds = 0
    cumulated_loss = 0.
    for (data, target), indices in train_loader:
        optimizer.zero_grad()
        n_samples += data.shape[0]
        data = torch.from_numpy(data).to(args.device)
        target = torch.from_numpy(target).to(args.device)
        pred = model(data)
        loss_vec = CE_Loss_sample(pred, target)
        loss = (loss_vec.T @ weight[indices]) / loss_vec.size(0)
        loss.backward()
        optimizer.step()

        pred_class = pred.argmax(dim=-1)
        correct_preds += pred_class.eq(target.view_as(pred_class)).sum().item()
        cumulated_loss += loss.detach() * loss_vec.size(0)

    return model, cumulated_loss/n_samples, correct_preds/n_samples


def train_model_with_weight(train_loader, model, update_weight, optimizer, args):
    model.train()
    old_parameters = utils.get_param_tensor(model)
    n_samples = 0
    correct_preds = 0
    for idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        n_samples += data.shape[0]
        data = torch.from_numpy(data).to(args.device)
        target = torch.from_numpy(target).to(args.device)
        pred = model(data)
        loss = CE_Loss(pred, target)
        loss.backward()
        optimizer.step()

        pred_class = pred.argmax(dim=-1)
        correct_preds += pred_class.eq(target.view_as(pred_class)).sum().item()
    model_update = utils.convert_param_tensor_to_np_params(update_weight*(utils.get_param_tensor(model) - old_parameters), model)
    return model, correct_preds/n_samples, model_update
