"""
Date: 2025.4.30
Algorithm: The first open-set semi-supervised federated learning
Details: 
    - OVA (One-vs-All) training for open-set detection (labeled data + confident negatives + positives);
    - Gradient stop
    - Imbalance margin for OVA training
    - Weak-strong logit consistency regularization
Reference: https://github.com/YUE-FAN/SSB
"""
import os
import math
import copy
import time
import torch
import numpy as np
import torch.nn as nn
import torch.cuda as cuda
import torch.nn.functional as F
from PIL import Image
from datasets import BasicDataset
from sklearn.metrics import balanced_accuracy_score, accuracy_score
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import DataLoader
from algorithm.base import ClientBase, ServerBase
from utils import AverageMeter, make_batchnorm_stats


class FedOpenMatch(ServerBase):

    def __init__(self, args):
        super().__init__(args, Client)
        self.lambda_mb = args.mb_loss_ratio
        self.margin = torch.zeros((2, 1), dtype=torch.float).to(self.device)
        self.tau = args.tau
        if self.tau:
            self.margin[0] = args.tau * math.log(float(self.num_classes - 1) / float(self.num_classes))
            self.margin[1] = args.tau * math.log(1. / float(self.num_classes))
        
    def make_model(self):
        model = super().make_model()
        return FedNet(model, self.num_classes).to(self.device)

    def training_stats(self, round_idx, stats):
        samples, open_set_util, open_set_acc, conf_inlier_util, conf_inlier_acc = 0, 0, 0, 0, 0
        for stat in stats:
            sample = stat['samples']
            ova_sample = stat['open_set_util'] * sample
            open_set_util += ova_sample
            open_set_acc += stat['open_set_acc'] * ova_sample
            conf_inlier_sample = stat['conf_inlier_util'] * sample
            conf_inlier_util += conf_inlier_sample
            conf_inlier_acc += stat['conf_inlier_acc'] * conf_inlier_sample
            samples += sample 

        open_set_acc = open_set_acc / (open_set_util + 1e-9) 
        open_set_util = open_set_util / (samples + 1e-9) 
        conf_inlier_acc = conf_inlier_acc / (conf_inlier_util + 1e-9)
        conf_inlier_util = conf_inlier_util / (samples + 1e-9)

        self.logger.log({
            'conf_inlier_util': conf_inlier_util, 
            'conf_inlier_acc': conf_inlier_acc,
            'ova_util': open_set_util,
            'ova_acc': open_set_acc,
        }, step=round_idx)

        
    def warmup(self):
        self.global_model.train(True)
        for epoch in range(self.warmup_epochs):
            for i, data in enumerate(self.train_loader):
                x, y = data['x'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                logits = self.global_model(x)['logits']
                loss = self.ce_loss(logits, y)
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.global_model.parameters(), self.clip_grad)
                self.optimizer.step()

    def run(self):
        
        self.warmup()
        if self.sBN:
            make_batchnorm_stats(self.batchnorm_dataset, self.global_model, self.device)
        for round_idx in range(self.global_rounds):
            st_time = time.time()
            self.selected_clients = self.select_clients(round_idx)
            self.selected_clients.sort()
            lr = self.scheduler.get_last_lr()[0]
            model_dict = self.global_model.state_dict()
            stats = []
            for id in self.selected_clients:
               client = self.clients[id]
               stat = client.train(round_idx, lr, model_dict)
               stats.append(stat)
            
            self.aggregate_models(round_idx)
            self.training_stats(round_idx, stats)
            self.train(round_idx)
            if self.sBN:
                make_batchnorm_stats(self.batchnorm_dataset, self.global_model, self.device)
            cuda.empty_cache() 
            self.evaluate(round_idx)
            self.printer.info(f'{round_idx}/{self.global_rounds} cost: {(time.time() - st_time) / 60:.2f} min')
            self.printer.info('-' * 30)


    def train(self, round):

        self.global_model.train(True)
        ce_loss_meter = AverageMeter()
        acc_meter = AverageMeter()
        for epoch in range(self.local_steps):
            for i, data in enumerate(self.train_loader):
                x, xs, y = data['x'].to(self.device), data['x_s'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                x = torch.cat([x, xs], dim=0)
                y = torch.cat([y, y], dim=0)
                outputs = self.global_model(x)
                logits = outputs['logits']
                logits_ova = outputs['logits_ova'].view(logits.size(0), 2, -1)
                sup_closed_loss = self.ce_loss(logits, y)
                if self.tau:  ## apply 
                    logits_ova += self.margin
                sup_mb_loss = self.lambda_mb * mb_sup_loss(logits_ova, y)
                loss = sup_closed_loss + sup_mb_loss
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.global_model.parameters(), self.clip_grad)
                self.optimizer.step()
                ce_loss_meter.update(loss.item(), y.shape[0])
                acc = (logits.argmax(dim=1) == y).float().mean().item()
                acc_meter.update(acc, y.shape[0])
        self.scheduler.step()
        self.logger.log({'train_loss': ce_loss_meter.avg}, step=round)


    @torch.no_grad()
    def test(self, loader, round=0):
        # close set tes
        self.printer.debug(f'-----------------testing-----------------')
        all_y, all_logits = [], []
        self.global_model.train(False)
        self.global_model.to(self.device)
        for data in loader:   
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            logits = self.global_model(x)['logits']
            all_y.append(y)
            all_logits.append(logits)
        y = torch.cat(all_y, dim=0)
        logits = torch.cat(all_logits, dim=0)
        test_acc = (y == torch.argmax(logits, dim=1)).float().mean().item()
        return test_acc
    

    @torch.no_grad()
    def test_openset(self, loader, round=0, log=True):
        # open-set test
        self.printer.debug(f'-----------------testing-----------------')
        all_y, all_logits, all_logits_op = [], [], []
        self.global_model.train(False)
        self.global_model.to(self.device)
        for data in loader:   
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            outputs = self.global_model(x)
            logits, logits_op = outputs['logits'], outputs['logits_ova']
            all_y.append(y)
            all_logits.append(logits)
            all_logits_op.append(logits_op)
        y = torch.cat(all_y, dim=0)
        logits = torch.cat(all_logits, dim=0)
        logits_ova = torch.cat(all_logits_op, dim=0)
        logits_ova = F.softmax(logits_ova.view(logits_ova.size(0), 2, -1), 1)
        inlier_pred = logits.argmax(dim=1)
        outlier_score = logits_ova[torch.arange(0, inlier_pred.size(0)).long().to(self.device), 0, inlier_pred]
        ood_mask = outlier_score > 0.5
        seen_idxs = y < self.num_classes
        close_set_acc = accuracy_score(y[seen_idxs].cpu().numpy(), inlier_pred[seen_idxs].cpu().numpy()) * 100
        inlier_pred[ood_mask] = self.num_classes
        open_set_acc = balanced_accuracy_score(y.cpu().numpy(), inlier_pred.cpu().numpy()) * 100
        return open_set_acc, close_set_acc


