"""
Algorithm: adapt OpenMatch to SSFL setting
Details:
    - OVA (One-vs-All) for open-set detection: only use labeled samples for ove_loss
    - weak-strong consistency regularization for ova classifier
    - entropy minimization for ova classifier
    - FixMatch style training for inlier classifier, exclude outliers (main difference from SSB)
Reference: https://github.com/VisionLearningGroup/OP_Match
"""
import os
import copy
import time
import torch
import numpy as np
import torch.nn as nn
import torch.cuda as cuda
import torch.nn.functional as F
from sklearn.metrics import balanced_accuracy_score, accuracy_score
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import DataLoader
from algorithm.base import ClientBase, ServerBase
from utils import AverageMeter



class OPMatch(ServerBase):

    def __init__(self, args):
        super().__init__(args, Client)
        self.lambda_mb = args.mb_loss_ratio
        
    def make_model(self):
        model = super().make_model()
        return OPNet(model, self.num_classes).to(self.device)
    
    def training_stats(self, round_idx):
        fix_utils, ood_utils, fix_accs, filter_accs = 0, 0, 0, 0
        samples, in_samples, true_inliers = 0, 0, 0
        for id in self.selected_clients:
            logs = self.clients[id].logs
            s = logs['samples']
            samples += s
            inlier_sample = np.array(logs['id_ratio']).mean() * s
            true_inlier = np.array(logs['true_id_ratio']).mean() * s
            true_inliers += true_inlier
            filter_accs += np.array(logs['filter_acc']).mean() * s
            in_samples += inlier_sample
            fix_util = np.array(logs['fix_utils']).mean() * inlier_sample
            fix_utils += fix_util
            fix_accs += np.array(logs['fix_accs']).mean() * fix_util
            ood_utils += np.array(logs['ood_utils']).mean() * inlier_sample
            print_log = f'client{id}:'
            for i in range(self.local_steps):
                print_log += f'fix_u={logs["fix_utils"][i]:.2f},fix_acc={logs["fix_accs"][i]:.2f},ood_u={logs["ood_utils"][i]:.2f}\n'
            self.printer.info(print_log)
        self.logger.log({
            'fix_util': fix_utils / (in_samples + 1e-9),
            'ood_util': ood_utils / (in_samples + 1e-9),
            'fix_acc': fix_accs / (fix_utils + 1e-9),
            'filter_acc': filter_accs / (samples + 1e-9),
            'inlier_used': in_samples/ (samples + 1e-9),
            'inlier_ratio': true_inliers / (samples + 1e-9),
        }, step=round_idx)


        

    def warmup(self):
        self.global_model.train(True)
        for epoch in range(self.warmup_epochs):
            for i, data in enumerate(self.train_loader):
                x, y = data['x'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                logits = self.global_model(x)['logits']
                loss = self.ce_loss(logits, y)
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.global_model.parameters(), self.clip_grad)
                self.optimizer.step()
        

    def train(self, round):
        st = time.time()
        self.global_model.train(True)
        ce_loss_meter, ova_loss_meter = AverageMeter(), AverageMeter()
        for epoch in range(self.local_steps):
            for i, data in enumerate(self.train_loader):
                x, xs, y = data['x'].to(self.device), data['x_s'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                x, y = torch.cat([x, xs], dim=0), torch.cat([y, y], dim=0)
                outputs = self.global_model(x)
                logits = outputs['logits']
                logits_ova = outputs['logits_ova']
                ce_loss = self.ce_loss(logits, y)
                # open-set detection
                ova_loss = self.lambda_mb * mb_sup_loss(logits_ova, y)
                loss = ce_loss + ova_loss
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.global_model.parameters(), self.clip_grad)
                self.optimizer.step()
                ce_loss_meter.update(ce_loss.item(), y.shape[0])
                ova_loss_meter.update(ova_loss.item(), y.shape[0])
        self.scheduler.step()
        self.logger.log({'ce_loss': ce_loss_meter.avg}, step=round)
        self.logger.log({'ova_loss': ova_loss_meter.avg}, step=round)
        self.printer.debug(f"server train cost {(time.time() - st) / 60:.2f} min")


    @torch.no_grad()
    def test(self, loader, round=0):
        # close set tes
        self.printer.debug(f'-----------------testing-----------------')
        all_y, all_logits = [], []
        self.global_model.train(False)
        self.global_model.to(self.device)
        for data in loader:   
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            logits = self.global_model(x)['logits']
            all_y.append(y)
            all_logits.append(logits)
        y = torch.cat(all_y, dim=0)
        logits = torch.cat(all_logits, dim=0)
        test_acc = (y == torch.argmax(logits, dim=1)).float().mean().item() * 100
        return test_acc
    

    @torch.no_grad()
    def test_openset(self, loader, round=0):
        self.printer.debug(f'-----------------open-set testing-----------------')
        all_y, all_logits, all_logits_ova = [], [], []
        self.global_model.train(False)
        self.global_model.to(self.device)
        for data in loader:   
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            outputs = self.global_model(x)
            logits, logits_ova = outputs['logits'], outputs['logits_ova']
            all_y.append(y)
            all_logits.append(logits)
            all_logits_ova.append(logits_ova)
        y = torch.cat(all_y, dim=0)
        logits = torch.cat(all_logits, dim=0)
        logits_ova = torch.cat(all_logits_ova, dim=0)
        logits_ova = F.softmax(logits_ova.view(logits_ova.size(0), 2, -1), 1)
        inlier_pred = logits.argmax(dim=1)
        outlier_score = logits_ova[torch.arange(0, inlier_pred.size(0)).long().to(self.device), 0, inlier_pred]
        ood_mask = outlier_score > 0.5
        seen_idxs = y < self.num_classes
        close_set_acc = accuracy_score(y[seen_idxs].cpu().numpy(), inlier_pred[seen_idxs].cpu().numpy()) * 100
        inlier_pred[ood_mask] = self.num_classes
        open_set_acc = balanced_accuracy_score(y.cpu().numpy(), inlier_pred.cpu().numpy()) * 100
        return open_set_acc, close_set_acc


