"""
Algorithm: adapt SSB to SSFL setting
Details:
    - OpenMatch + ova negative mining
    - Feature disentanglement (feature projection): seperate the features used for inlier classification and outlier detection;
Notes:
    - In the original paper, the training consists of two stages: 
        - train the inlier classifier with labeled data and confident unlabeled data, 
        - train the OVA classifier with labeled data and confident unlabeled data. 
    But we got a bad performance in this way. So in our implementation, we train the inlier classifier and OVA classifier simultaneously from scratch.

Reference: https://github.com/YUE-FAN/SSB
"""
import os
import copy
import time
import torch
import numpy as np
import torch.nn as nn
import torch.cuda as cuda
import torch.nn.functional as F
from sklearn.metrics import balanced_accuracy_score, accuracy_score
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import DataLoader
from algorithm.base import ClientBase, ServerBase
from utils import AverageMeter



class SSB(ServerBase):

    def __init__(self, args):
        super().__init__(args, Client)
        self.lambda_mb = args.mb_loss_ratio
        
    def make_model(self):
        model = super().make_model()
        return SSBNet(model, self.num_classes).to(self.device)
    
    def training_stats(self, round_idx):
        fix_util, ood_util, ova_neg_util, ova_neg_prec, fix_acc = 0, 0, 0, 0, 0
        samples = 0
        for id in self.selected_clients:
            logs = self.clients[id].logs
            s = logs['samples']
            samples += s
            fix_util += np.array(logs['fix_utils']).mean() * s
            fix_acc += np.array(logs['fix_accs']).mean() * s
            ood_util += np.array(logs['ood_utils']).mean() * s
            print_log = f'client{id}:'
            ova_neg_util += np.array(logs['ova_neg_utils']).mean() * s
            ova_neg_prec += np.array(logs['ova_neg_precs']).mean() * s
            for i in range(self.local_steps):
                print_log += f'fix_u={logs["fix_utils"][i]:.2f},fix_acc={logs["fix_accs"][i]:.2f},ood_u={logs["ood_utils"][i]:.2f},ova_neg_u={logs["ova_neg_utils"][i]:.2f},ova_neg_acc={logs["ova_neg_precs"][i]:.2f}\n'
            self.printer.info(print_log)
        self.logger.log({'fix_util': fix_util / samples}, step=round_idx)
        self.logger.log({'ood_util': ood_util / samples}, step=round_idx)
        self.logger.log({'fix_acc': fix_acc / samples}, step=round_idx)
        

    def warmup(self):
        self.global_model.train(True)
        for epoch in range(self.warmup_epochs):
            for i, data in enumerate(self.train_loader):
                x, y = data['x'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                logits = self.global_model(x)['logits']
                loss = self.ce_loss(logits, y)
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.global_model.parameters(), self.clip_grad)
                self.optimizer.step()
        

    def train(self, round):
        st = time.time()
        self.global_model.train(True)
        ce_loss_meter, ova_loss_meter = AverageMeter(), AverageMeter()
        acc_meter = AverageMeter()
        for epoch in range(self.local_steps):
            for i, data in enumerate(self.train_loader):
                x, y = data['x'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                outputs = self.global_model(x)
                logits = outputs['logits']
                logits_ova = outputs['logits_ova']
                ce_loss = self.ce_loss(logits, y)
                # open-set detection
                ova_loss = self.lambda_mb * mb_sup_loss(logits_ova, y)
                ova_loss_meter.update(ova_loss.item(), y.shape[0])
                loss = ce_loss + ova_loss
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.global_model.parameters(), self.clip_grad)
                self.optimizer.step()
                ce_loss_meter.update(ce_loss.item(), y.shape[0])
                acc = (logits.argmax(dim=1) == y).float().mean().item()
                acc_meter.update(acc, y.shape[0])
                # self.printer.info(f"server train epoch {epoch}, iter {i}, loss {loss.item():.4f}, acc {acc * 100:.2f}")
        self.scheduler.step()
        self.logger.log({'ce_loss': ce_loss_meter.avg}, step=round)
        self.logger.log({'ova_loss': ova_loss_meter.avg}, step=round)
        self.logger.log({'server@train_acc': acc_meter.avg * 100}, step=round)
        self.printer.info(f"server train cost {(time.time() - st) / 60:.2f} min")


    @torch.no_grad()
    def test(self, loader, round=0):
        # close set tes
        self.printer.debug(f'-----------------testing-----------------')
        all_y, all_logits = [], []
        self.global_model.train(False)
        self.global_model.to(self.device)
        for data in loader:   
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            logits = self.global_model(x)['logits']
            all_y.append(y)
            all_logits.append(logits)
        y = torch.cat(all_y, dim=0)
        logits = torch.cat(all_logits, dim=0)
        test_acc = (y == torch.argmax(logits, dim=1)).float().mean().item() * 100
        return test_acc
    

    @torch.no_grad()
    def test_openset(self, loader, round=0):
        self.printer.debug(f'-----------------open-set testing-----------------')
        all_y, all_logits, all_logits_ova = [], [], []
        self.global_model.train(False)
        self.global_model.to(self.device)
        for data in loader:   
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            outputs = self.global_model(x)
            logits, logits_ova = outputs['logits'], outputs['logits_ova']
            all_y.append(y)
            all_logits.append(logits)
            all_logits_ova.append(logits_ova)
        y = torch.cat(all_y, dim=0)
        logits = torch.cat(all_logits, dim=0)
        logits_ova = torch.cat(all_logits_ova, dim=0)
        logits_ova = F.softmax(logits_ova.view(logits_ova.size(0), 2, -1), 1)
        inlier_pred = logits.argmax(dim=1)
        outlier_score = logits_ova[torch.arange(0, inlier_pred.size(0)).long().to(self.device), 0, inlier_pred]
        ood_mask = outlier_score > 0.5
        seen_idxs = y < self.num_classes
        close_set_acc = accuracy_score(y[seen_idxs].cpu().numpy(), inlier_pred[seen_idxs].cpu().numpy()) * 100
        inlier_pred[ood_mask] = self.num_classes
        open_set_acc = balanced_accuracy_score(y.cpu().numpy(), inlier_pred.cpu().numpy()) * 100

        return open_set_acc, close_set_acc


