from typing import OrderedDict

import torch

from fedavg import FedAvgClient
from src.config.utils import trainable_params, vectorize
import torch.nn.functional as F


class FedDynClient(FedAvgClient):
    def __init__(self, model, args, logger, device, class_num, indexs_tensor):
        super().__init__(model, args, logger, device, class_num, indexs_tensor)

        self.nabla: torch.Tensor = None
        self.vectorized_global_params: torch.Tensor = None
        self.vectorized_curr_params: torch.Tensor = None

    def train(
        self,
        client_id: int,
        local_epoch: int,
        new_parameters: OrderedDict[str, torch.Tensor],
        nabla: torch.Tensor,
        return_diff=False,
        verbose=False,
    ):
        self.vectorized_global_params = vectorize(new_parameters, detach=True)
        self.nabla = nabla
        res = super().train(
            client_id, local_epoch, new_parameters, return_diff, verbose
        )
        return res

    def fit(self):
        self.indexs_tensor = self.indexs_tensor.to(self.device)
        self.model.train()
        for _ in range(self.local_epoch):
            for x, y in self.trainloader:
                if len(x) <= 1:
                    continue

                x, y = x.to(self.device), y.to(self.device)
                if self.args.local_reg:
                    z_curr = self.model.get_final_features(x, detach=False)
                    project_z = self.model.projection(z_curr)
                    orthogonal_loss = self.orthogonal_loss(project_z, self.indexs_tensor, self.args.local_reg_weight)
                    logit = self.model.classifier_forward(z_curr)

                    preservation_loss = self.preservation_loss(self.model.projection_classifier_forward(project_z), logit.detach(), self.args.local_reg_weight)
                else:
                    logit = self.model(x)
                    orthogonal_loss = 0.0
                    preservation_loss = 0.0
                loss_ce = self.criterion(logit, y) + orthogonal_loss + preservation_loss
                self.vectorized_curr_params = vectorize(trainable_params(self.model))
                loss_algo = self.args.alpha * torch.sum(
                    self.vectorized_curr_params
                    * (-self.vectorized_global_params + self.nabla)
                )
                loss = loss_ce + loss_algo
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    trainable_params(self.model), max_norm=self.args.max_grad_norm
                )
                self.optimizer.step()