class Client(ClientBase):
    def __init__(self, args, id, trainset):
        super().__init__(args, id, trainset)
        self.p_cutoff_pos = args.p_cutoff_pos
        self.p_cutoff_neg = args.p_cutoff_neg
        self.cr_weight = args.cr_weight
        
    def make_model(self):
        model = super().make_model()
        return FedNet(model, self.num_classes).to(self.device)
    
    @torch.no_grad()
    def pseudo_labeling(self):
        # generate pseudo labels for inlier classifier and outlier classifier
        self.model.train(False)
        loader = DataLoader(self.trainset, batch_size=128, shuffle=False, num_workers=4)
        all_idx, all_logits, all_logits_ova, all_y = [], [], [], []
        for data in loader:
            idx, x, y = data['idx'], data['x'].to(self.device), data['y'].to(self.device)
            outputs = self.model(x)
            logits = outputs['logits']
            logits_ova = outputs['logits_ova']
            all_idx.append(idx)
            all_logits.append(logits)
            all_logits_ova.append(logits_ova)
            all_y.append(y)
        all_idx = torch.cat(all_idx, dim=0)
        all_logits = torch.cat(all_logits, dim=0)
        all_logits_ova = torch.cat(all_logits_ova, dim=0)
        all_y = torch.cat(all_y, dim=0)
        logits = F.softmax(all_logits, dim=1)
        max_prob, pseudo_labels = logits.max(dim=1)
        fix_mask = max_prob.ge(self.threshold)  

        # pseudo_labels for ood detector
        logits_ova = F.softmax(all_logits_ova.view(all_logits_ova.size(0), 2, -1), 1)
        max_prob_ova, pseudo_labels_ova = logits_ova[:, 1, :].max(dim=1)

        # confident_inlier_acc
        in_score = logits_ova[torch.arange(0, pseudo_labels.size(0)).long().to(self.device), 1, pseudo_labels]
        conf_in_mask = in_score.ge(0.5) & fix_mask
        conf_in_util = conf_in_mask.float().mean().cpu().item()
        if conf_in_mask.any():
            conf_in_acc = (pseudo_labels[conf_in_mask] == all_y[conf_in_mask]).float().mean().cpu().item()
        else:
            conf_in_acc = 0.0

        pos_mask = max_prob_ova.ge(self.p_cutoff_pos).unsqueeze(1)  # confident positive samples
        neg_mask = logits_ova[:, 1, :].le(self.p_cutoff_neg)        # confident negative samples
        ova_mask = (pos_mask | neg_mask)

        pseudo_labels_oh = F.one_hot(pseudo_labels_ova.long(), num_classes=self.num_classes).to(max_prob_ova.dtype)
        pseudo_labels_oh *= pos_mask.to(max_prob_ova.dtype)
        
        
        fix_dataset = MyDataset(copy.deepcopy(self.trainset))
        fix_dataset.data = [fix_dataset.data[i] for i in all_idx]
        fix_dataset.targets = [fix_dataset.targets[i] for i in all_idx]
        fix_dataset.fix_mask = conf_in_mask
        fix_dataset.ova_mask = ova_mask
        fix_dataset.pseudo_labels = pseudo_labels
        fix_dataset.ova_pl = pseudo_labels_oh

        # open set accuracy
        open_set_util = ova_mask.float().mean().cpu().item()
        if ova_mask.any():
            ova_y = torch.zeros((all_y.shape[0], self.num_classes)).to(self.device)
            ova_y[all_y < self.num_classes, all_y[all_y < self.num_classes]] = 1
            open_set_acc = (ova_y[ova_mask] == pseudo_labels_oh[ova_mask]).float().mean().cpu().item()
        else:
            open_set_acc = 0.0
        # overall accuracy
        acc = (all_y == pseudo_labels).float().mean()
        self.printer.info(f'[{self.id}] pl_acc: {acc:.4f}, conf_in_util: {conf_in_util:.4f}, conf_in_acc: {conf_in_acc:.4f}, open_set_util: {open_set_util:.4f}, open_set_acc: {open_set_acc:.4f}')

        return fix_dataset, {
            'pl_acc': acc,
            'conf_inlier_util': conf_in_util,
            'conf_inlier_acc': conf_in_acc,
            'open_set_util': open_set_util,
            'open_set_acc': open_set_acc,
            'samples': len(all_y)
        }
        
    def train(self, round_idx, lr, state_dict):
        """
        in this train function, client only:
            1.train the model using local dataset;
            2.send the model to server;
        rewrite this function if you wanna use a more complex algorithm
        """
        self.prepare(lr, state_dict)
        dataset, stats = self.pseudo_labeling()
        if dataset is not None:
            loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
            self.util = True
            self.model.train(True)
            for step in range(self.local_steps):
                for i, data in enumerate(loader):
                    x_w, x_s, pl, ova_pl, fix_mask, ova_mask = data['x'].to(self.device), data['x_s'].to(self.device), data['py'].to(self.device), \
                                data['ova_y'].to(self.device), data['fix_m'].to(self.device), data['ova_m'].to(self.device)  
                    
                    self.optimizer.zero_grad()
                    with torch.no_grad():
                        out = self.model(x_w)
                        logits_w, logits_ova_w = out['logits'], out['logits_ova']
                    outputs = self.model(x_s)
                    logits_s = outputs['logits']
                    logits_ova_s = outputs['logits_ova']
                    logits_ova_pos = logits_ova_s.view(logits_ova_s.size(0), 2, -1).softmax(1)[:, 1, :]
                
                    cr_ls =  F.mse_loss(logits_w, logits_s) * 0.5 + F.mse_loss(logits_ova_w, logits_ova_s) * 0.5
                    loss = cr_ls * self.cr_weight
                    if fix_mask.any():
                        loss += F.cross_entropy(logits_s[fix_mask], pl[fix_mask], reduction='mean')
                    if ova_mask.any():
                        loss += F.binary_cross_entropy(logits_ova_pos[ova_mask], ova_pl[ova_mask].float(), reduction='mean') * self.num_classes
                    loss.backward()
                    if self.clip_grad > 0:
                        clip_grad_norm_(self.model.parameters(), self.clip_grad)
                    self.optimizer.step()
        self.optimizer_dict = self.optimizer.state_dict()
        self.model.to('cpu')
        return stats

