from torch import nn
import numpy as np
from FedUtils.models.utils import CusDataset
from torch.utils.data import DataLoader

class Client(object):
    def __init__(self, id, group, train_data, eval_data, model, batchsize, train_transform=None, test_transform=None, traincusdataset=None, evalcusdataset=None):
        super(Client, self).__init__()
        self.model=model
        self.id=id
        self.group=group
        self.train_samplenum=len(train_data["x"])
        self.num_train_samples=len(train_data["x"])
        self.num_test_samples=len(eval_data["x"])
        drop_last=False
        if traincusdataset:
            #if len(traincusdataset(train_data))%batchsize==1:
            #    drop_last=True
            self.train_data = DataLoader(traincusdataset(train_data), batch_size=batchsize, shuffle=True, drop_last=drop_last)
            self.train_data_fortest = DataLoader(evalcusdataset(train_data), batch_size=batchsize, shuffle=False)
            self.eval_data = DataLoader(evalcusdataset(eval_data), batch_size=100, shuffle=False)
        else:
            #if len(CusDataset(train_data))%batchsize==1:
            #    drop_last=True
            self.train_data = DataLoader(CusDataset(train_data, transform=train_transform), batch_size=batchsize, shuffle=True, drop_last=drop_last)
            self.train_data_fortest = DataLoader(CusDataset(train_data, transform=test_transform), batch_size=batchsize, shuffle=False)
            self.eval_data = DataLoader(CusDataset(eval_data, transform=test_transform), batch_size=100, shuffle=False)
        self.train_iter=iter(self.train_data)
        """if __debug__:
            print("{} has train size {} eval size {}".format(id, self.num_train_samples, self.num_test_samples))"""

    def set_param(self, state_dict):
        self.model.set_param(state_dict)
        return True

    def get_param(self):
        return self.model.get_param()

    def solve_grad(self):
        bytes_w=self.model.size
        grads, comp=self.model.get_gradients(self.train_data)
        bytes_r=self.model.size
        return ((self.num_train_samples, grads), (bytes_w, comp, bytes_r))

    def solve_inner(self, num_epochs=1, extra_loss=None, step_func=None):
        bytes_w=self.model.size
        soln, comp, weight = self.model.solve_inner(self.train_data, num_epochs, extra_loss, step_func=step_func)
        bytes_r=self.model.size
        #if weight<0:
        #    weight=0
        #else:
        #    weight=1
        return (self.num_train_samples*weight, soln), (bytes_w, comp, bytes_r)

    def solve_iters(self, num_iters=1):
        bytes_w=self.model.size
        for _ in range(num_iters):
            try:
                data=next(self.train_iter)
            except:
                self.train_iter=iter(self.train_data)
                data=next(self.train_iter)
            soln, comp=self.model.solve_iters(data)
        bytes_r=self.model.size
        return (self.num_train_samples, soln), (bytes_w, comp, bytes_r)

    def test(self):
        total_correct, loss=self.model.test(self.eval_data)
        return total_correct,  self.num_test_samples

    def train_error_and_loss(self):
        tot_correct, loss=self.model.test(self.train_data_fortest)
        return tot_correct, loss, self.train_samplenum
