from copy import deepcopy
import os
dirname = os.path.dirname(__file__)
import sys
sys.path.append(os.path.join(dirname, '../'))

from utils import Worker_Trainer
import torch
from collections import defaultdict

class Worker(Worker_Trainer):
    def __init__(self, rank, args, model=None, train_data_loader=None, 
                 test_data_loader=None, multiprocessing=True, cpu=None, 
                 gpu=None):
        super().__init__(rank, args, model, train_data_loader, 
                         test_data_loader, multiprocessing, cpu, gpu)

        self.lr = self.args.lr
        self.lr_mask = self.args.lr_mask
        self.local_updates_args = {self.args.K_unit: self.args.K}
        self.global_prototypes = None

    def local_training(self, model:torch.nn.Module, model_size:float):
        cache_model = deepcopy(model.state_dict())
        client_model = model
        
        criterion = torch.nn.CrossEntropyLoss()
        tot_loss  = 0.0

        client_model.train()
        optimizer = torch.optim.SGD(client_model.parameters(), lr=self.lr)
        
        for iteration, (inputs, targets) in \
            enumerate(self.get_next_K_batch(**self.local_updates_args)):
            inputs, targets = inputs.to(self.gpu), targets.to(self.gpu)
            optimizer.zero_grad()
            outputs, embs = client_model(inputs)
            ce_loss = criterion(outputs, targets)
            loss = ce_loss
            
            if self.global_prototypes is not None:
                proto_loss = self.compute_prototype_loss(embs, targets)
                loss += self.args.c_lamda * proto_loss
            
            loss.backward()

            tot_loss = tot_loss + loss.data
            optimizer.step()
            
        iteration = iteration + 1
         
        gradient = {}
        for (name, cur_param), (_, cache_param) in \
            zip(client_model.state_dict().items(), cache_model.items()):
            gradient[name] = cache_param - cur_param
        
        print(f'Worker: {self.rank}\tIterations: {iteration}\t'
              f'Loss: {tot_loss/iteration}\tModel size: {model_size}')
        sys.stdout.flush()
        
        flat_gradient = self.flatten_gradient(gradient)

        features_by_class = defaultdict(list)
        client_model.eval()
        with torch.no_grad():
            for _, (inputs, targets) in enumerate(self.get_next_K_batch(**self.local_updates_args)):
                inputs, targets = inputs.to(self.gpu), targets.to(self.gpu)
                _, embs = client_model(inputs)  
                for feat, label in zip(embs, targets):
                    features_by_class[int(label.item())].append(feat.detach())

        local_prototypes = {label: torch.stack(feats).mean(dim=0) for label, feats in features_by_class.items()}
        
        return flat_gradient, tot_loss/iteration, iteration, local_prototypes