def mb_sup_loss(logits_ova, label):
    batch_size, num_classes = logits_ova.size(0), logits_ova.size(2)
    probs_ova = F.softmax(logits_ova, 1)
    label_s_sp = torch.zeros((batch_size, num_classes)).to(label.device)
    label_range = torch.arange(0, batch_size).long().to(label.device)
    # set the multi-binary label for inliers
    label_s_sp[label_range[label < num_classes], label[label < num_classes]] = 1.
    label_sp_neg = 1 - label_s_sp
    open_loss = torch.mean(torch.sum(-torch.log(probs_ova[:, 1, :] + 1e-8) * label_s_sp, 1))
    open_loss_neg = torch.mean(torch.max(-torch.log(probs_ova[:, 0, :] + 1e-8) * label_sp_neg, 1)[0])
    return open_loss_neg + open_loss


class FedNet(nn.Module):
    def __init__(self, base, num_classes):
        super(FedNet, self).__init__()
        self.backbone = base
        feat_planes = base.fc.in_features
        
        self.ova_feat = nn.Sequential(nn.Linear(feat_planes, feat_planes),
                                    nn.ReLU(),
                                    nn.Linear(feat_planes, feat_planes),
                                    nn.ReLU())
        self.fc_feat = nn.Sequential(nn.Linear(feat_planes, feat_planes),
                                    nn.ReLU(),
                                    nn.Linear(feat_planes, feat_planes),
                                    nn.ReLU())
        self.ova = nn.Linear(feat_planes, 2 * num_classes, bias=False)
        self.fc = nn.Linear(feat_planes, num_classes, bias=False)
        
        nn.init.xavier_normal_(self.fc.weight.data)
        nn.init.xavier_normal_(self.ova.weight.data)
               

    def forward(self, x, **kwargs):
        feat = self.backbone.get_feature(x)
        logits = self.fc(self.fc_feat(feat)) 
        logits_ova = self.ova(
            self.ova_feat(feat.detach())   ## gradient stop
        )  

        return_dict = {
            'logits': logits, 
            'logits_ova': logits_ova
        }
        return return_dict

    def group_matcher(self, coarse=False):
        matcher = self.backbone.group_matcher(coarse, prefix='backbone.')
        return matcher
    

 
