import copy
import os.path

from .ufedbase import UnlearnBasicClient, UnlearnBasicServer
import numpy as np
from utils import fmodule
import torch
import torch.nn as nn
import collections
import json
from tqdm import tqdm
from utils.utils_unlearn import agg_func
from utils.finch import FINCH

class Server(UnlearnBasicServer):
    def __init__(self, option, model, clients, data_loader, device=None):
        super(Server, self).__init__(option, model, clients, data_loader, device)


class Client(UnlearnBasicClient):
    def __init__(self, option, id, model=None):
        super(Client, self).__init__(option, id, model)
    def train(self, ):
        # initial_train_model = copy.deepcopy(model)
        self.model.train()
        total_loss = 0.0
        optimizer = self.get_optimizer(self.model)
        for e in range(self.epochs):
            # for step, (batch_x, batch_y) in enumerate(self.train_data):
            for batch_id, batch_data in enumerate(self.train_data):
                batch_x, batch_y = batch_data['image'], batch_data['label']
                self.model.zero_grad()

                batch_x = self.data_to_device(batch_x, device=self.device)
                batch_y = self.data_to_device(batch_y, device=self.device)
                outputs = self.model(batch_x)
                loss = self.criterion(outputs, batch_y)
                if self.unlearn and self.stage == 'Unlearn':
                    loss *= -1.0
                loss.backward()
                if self.unlearn and self.stage == 'Unlearn':
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0) # max_norm / clip_value two options
                optimizer.step()
                batch_mean_loss = loss.item()
                total_loss += batch_mean_loss * len(batch_y)
        del optimizer
        return total_loss / (self.datavol * self.epochs)
