import copy
import os.path

from .ufedbase import UnlearnBasicClient, UnlearnBasicServer
import numpy as np
from utils import fmodule
import torch
import collections
import json
class Server(UnlearnBasicServer):
    def __init__(self, option, model, clients, data_loader, device=None):
        super(Server, self).__init__(option, model, clients, data_loader, device)


class Client(UnlearnBasicClient):
    def __init__(self, option, id, model=None):
        super(Client, self).__init__(option, id, model)

    # def train(self, ):
    #     # initial_train_model = copy.deepcopy(model)
    #     self.model.train()
    #     total_loss = 0.0
    #     optimizer = self.get_optimizer()
    #     for e in range(self.epochs):
    #         for step, (batch_x, batch_y) in enumerate(self.train_data):
    #             self.model.zero_grad()
    #             batch_x = self.data_to_device(batch_x, device=self.device)
    #             batch_y = self.data_to_device(batch_y, device=self.device)
    #             outputs = self.model(batch_x)
    #             loss = self.criterion(outputs, batch_y)
    #             loss.backward()
    #             optimizer.step()
    #             batch_mean_loss = loss.item()
    #             total_loss += batch_mean_loss * len(batch_y)
    #     return total_loss / (self.datavol * self.epochs)
    # def test(self, model=None, dataflag='test'):
    #     test_model = model if model is not None else self.model
    #     test_model.eval()
    #
    #     if dataflag == 'train':
    #         dataset = self.train_data
    #         datavol = self.datavol
    #     else:
    #         dataset = self.test_data
    #         datavol = self.test_datavol
    #     total_loss = 0.0
    #     num_correct = 0
    #     local_metric = {}
    #     with torch.no_grad():
    #         for batch_id, (batch_x, batch_y) in enumerate(dataset):
    #             batch_x = self.data_to_device(batch_x, device=self.device)
    #             batch_y = self.data_to_device(batch_y, device=self.device)
    #             outputs = test_model(batch_x)
    #             batch_mean_loss = self.criterion(outputs, batch_y).item()
    #             y_pred = outputs.data.max(1, keepdim=True)[1]
    #             correct = y_pred.eq(batch_y.data.view_as(y_pred)).long().cpu().sum()
    #             num_correct += correct.item()
    #             total_loss += batch_mean_loss * len(batch_y)
    #         if not self.unlearn:
    #             local_metric.update({'retain_accuracy': 100 * num_correct / datavol, 'retain_loss': total_loss / datavol})
    #         else:
    #             local_metric.update({'Backdoor_accuracy': 100 * num_correct / datavol,
    #                                 'Backdoor_loss': total_loss / datavol})
    #     if self.unlearn:
    #         # 统计unlearn memory acc 用训练集做指标的
    #         BD_correct = 0
    #         BD_loss = 0.0
    #         with torch.no_grad():
    #             for batch_id, (batch_x, batch_y) in enumerate(self.UM_test_data):
    #                 batch_x = self.data_to_device(batch_x, device=self.device)
    #                 batch_y = self.data_to_device(batch_y, device=self.device)
    #                 outputs = test_model(batch_x)
    #                 batch_mean_loss = self.criterion(outputs, batch_y).item()
    #                 y_pred = outputs.data.max(1, keepdim=True)[1]
    #                 correct = y_pred.eq(batch_y.data.view_as(y_pred)).long().cpu().sum()
    #                 BD_correct += correct.item()
    #                 BD_loss += batch_mean_loss * len(batch_y)
    #             local_metric.update({'Unlearn_Memory_accuracy': 100 * BD_correct / self.UM_test_datavol,
    #                                  'Unlearn_Memory_loss': BD_loss / self.UM_test_datavol})
    #     return local_metric