# -*- coding: utf-8 -*-
import numpy as np
import time
import torch
from torch import nn
import os
import copy
from .ufedbase import UnlearnBasicClient, UnlearnBasicServer
from utils import fmodule

class Server(UnlearnBasicServer):
    def __init__(self, option, model, clients, data_loader, device=None):
        self.old_model = copy.deepcopy(model).to(device)
        super(Server, self).__init__(option, model, clients, data_loader, device)
        self.u_rounds = 1
    def unlearn_iterate(self):
        self.old_model = copy.deepcopy(self.model)
        path = os.path.join(os.path.join(os.path.dirname(self.save_folder), 'fedavg'), 'pretrained_history_fedrecovery')
        model_params = torch.nn.utils.parameters_to_vector(self.model.parameters())
        global_model_norm_square = np.loadtxt(os.path.join(path, 'global_model_norm_square_s2_c30.csv'), delimiter=',')
        sum_norm_square = float(np.sum(global_model_norm_square))  # 计算sum_norm_square
        mid_value_r = 0
        count_r = 0
        mid_value_u = 0
        count_u = 0
        for client in self.clients:
            mid_value = np.loadtxt(os.path.join(path, 'client_' + str(client.id) + '_mid_value_s2_c30.csv'), delimiter=',')
            if client.unlearn:
                mid_value_u += mid_value
                count_u += 1
            else:
                mid_value_r += mid_value
                count_r += 1
        mid_value_r = mid_value_r / (count_r * len(self.clients) * sum_norm_square)
        mid_value_u = mid_value_u / (count_u * len(self.clients) * sum_norm_square)
        print(np.linalg.norm(mid_value_r - mid_value_u))
        model_params = torch.nn.utils.parameters_to_vector(self.model.parameters())
        model_params = model_params - torch.Tensor(mid_value_r - mid_value_u).float().to(self.device)
        torch.nn.utils.vector_to_parameters(model_params, self.model.parameters())
        return

class Client(UnlearnBasicClient):
    def __init__(self, option, id, model=None):
        super(Client, self).__init__(option, id, model)
