import logging
import numpy as np
from numpy import dot
from numpy.linalg import norm
import torch
import torch.nn.functional as F
import copy

class Client:
    def __init__(self, client_idx, local_sample_number, args, device,
                 model_trainer):
        self.client_idx = client_idx
        self.local_sample_number = local_sample_number
        logging.info("self.local_sample_number = " + str(self.local_sample_number))
        self.args = args
        self.device = device
        self.model_trainer = model_trainer

    def train(self, w_global, client_idx, client_idx_other):
        self.model_trainer.id = self.client_idx
        self.model_trainer.set_model_params(w_global)
        weight = copy.deepcopy(self.model_trainer.get_model_params())
        self.model_trainer.PF_train(temp=0, weight = weight, client_idx = client_idx, client_idx_other = client_idx_other)
        weights = self.model_trainer.get_model_params()

        return weights

    def local_test(self, b_use_test_dataset, client_idx, client_idx_other):
        weight = copy.deepcopy(self.model_trainer.get_model_params())
        metrics_frjve = self.model_trainer.PF_train(temp=1, weight = weight, client_idx = client_idx, client_idx_other = client_idx_other)
        return metrics_frjve