class Client(ClientBase):
    def __init__(self, args, id, trainset):
        super().__init__(args, id, trainset)
        self.lambda_oem = 0.1
        self.lambda_cr = 1.0 if args.dataset == 'CIFAR100' else 0.5

    def make_model(self):
        model = super().make_model()
        return OPNet(model, self.num_classes).to(self.device)
    
    @torch.no_grad()
    def filter_outliers(self):
        self.printer.debug(f"client {self.id} filtering outliers...")
        self.model.train(False)
        loader = DataLoader(self.trainset, batch_size=128, shuffle=False)
        filter_ratio, filter_acc, true_ratio  = AverageMeter(), AverageMeter(), AverageMeter()
        inlier_idx = []
        for data in loader:
            idx, x, y = data['idx'].to(self.device), data['x'].to(self.device), data['y'].to(self.device)
            bs = x.shape[0]
            outputs = self.model(x)
            logits = outputs['logits'].softmax(-1)
            logits_ova = outputs['logits_ova'].view(bs, 2, self.num_classes).softmax(1)

            max_prob, pl = logits.max(dim=-1)
            inlier_mask = (logits_ova[torch.arange(0, bs).long(), 0, pl] < 0.5)
            inlier_ = y < self.num_classes
            acc = (inlier_mask == inlier_).float().mean().item()
            filter_acc.update(acc, bs)
            filter_ratio.update(inlier_mask.float().mean().item(), bs)
            true_ratio.update(inlier_.float().mean().item(), bs)
            inlier_idx.append(idx[inlier_mask].cpu().numpy())
        inlier_idx = np.concatenate(inlier_idx, axis=0)
        trainset = copy.deepcopy(self.trainset)
        trainset.data = [trainset.data[i] for i in inlier_idx]
        trainset.targets = [trainset.targets[i] for i in inlier_idx]
        self.model.train(True)
        return trainset, {
            'filter_ratio': filter_ratio.avg,
            'filter_acc': filter_acc.avg,
            'true_ratio':  true_ratio.avg
        }


    def train(self, round_idx, lr, state_dict):
        self.prepare(lr, state_dict)
        fix_util, ood_util, fix_prec = AverageMeter(), AverageMeter(), AverageMeter()
        filter_ratio, filter_acc = [], []
        fix_utils, ood_utils, fix_precs = [], [], []
        loss_meter = AverageMeter()
        losses = []
        loader_all = DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True)

        for epoch in range(self.local_steps):
            inlier_set, logs = self.filter_outliers()
            if len(inlier_set):
                loader = DataLoader(inlier_set, batch_size=self.batch_size, shuffle=True)
                loader_iter = iter(loader)
            for data_all in loader_all:
                x_all, x_s_all, y_all = data_all['x'].to(self.device), data_all['x_s'].to(self.device), data_all['y'].to(self.device)
                outputs_all = self.model(torch.cat([x_all, x_s_all], dim=0))
                logits_ova1, logits_ova2 = outputs_all['logits_ova'].chunk(2)
                # entropy minimization for ova_logits
                logits_ova1, logits_ova2 = F.softmax(logits_ova1.view(logits_ova1.size(0), 2, -1), 1), F.softmax(logits_ova2.view(logits_ova2.size(0), 2, -1), 1)
                ls_oem = torch.mean(torch.sum(-logits_ova1 * torch.log(logits_ova1 + 1e-8), 1)) + torch.mean(torch.sum(-logits_ova2 * torch.log(logits_ova2 + 1e-8), 1))
                # consistency regularization
                ls_ova_cr = torch.mean(torch.sum(torch.sum((logits_ova1 - logits_ova2)**2, 1), 1))  
                loss = ls_ova_cr * self.lambda_cr + ls_oem * 0.5 * self.lambda_oem
                if len(inlier_set):
                    # inlier classifier
                    try:
                        data = next(loader_iter)
                    except:
                        loader_iter = iter(loader)
                        data = next(loader_iter)
                    x, x_s, y = data['x'].to(self.device), data['x_s'].to(self.device), data['y'].to(self.device)
                    outputs = self.model(torch.cat([x, x_s], dim=0))
                    logits, logits_s = outputs['logits'].chunk(2)

                    with torch.no_grad():
                        pseudo_label = torch.softmax(logits, dim=-1)
                        max_probs, pl = torch.max(pseudo_label, dim=-1)
                        fix_mask = max_probs.ge(self.threshold).float()
                        fix_util.update(fix_mask.mean().item(), y.shape[0])

                    if fix_mask.any():
                        fix_prec.update(pl[fix_mask.bool()].eq(y[fix_mask.bool()]).float().mean().item(), fix_mask.sum().item())
                        ls_fc = (F.cross_entropy(logits_s, pl, reduction='none') * fix_mask).mean()
                        loss += ls_fc

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                loss_meter.update(loss.item(), y_all.shape[0])
            
            fix_precs.append(fix_prec.avg)
            fix_utils.append(fix_util.avg)
            ood_utils.append(ood_util.avg)
            fix_util.reset()
            ood_util.reset()
            fix_prec.reset()
            losses.append(loss_meter.avg)
            loss_meter.reset()
            filter_ratio.append(logs['filter_ratio'])
            filter_acc.append(logs['filter_acc'])


        self.logs = {
            'id_ratio': filter_ratio,
            'filter_acc': filter_acc,
            'true_id_ratio': [logs['true_ratio']] * self.local_steps,
            'fix_utils': fix_utils,
            'fix_accs': fix_precs,
            'ood_utils': ood_utils,
            "samples": len(self.trainset)
        }
        self.optimizer_dict = self.optimizer.state_dict()
        self.model.to('cpu')
        self.util = 1

