from fedavg import FedAvgClient
from src.utils.tools import trainable_params
import torch.nn.functional as F
import torch
import torch.nn as nn
from argparse import Namespace
from collections import OrderedDict
from typing import Dict, List, Tuple, Union
import torch.autograd as autograd


class FedIIRClient(FedAvgClient):
    def __init__(self, model, args, logger, device, class_num, indexs_tensor):
        super(FedIIRClient, self).__init__(model, args, logger, device, class_num, indexs_tensor)
        self.penalty = args.penalty
        self.ema = args.ema

    def fit(self):
        self.indexs_tensor = self.indexs_tensor.to(self.device)
        self.model.train()
        penalty_weight = self.args.penalty
        for i in range(self.local_epoch):
            for x, y in self.trainloader:
                if len(x) <= 1:
                    continue

                x, y = x.to(self.device), y.to(self.device)
                try:
                    features = self.model.base(x)
                    # logit = self.model.classifier(F.relu(features))
                except:
                    print(
                        "model may have no feature extractor + classifier architecture"
                    )
                if self.args.local_reg:
                    project_z = self.model.projection(features)
                    orthogonal_loss = self.orthogonal_loss(project_z, self.indexs_tensor, self.args.local_reg_weight)
                    logit = self.model.classifier_forward(features)

                    preservation_loss = self.preservation_loss(self.model.projection_classifier_forward(project_z), logit.detach(), self.args.local_reg_weight)
                else:
                    logit = self.model.classifier_forward(features)
                    orthogonal_loss = 0.0
                    preservation_loss = 0.0
                loss_erm = F.cross_entropy(logit, y)
                grad_client = autograd.grad(
                    loss_erm, self.model.classifier.parameters(), create_graph=True
                )
                penalty_value = 0
                for g_client, g_mean in zip(grad_client, self.grad_mean):
                    penalty_value += (g_client - g_mean).pow(2).sum()
                loss = loss_erm + penalty_weight * penalty_value + orthogonal_loss + preservation_loss
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

    def grad(self, client_id, new_parameters):
        self.client_id = client_id
        self.load_dataset()
        self.set_parameters(new_parameters)
        grad_sum = tuple(
            torch.zeros_like(p).to(self.device)
            for p in list(self.model.classifier.parameters())
        )
        total_batch = 0
        for x, y in self.trainloader:
            if len(x) <= 1:
                continue

            x, y = x.to(self.device), y.to(self.device)
            try:
                features = self.model.base(x)
                logits = self.model.classifier(F.relu(features))
            except:
                print("model may have no feature extractor + classifier architecture")
            loss = F.cross_entropy(logits, y)
            grad_batch = autograd.grad(
                loss, self.model.classifier.parameters(), create_graph=False
            )
            grad_sum = tuple(g1 + g2 for g1, g2 in zip(grad_sum, grad_batch))
            total_batch += 1
        return grad_sum, total_batch