from collections import OrderedDict
from copy import deepcopy
from typing import Dict

import torch
from torch.nn.functional import cosine_similarity, relu

from fedavg import FedAvgClient


class MOONClient(FedAvgClient):
    def __init__(self, model, args, logger, device, class_num, indexs_tensor):
        super().__init__(model, args, logger, device, class_num, indexs_tensor)
        self.prev_params_dict: Dict[int, OrderedDict[str, torch.Tensor]] = {}
        self.prev_model = deepcopy(self.model)
        self.global_model = deepcopy(self.model)

    def save_state(self):
        super().save_state()
        self.prev_params_dict[self.client_id] = deepcopy(self.model.state_dict())

    def set_parameters(self, new_parameters):
        super().set_parameters(new_parameters)
        self.global_model.load_state_dict(self.model.state_dict())
        if self.client_id in self.prev_params_dict.keys():
            self.prev_model.load_state_dict(self.prev_params_dict[self.client_id])
        else:
            self.prev_model.load_state_dict(self.model.state_dict())

    def fit(self):
        self.indexs_tensor = self.indexs_tensor.to(self.device)
        self.model.train()
        for _ in range(self.local_epoch):
            for x, y in self.trainloader:
                if len(x) <= 1:
                    continue

                x, y = x.to(self.device), y.to(self.device)
                z_curr = self.model.get_final_features(x, detach=False)
                z_global = self.global_model.get_final_features(x, detach=True)
                z_prev = self.prev_model.get_final_features(x, detach=True)
                if self.args.local_reg:
                    z_curr = self.model.get_final_features(x, detach=False)
                    project_z = self.model.projection(z_curr)
                    # orthogonal_loss = self.orthogonal_loss(project_z, self.indexs_tensor, self.args.mu)
                    orthogonal_loss = self.orthogonal_loss(project_z, self.indexs_tensor, self.args.local_reg_weight)
                    logit = self.model.classifier_forward(z_curr)

                    # preservation_loss = self.preservation_loss(self.model.projection_classifier(relu(project_z)), logit.detach(), self.args.mu)
                    preservation_loss = self.preservation_loss(self.model.projection_classifier_forward(project_z), logit.detach(), self.args.local_reg_weight)
                else:
                    logit = self.model.classifier_forward(z_curr)
                    orthogonal_loss = 0.0
                    preservation_loss = 0.0
                loss_sup = self.criterion(logit, y) + orthogonal_loss + preservation_loss
                loss_con = -torch.log(
                    torch.exp(
                        cosine_similarity(z_curr.flatten(1), z_global.flatten(1))
                        / self.args.tau
                    )
                    / (
                        torch.exp(
                            cosine_similarity(z_prev.flatten(1), z_curr.flatten(1))
                            / self.args.tau
                        )
                        + torch.exp(
                            cosine_similarity(z_curr.flatten(1), z_global.flatten(1))
                            / self.args.tau
                        )
                    )
                )

                loss = loss_sup + self.args.mu * torch.mean(loss_con)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