class OPNet(nn.Module):
    def __init__(self, base, num_classes):
        super(OPNet, self).__init__()
        self.backbone = base
        self.feat_planes = base.fc.in_features

        self.fc = nn.Linear(self.feat_planes, num_classes, bias=False)
        self.ova = nn.Linear(self.feat_planes, 2 * num_classes, bias=False)
        
        nn.init.xavier_normal_(self.fc.weight.data)
        nn.init.xavier_normal_(self.ova.weight.data)

    def forward(self, x, **kwargs):
        feat = self.backbone.get_feature(x)
        logits = self.fc(feat) 
        logits_ova = self.ova(feat) 

        return_dict = {
            'logits': logits, 
            'logits_ova': logits_ova
        }

        return return_dict

    def group_matcher(self, coarse=False):
        matcher = self.backbone.group_matcher(coarse, prefix='backbone.')
        return matcher
    
    
def mb_sup_loss(logits_ova, label):
    batch_size = logits_ova.size(0)
    logits_ova = logits_ova.view(batch_size, 2, -1)
    num_classes = logits_ova.size(2)
    probs_ova = F.softmax(logits_ova, 1)
    label_s_sp = torch.zeros((batch_size, num_classes)).long().to(label.device)
    label_range = torch.arange(0, batch_size).long().to(label.device)
    label_s_sp[label_range[label < num_classes], label[label < num_classes]] = 1
    label_sp_neg = 1 - label_s_sp
    open_loss = torch.mean(torch.sum(-torch.log(probs_ova[:, 1, :] + 1e-8) * label_s_sp, 1))
    open_loss_neg = torch.mean(torch.max(-torch.log(probs_ova[:, 0, :] + 1e-8) * label_sp_neg, 1)[0])
    l_ova_sup = open_loss_neg + open_loss
    return l_ova_sup
