"""
FedLabel: which model used for pseudo-labeling--local or global?
    This method is initially proposed for label-at-client setting. 
    - Both local and global models generate pseudo-labels for the unlabeled data, the one with higher
    confidence is selected as the candidate pseudo-label, which is further filtered through threshold.
    - The discarded pseudo-labels may contain some useful information if it is consistent with the candidate , 
    which is utilized through KL_divergence, weighted by the confidence ratio.
"""
import os
import copy
import time
import torch
import numpy as np
import torch.nn as nn
import torch.cuda as cuda
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.clip_grad import clip_grad_norm_
from algorithm.base import ClientBase, ServerBase
from utils import mixup_data, AverageMeter

class FedLabel(ServerBase):
    def __init__(self, args):
        super().__init__(args, Client)


class Client(ClientBase):
    def __init__(self, args, id, trainset):
        super().__init__(args, id, trainset)
        self.lamda = args.lamda


    def train(self, round_idx, lr, state_dict):
        ls_func = F.cross_entropy
        self.prepare(lr, state_dict)
        loader = DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True, drop_last=True)

        # no labeled data, initialize the loca model with global model
        global_model = copy.deepcopy(self.model)
        pl_meter, mask_meter, util_meter, = AverageMeter(), AverageMeter(), AverageMeter()
        pl_acc, mask_acc, util = [], [], []
        # semi-superivsed training
        for epoch in range(self.local_steps):
            self.model.train()
            for batch_idx, data in enumerate(loader):
                x, x_s, y = data['x'].to(self.device), data['x_s'].to(self.device), data['y'].to(self.device)
                # pseudo-labeling
                with torch.no_grad():
                    global_pred, local_pred = F.softmax(global_model(x), dim=1), F.softmax(self.model(x), dim=1)
                    global_c, global_py = torch.max(global_pred, dim=1)
                    local_c, local_py = torch.max(local_pred, dim=1)

                # choose the one with higher confidence
                mask = global_c.ge(local_c)
                py_1 = mask * global_py + (~mask) * local_py
                c_1 = mask * global_c + (~mask) * local_c

                py_2 = mask * local_py + (~mask) * global_py
                c_2 = mask * local_c + (~mask) * global_c
                pred_2 = global_pred
                pred_2[mask] = local_pred[mask]

                # thresholding
                mask = c_1.ge(self.threshold)  # torch.size([bs])
                logits = self.model(x_s)
                loss = (ls_func(logits, py_1, reduction='none') * mask.float()).mean() # formula (5)
                
                pl_meter.update(py_1.eq(y).float().mean().item(), y.shape[0])
                if any(mask):
                    mask_meter.update(py_1[mask].eq(y[mask]).float().mean().item(), mask.shape[0])
                util_meter.update(mask.float().mean().cpu().item(), y.shape[0])

                # kl-divergence
                mask = py_1.eq(py_2) * mask
                weight = c_2 / (c_1 + 1e-8) * mask.float() * self.lamda
                loss += (F.kl_div(F.softmax(logits, dim=1).log(), pred_2, reduction='none').sum(dim=-1) * weight).mean() # formula (6)(7)

                self.optimizer.zero_grad()
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.model.parameters(), self.clip_grad)
                self.optimizer.step()
            pl_acc.append(pl_meter.avg)
            pl_meter.reset() 
            mask_acc.append(mask_meter.avg)
            mask_meter.reset() 
            util.append(util_meter.avg)
            util_meter.reset()

        self.logs = {
            'pl_acc': np.array(pl_acc),
            'mask_acc': np.array(mask_acc),
            'util': np.array(util),
            'samples': len(self.trainset),
        }
        self.util = self.logs['util'].mean() * self.logs['samples']
        self.optimizer_dict = self.optimizer.state_dict()
        self.model.to('cpu')






                    
                

        
    