class Client(ClientBase):
    def __init__(self, args, id, trainset):
        super().__init__(args, id, trainset)
        self.lambda_oem = 0.1
        self.lambda_cr = 1.0 if args.dataset == 'CIFAR100' else 0.5
        self.ova_threshold = args.p_cutoff_neg
        self.train_loader = DataLoader(self.trainset, batch_size=args.c_batch_size, shuffle=True)

    def make_model(self):
        model = super().make_model()
        return SSBNet(model, self.num_classes).to(self.device)
    

    def train_fix(self, round_idx, lr, state_dict):
        '''
        train the local model with fixed pseudo-labels
        '''
        self.prepare(lr, state_dict)
        self.model.train(True)
    

    def train(self, round_idx, lr, state_dict):
        self.prepare(lr, state_dict)
        self.model.train(True)
        fix_util, ood_util, ova_neg_util, ova_neg_prec, fix_prec = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
        fix_utils, ood_utils, ova_neg_utils, ova_neg_precs, fix_precs = [], [], [], [], []
        loss_meter = AverageMeter()
        losses = []
        
        loader = DataLoader(self.trainset, batch_size=self.args.c_batch_size, shuffle=True)
        loader_all = DataLoader(self.trainset, batch_size=self.args.c_batch_size, shuffle=True)
        for epoch in range(self.local_steps):
            for data, data_all in zip(loader, loader_all):
                self.optimizer.zero_grad()
                x, x_s, y = data['x'].to(self.device), data['x_s'].to(self.device), data['y'].to(self.device)
                outputs = self.model(torch.cat([x, x_s], dim=0))
                logits, logits_s = outputs['logits'].chunk(2)
                
                x_all, x_s_all, y_all = data_all['x'].to(self.device), data_all['x_s'].to(self.device), data_all['y'].to(self.device)
                outputs_all = self.model(torch.cat([x_all, x_s_all], dim=0))
                logits_ova1, logits_ova2 = outputs_all['logits_ova'].chunk(2)
                
                # unlabeled OVA loss starts
                with torch.no_grad():
                    logits_ova1 = F.softmax(logits_ova1.view(logits_ova1.size(0), 2, -1), 1)
                    inlier_score = logits_ova1[:, 1, :]
                    neg_mask = (inlier_score <= self.ova_threshold).float()
                    ova_neg_util.update(neg_mask.mean().item(), inlier_score.shape[0] * inlier_score.shape[1])
                    
                    gt_mask = torch.zeros(inlier_score.size(0), self.num_classes+1).to(self.device)
                    gt_mask.scatter_(1, y.view(-1, 1), 1)
                    gt_neg = 1 - gt_mask[:, :-1]

                    if neg_mask.sum() > 0:
                        prec = ((neg_mask == gt_neg) * neg_mask).sum() / neg_mask.sum()
                        ova_neg_prec.update(prec.item(), neg_mask.sum().item())
                ls_ova = torch.mean((-F.log_softmax(logits_ova2.view(logits_ova2.size(0), 2, -1), dim=1)[:, 0, :] * neg_mask).sum(dim=1))


                # entropy minimization for ova_logits
                ls_oem = 0.5 * (ova_ent(logits_ova1) + ova_ent(logits_ova2)) * self.lambda_oem

                logits_ova1, logits_ova2 = F.softmax(logits_ova1.view(logits_ova1.size(0), 2, -1), 1), F.softmax(logits_ova2.view(logits_ova2.size(0), 2, -1), 1)
 
                # consistency regularization
                ls_ova_cr = torch.mean(torch.sum(torch.sum(torch.abs(logits_ova1 - logits_ova2)**2, 1), 1)) * self.lambda_cr     
                         
                # inlier classifier
                with torch.no_grad():
                    pseudo_label = torch.softmax(logits, dim=-1)
                    max_probs, pl = torch.max(pseudo_label, dim=-1)
                    fix_mask = max_probs.ge(self.threshold).float()
                    ood_mask = (y == self.num_classes).float()
                    ood_util.update((ood_mask * fix_mask).mean().item(), y.shape[0])
                    fix_util.update(fix_mask.mean().item(), y.shape[0])
                    if fix_mask.any():
                        fix_prec.update(pl[fix_mask.bool()].eq(y[fix_mask.bool()]).float().mean().item(), fix_mask.sum().item())

                ls_fc = (F.cross_entropy(logits_s, pl, reduction='none') * fix_mask).mean()
                
                loss = ls_fc + ls_ova_cr + ls_oem + ls_ova
                loss.backward()
                self.optimizer.step()
                loss_meter.update(loss.item(), y.shape[0])
            
            fix_precs.append(fix_prec.avg)
            fix_utils.append(fix_util.avg)
            ood_utils.append(ood_util.avg)
            ova_neg_utils.append(ova_neg_util.avg)
            ova_neg_precs.append(ova_neg_prec.avg)
            fix_util.reset()
            ood_util.reset()
            ova_neg_util.reset()
            ova_neg_prec.reset()
            fix_prec.reset()
            losses.append(loss_meter.avg)
            loss_meter.reset()

        self.logs = {
            'fix_utils': fix_utils,
            'fix_accs': fix_precs,
            'ood_utils': ood_utils,
            'ova_neg_utils': ova_neg_utils,
            'ova_neg_precs': ova_neg_precs,
            "samples": len(self.trainset)
        }
        self.optimizer_dict = self.optimizer.state_dict()
        self.model.to('cpu')
        self.util = 1

