from copy import deepcopy
import torch.autograd.profiler as profiler

import os
import numpy as np
dirname = os.path.dirname(__file__)
import sys
sys.path.append(os.path.join(dirname, '../'))

from utils import Worker_Trainer, log, count_params
import torch
import time

class Worker(Worker_Trainer):
    def __init__(self, rank, args, model=None, train_data_loader=None, 
                 test_data_loader=None, multiprocessing=True, cpu=None, 
                 gpu=None):
        super().__init__(rank, args, model, train_data_loader, 
                         test_data_loader, multiprocessing, cpu, gpu)

        self.lr = self.args.lr
        self.momentum = self.args.momentum
        self.lr_mask = 10**self.args.lr_mask
        self.local_updates_args = {self.args.K_unit: self.args.K}

    def run_with_multiprocessing(self):
        raise NotImplementedError

    def local_training(self, model:torch.nn.Module, lambdai):
        time_total_st = time.time()

        client_model = model
        linshi_param = deepcopy(model.state_dict())
        criterion = torch.nn.CrossEntropyLoss()
        tot_loss  = 0.0
        client_model.train()
        optimizer = torch.optim.SGD(client_model.parameters(), lr=self.lr, momentum=self.momentum)
        time_train_st = time.time()

        for iteration, (inputs, targets) in enumerate(self.get_next_K_batch(**self.local_updates_args)):
            inputs, targets = inputs.to(self.gpu), targets.to(self.gpu)
            optimizer.zero_grad()
            outputs = client_model(inputs)
            loss = criterion(outputs, targets)
            tot_loss = tot_loss + loss.data
            loss.backward(retain_graph=False)

            for layer_name, layer in zip(['conv1', 'layer1', 'layer2', 'layer3', 'layer4', 'linear'], [client_model.conv1, client_model.layer1, client_model.layer2, client_model.layer3, client_model.layer4, client_model.linear]):     
                for name, param in layer.named_parameters(recurse=True):                 
                    full_name = f"{layer_name}.{name}"      
                    with torch.no_grad():
                        param.grad += - self.args.lambdai_coef * lambdai[full_name] + self.args.prox_coef * self.args.delta * (param - linshi_param[full_name])                    

            optimizer.step()        
        time_train_end = time.time() - time_train_st
        iteration = iteration + 1
        time_grad_st = time.time()
        gradient = {}
        time_grad_end = time.time() - time_grad_st
        sys.stdout.flush()
        time_total_end = time.time() - time_total_st
        log(self.args.save_worker, f'Worker: {self.rank-1}; Iterations: {iteration}\tLoss: {tot_loss/iteration:.6f}; time_total: {time_total_end:.4f}s; time_train: {time_train_end:.4f}s; time_it: {time_grad_end:.4f}s')
        return gradient, tot_loss/iteration, iteration
