import os
import copy
import time
import torch
from torch.optim import *
from .algorithm import BaseClient
from collections import OrderedDict

class FedCompassClientOptim(BaseClient):
    def __init__(
        self, id, weight, model, loss_fn, dataloader, cfg, outfile, test_dataloader, **kwargs
    ):
        super(FedCompassClientOptim, self).__init__(
            id, weight, model, loss_fn, dataloader, cfg, outfile, test_dataloader
        )
        self.__dict__.update(kwargs)
        self.round = 0
        super(FedCompassClientOptim, self).client_log_title()

    def update(self):
        """Inputs for the local model update"""
        self.model.to(self.cfg.device)
        optimizer = eval(self.optim)(self.model.parameters(), **self.optim_args)

        """ Multiple local update """
        start_time=time.time()
        ## initial evaluation
        if self.cfg.validation == True and self.test_dataloader != None:
            test_loss, test_accuracy = super(FedCompassClientOptim, self).client_validation(
                self.test_dataloader
            )
            per_iter_time = time.time() - start_time
            super(FedCompassClientOptim, self).client_log_content(
                0, per_iter_time, 0, 0, test_loss, test_accuracy
            )
            ## return to train mode
            self.model.train()        

        ## local training
        data_iter = iter(self.dataloader)
        start_time = time.time()
        epoch = 1
        train_loss, train_correct, tmptotal = 0, 0, 0
        for _ in range(self.local_steps):
            try:
                data, target = next(data_iter)
            except: # End of one local epoch
                ## Validation
                train_loss = train_loss / len(self.dataloader)
                train_accuracy = 100.0 * train_correct / tmptotal
                if self.cfg.validation == True and self.test_dataloader != None:
                    test_loss, test_accuracy = super(FedCompassClientOptim, self).client_validation(self.test_dataloader)
                    per_iter_time = time.time() - start_time
                    super(FedCompassClientOptim, self).client_log_content(epoch, per_iter_time, train_loss, train_accuracy, test_loss, test_accuracy)
                    self.model.train()
                start_time = time.time()
                train_loss, train_correct, tmptotal = 0, 0, 0
                epoch += 1

                ## save model.state_dict()
                if self.cfg.save_model_state_dict == True:
                    path = self.cfg.output_dirname + "/client_%s" % (self.id)
                    if not os.path.exists(path):
                        os.makedirs(path)
                    torch.save(
                        self.model.state_dict(),
                        os.path.join(path, "%s_%s.pt" % (self.round, epoch)),
                    )
                ## Reset the data iterator
                data_iter = iter(self.dataloader)
                data, target = next(data_iter)

            tmptotal += len(target)
            data = data.to(self.cfg.device)
            target = target.to(self.cfg.device)
            optimizer.zero_grad()
            output = self.model(data)
            loss = self.loss_fn(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            if output.shape[1] == 1:
                pred = torch.round(output)
            else:
                pred = output.argmax(dim=1, keepdim=True)
            train_correct += pred.eq(target.view_as(pred)).sum().item()

            if self.clip_value != False:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_value, norm_type=self.clip_norm)

        self.round += 1

        self.primal_state = copy.deepcopy(self.model.state_dict())
        if (self.cfg.device == "cuda"):            
            for k in self.primal_state:
                self.primal_state[k] = self.primal_state[k].cpu()

        """ Differential Privacy  """
        if self.epsilon != False:
            sensitivity = 0
            if self.clip_value != False:
                sensitivity = 2.0 * self.clip_value * self.optim_args.lr
            scale_value = sensitivity / self.epsilon
            super(FedCompassClientOptim, self).laplace_mechanism_output_perturb(scale_value)

        """ Update local_state """
        self.local_state = OrderedDict()
        self.local_state["primal"] = self.primal_state
        self.local_state["dual"] = OrderedDict()
        self.local_state["penalty"] = OrderedDict()
        self.local_state["penalty"][self.id] = 0.0

        return self.local_state
 
