import torch
import torch.optim as optim
import torch.nn as nn
import copy
from torch.optim.lr_scheduler import StepLR


class client_FLwF():
    def __init__(self, model, train_dataset, name, args):
        self.args = args
        self.model = copy.deepcopy(model)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr_local, momentum=0.9)
        self.scheduler = StepLR(self.optimizer, step_size=5, gamma=0.96)
        self.theta_reg = None
        self.init_weights = None
        self.initialize_model()
        self.name = name

    def set_theta_reg(self, theta_reg):
        self.theta_reg = theta_reg

    def _random_init_weights(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            torch.nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

    def initialize_model(self):
        if self.init_weights is not None:
            self.model.load_state_dict(self.init_weights)
        else:
            self.model.apply(self._random_init_weights)
    def update_model(self, new_model):
        self.model.load_state_dict(new_model.state_dict())

    def train(self, global_model = None, previous_global_model=None):
        self.model = self.model.to(self.device)
        self.model.train()

        for epoch in range(self.args.epochs):
            for images, labels in self.train_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                self.model.zero_grad()
                self.optimizer.zero_grad()
                outputs = self.model(images)
                # classification loss
                base_loss = nn.CrossEntropyLoss()
                loss_class = base_loss(outputs, labels)
                # dis_cl loss
                exp_outputs = torch.exp(outputs)/2
                row_sums = exp_outputs.sum(1, keepdim=True)
                softmax_manual = exp_outputs / row_sums
                log_softmax_manual = torch.log(softmax_manual + 1e-12)
                global_output = global_model(images)
                global_softmax_manual = torch.exp(global_output/2)/(torch.exp(global_output/2).sum(1, keepdim=True))
                loss_dis_client = -(global_softmax_manual * log_softmax_manual).sum(dim=1).sum()
                # dis_server loss
                if previous_global_model is not None:
                    alpha = 0.5
                    beta = 0.25
                    previous_global_output = previous_global_model(images)
                    previous_global_softmax_manual = torch.exp(previous_global_output/2) / (torch.exp(previous_global_output/2).sum(1, keepdim=True))
                    loss_dis_server = -(previous_global_softmax_manual * log_softmax_manual).sum(dim=1).sum()
                    loss = alpha * loss_class + beta * loss_dis_client + (1 - alpha - beta) * loss_dis_server
                else:
                    alpha = 0.99
                    loss = alpha * loss_class + (1 - alpha) * loss_dis_client
                loss.backward()
                self.optimizer.step()
        self.scheduler.step()



