from typing import Optional
import torch
from torch.utils.data import DataLoader
from src.algo.fed_clients.base_client import Client
from src.utils import TrainAnalyzer, load_tensor_list, store_tensor_list


class FedHBMClient(Client):

    def __init__(self, client_id: int, dataloader: Optional[DataLoader], num_classes: int, device: str,
                 C: float, save_memory: bool, analyzer: Optional[TrainAnalyzer] = None):
        super().__init__(client_id, dataloader, num_classes, device, save_memory, analyzer)
        self.__last_trained_model = None
        self.correction = None
        self.fraction = C

    def __calculate_correction(self):
        self.correction = [last_train.data - server.data for last_train, server in
                           zip(self.__last_trained_model, self.model.parameters())]

    def setup(self):
        super().setup()
        if self.__last_trained_model is None:
            self.__last_trained_model = [torch.clone(p.data) for p in
                                         self.model.parameters()]
        else:
            load_tensor_list(self.__last_trained_model, self.save_memory, self.device, f'last_model_client{self.client_id}')
        self.__calculate_correction()

    def cleanup(self):
        self.__last_trained_model = [torch.clone(p.data) for p in
                                     self.model.parameters()]
        store_tensor_list(self.__last_trained_model, self.save_memory, f'last_model_client{self.client_id}')
        self.correction = None
        self.starting_model = None
        super().cleanup()

    def client_update(self, optimizer: type, optimizer_args, local_epoch: int, loss_fn: torch.nn.Module, s_round: int):
        self.model.train()
        op = optimizer(self.model.parameters(), **optimizer_args)
        factor = self.fraction / (optimizer_args['lr'] * len(self.dataloader) * local_epoch)
        

        for _ in range(local_epoch):
            for img, target in self.dataloader:
                img = img.to(self.device)
                target = target.to(self.device)
                op.zero_grad()
                logits = self.model(img)
                loss = loss_fn(logits, target)
                loss.backward()

                [p.grad.add_(c, alpha=factor) for p, c in zip(self.model.parameters(), self.correction)]
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 50)
                op.step()

                self.__calculate_correction()

        # correction will be recalculated next time
        self.correction = None