"""
SemiFL Nips22: FedAvg + FixMatch + server fine-tuning + fixing pseudo labels
2023.5.16 
"""
import os
import copy
import time
import torch
import numpy as npss
import torch.nn as nn
import torch.cuda as cuda
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import DataLoader
from algorithm.base import ClientBase, ServerBase
from datasets import MixDataset
class SemiFL(ServerBase):
    def __init__(self, args):
        super().__init__(args, Client)

    def training_stats(self, round_idx):
        pass
        

class Client(ClientBase):
    def __init__(self, args, id, trainset):
        super().__init__(args, id, trainset)
        if self.mixup:  # mixup data augmentation
            self.beta = torch.distributions.beta.Beta(torch.tensor(0.75), torch.tensor(0.75)) # type: ignore

    def make_pseudo_dataset(self):
        logits, labels, all_idx = [], [], []
        data_loader = DataLoader(self.trainset, batch_size=256)
        mix_dataset, fix_dataset = None, None
        self.model.train(False)
        with torch.no_grad():
            for i, data in enumerate(data_loader):
                idx, xw, y = data['idx'], data['x'], data['y']
                all_idx.append(idx)
                xw, y = xw.to(self.device), y.to(self.device)
                labels.append(y)
                logit = self.model(xw)
                logits.append(logit)
            idx = torch.cat(all_idx, dim=0)
            logits = torch.cat(logits, dim=0)
            labels = torch.cat(labels, dim=0)
            logits = F.softmax(logits, dim=1)
            max_prob, pseudo_labels = logits.max(dim=1)
            mask = max_prob.ge(self.threshold)


        # compute the confusion matrix
        stat_classes = self.trainset.classes
        cm = confusion_matrix(labels.cpu().numpy(), pseudo_labels.cpu().numpy(), labels=range(stat_classes))
        self.cm = cm
        if torch.any(mask):
            mask_cm = confusion_matrix(labels[mask].cpu().numpy(), pseudo_labels[mask].cpu().numpy(), labels=range(stat_classes))
            self.mask_cm = mask_cm
        else:
            self.mask_cm = np.zeros((stat_classes, stat_classes), dtype=int)
        
        if torch.any(mask):
            fix_dataset = copy.deepcopy(self.trainset)
            mask_idx = idx[mask]
            fix_dataset.data = [fix_dataset.data[i] for i in mask_idx.tolist()]
            fix_dataset.targets = [fix_dataset.targets[i] for i in mask_idx.tolist()]
            fix_dataset.pseudo_labels = pseudo_labels[mask].tolist()
            if self.mixup:
                mix_dataset = copy.deepcopy(self.trainset)
                mix_dataset.data = [mix_dataset.data[i] for i in idx.tolist()]
                mix_dataset.targets = [mix_dataset.targets[i] for i in idx.tolist()]
                mix_dataset.pseudo_labels = pseudo_labels.tolist()
                mix_dataset = MixDataset(len(fix_dataset), mix_dataset)
        
        return fix_dataset, mix_dataset
            
    def train(self, round_idx, lr, state_dict):
        ls_func = F.cross_entropy
        self.prepare(lr, state_dict)
        fix_set, mix_set = self.make_pseudo_dataset()
        if fix_set is not None and mix_set is not None:
            fix_loader = DataLoader(fix_set, batch_size=self.batch_size, shuffle=True)
            mix_loader = DataLoader(mix_set, batch_size=self.batch_size, shuffle=True)
            for step in range(self.local_steps):
                self.model.train(True)
                for i, (fix_data, mix_data) in enumerate(zip(fix_loader, mix_loader)):
                    fix_w, fix_s, fix_y = fix_data['x'].to(self.device), fix_data['x_s'].to(self.device), fix_data['py'].to(self.device)
                    mix_w, mix_y = mix_data['x'].to(self.device), mix_data['py'].to(self.device)
                    lam = self.beta.sample().item()
                    mix_data = (lam * fix_w + (1 - lam) * mix_w).detach()
                    self.optimizer.zero_grad()
                    logits_s = self.model(fix_s)
                    logits_mix = self.model(mix_data)
                    ce_loss = ls_func(logits_s, fix_y)
                    mix_loss = lam * ls_func(logits_mix, fix_y) + (1 - lam) * ls_func(logits_mix, mix_y)
                    loss = ce_loss + mix_loss
                    loss.backward()
                    if self.clip_grad > 0:
                        clip_grad_norm_(self.model.parameters(), self.clip_grad)
                    self.optimizer.step()

        elif fix_set is not None: 
            fix_loader = DataLoader(fix_set, batch_size=self.batch_size, shuffle=True)
            for step in range(self.local_steps):
                self.model.train(True)
                for i, data in enumerate(fix_loader):
                    fix_s, fix_y = data['x_s'].to(self.device), data['py'].to(self.device)
                    y = data['y'].to(self.device)
                    self.optimizer.zero_grad()
                    logits_s = self.model(fix_s)
                    loss = ls_func(logits_s, fix_y)
                    loss.backward()
                    if self.clip_grad > 0:
                        clip_grad_norm_(self.model.parameters(), self.clip_grad)
                    self.optimizer.step()

        self.util = len(fix_set) if fix_set is not None else 0
        self.optimizer_dict = self.optimizer.state_dict()
        self.model.to('cpu')