class MyDataset(BasicDataset):
    def __init__(self, basicset):
        super(MyDataset, self).__init__(basicset.data_name, basicset.data, basicset.targets, basicset.classes, basicset.is_train)
        x = len(basicset)
        self.fix_mask = torch.ones(x).bool()
        self.ova_mask = torch.ones(x).bool()
        self.ova_pl = torch.ones(x)
        self.pseudo_labels = torch.ones(x)

    
    def __sample__(self, idx):
        """ dataset specific sample function """
        # set idx-th target
        if self.targets is None:
            target = None
        else:
            target = self.targets[idx]

        if self.pseudo_labels is None:
            pseudo_labels = None
        else:
            pseudo_labels = self.pseudo_labels[idx]

        if self.ova_pl is None:
            ova_pl = None
        else:
            ova_pl = self.ova_pl[idx]
        
        if self.fix_mask is None:
            fix_mask = None
        else: 
            fix_mask = self.fix_mask[idx]

        if self.ova_mask is None:
            ova_mask = None
        else:
            ova_mask = self.ova_mask[idx]

        # set augmented images
        img = self.data[idx]
        return img, target, pseudo_labels, ova_pl, fix_mask, ova_mask

    def __getitem__(self, idx):

        img, target, pseudo_label, ova_pl, fix_mask, ova_mask = self.__sample__(idx)

        if len(img.shape) == 2:
            img = np.stack([img] * 3, axis=0)
        if self.data_name in ['CIFAR10', 'CIFAR100']:
            img = Image.fromarray(img)
        else: 
            img = Image.fromarray(np.transpose(img, (1,2,0)))

        return {'idx': idx, 'x': self.transform(img), 'x_s': self.strong(img), 'y': target, 'py': pseudo_label, 'ova_y': ova_pl, 'fix_m': fix_mask, 'ova_m': ova_mask }   
 

    def __len__(self):
        return len(self.data)