class SSBNet(nn.Module):
    def __init__(self, base, num_classes):
        super(SSBNet, self).__init__()
        self.backbone = base
        feat_planes = base.fc.in_features
        self.ova_feat = nn.Sequential(nn.Linear(feat_planes, feat_planes),
                                    nn.ReLU(),
                                    nn.Linear(feat_planes, feat_planes),
                                    nn.ReLU())
        self.fc_feat = nn.Sequential(nn.Linear(feat_planes, feat_planes),
                                    nn.ReLU(),
                                    nn.Linear(feat_planes, feat_planes),
                                    nn.ReLU())
        self.fc = nn.Linear(feat_planes, num_classes, bias=False)
        self.ova = nn.Linear(feat_planes, 2 * num_classes, bias=False)
        nn.init.xavier_normal_(self.fc.weight.data)
        nn.init.xavier_normal_(self.ova.weight.data)


    def forward(self, x, **kwargs):
        feat = self.backbone.get_feature(x)
        logits = self.fc(self.fc_feat(feat)) 
        logits_ova = self.ova(
            self.ova_feat(feat)
        )  

        results = {
            'logits': logits, 
            'logits_ova': logits_ova
        }
        return results

    def group_matcher(self, coarse=False):
        matcher = self.backbone.group_matcher(coarse, prefix='backbone.')
        return matcher


# # Reference: https://github.com/VisionLearningGroup/OP_Match/blob/main/utils/misc.py
def mb_sup_loss(logits_ova, label):
    batch_size = logits_ova.size(0)
    logits_ova = logits_ova.view(batch_size, 2, -1)
    num_classes = logits_ova.size(2)
    probs_ova = F.softmax(logits_ova, 1)
    label_s_sp = torch.zeros((batch_size, num_classes)).long().to(label.device)
    label_range = torch.arange(0, batch_size).long().to(label.device)
    label_s_sp[label_range[label < num_classes], label[label < num_classes]] = 1
    label_sp_neg = 1 - label_s_sp
    open_loss = torch.mean(torch.sum(-torch.log(probs_ova[:, 1, :] + 1e-8) * label_s_sp, 1))
    open_loss_neg = torch.mean(torch.max(-torch.log(probs_ova[:, 0, :] + 1e-8) * label_sp_neg, 1)[0])
    l_ova_sup = open_loss_neg + open_loss
    return l_ova_sup

def ova_ent(logits_open):
    logits_open = logits_open.view(logits_open.size(0), 2, -1)
    logits_open = F.softmax(logits_open, 1)
    Le = torch.mean(torch.mean(torch.sum(-logits_open *
                                   torch.log(logits_open + 1e-8), 1), 1))
    return Le