"""
Algorithm: adapt BDMatch to SSFL setting
Details:
    - The open-set problem is converted into a binary classification problem: predicting inlier or outlier for each class
    - Class-imbalance learning and label-specific representation learning are further incorporated
    - Dual-branch architecture is used to elicit good representation from the original data distribution
Limitations:
    - the distribution estimation of unlabeled data is not stable in FL setting. Can not use. 
Reference: http://palm.seu.edu.cn/zhangml/files/BDMatch.rar
"""

import os
import math
import time
import torch
import numpy as np
import torch.nn as nn
import torch.cuda as cuda
import torch.nn.functional as F
from sklearn.metrics import balanced_accuracy_score, accuracy_score
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import DataLoader
from algorithm.base import ClientBase, ServerBase
from utils import AverageMeter


class BDMatch(ServerBase):
    def __init__(self, args):
        super().__init__(args, Client)
        self.sup_remargin = - self.args.tau * math.log(self.num_classes - 1)
        self.ce_loss = nn.BCEWithLogitsLoss(reduction='none')


    def make_model(self):
        base = super().make_model()
        return BDMatch_Net(base, self.num_classes).to(self.device)

    def training_stats(self, round_idx):
        fix_util, os_util, os_acc, fix_acc, all_util, all_acc, st_all_util, st_all_acc = 0, 0, 0, 0, 0, 0, 0, 0
        samples = 0
        for id in self.selected_clients:
            logs = self.clients[id].logs
            s = logs['samples']
            samples += s
            fix_util += np.array(logs['fix_utils']).mean() * s
            fix_acc += np.array(logs['fix_accs']).mean() * s
            os_util += np.array(logs['st_utils']).mean() * s
            os_acc += np.array(logs['st_accs']).mean() * s
            all_acc += np.array(logs['all_accs']).mean() * s
            all_util += np.array(logs['all_utils']).mean() * s
            st_all_acc += np.array(logs['st_all_accs']).mean() * s
            st_all_util += np.array(logs['st_all_utils']).mean() * s

            print_log = f'client{id}:'
            for i in range(self.local_steps):
                print_log += f'fix_u={logs["fix_utils"][i]:.2f},fix_acc={logs["fix_accs"][i]:.2f},st_utils={logs["st_utils"][i]:.2f},st_acc={logs["st_accs"][i]:.2f}\n'
            self.printer.info(print_log)
        self.logger.log({'fix_util': fix_util / samples}, step=round_idx)
        self.logger.log({'fix_acc': fix_acc / samples}, step=round_idx)
        self.logger.log({'st_util': os_util / samples}, step=round_idx)
        self.logger.log({'st_acc': os_acc / samples}, step=round_idx)
        self.logger.log({'all_util': all_util / samples}, step=round_idx)
        self.logger.log({'all_acc': all_acc / samples}, step=round_idx)
        self.logger.log({'st_all_util': st_all_util / samples}, step=round_idx)
        self.logger.log({'st_all_acc': st_all_acc / samples}, step=round_idx)
        

    def warmup(self):
        self.global_model.train(True)
        for epoch in range(self.warmup_epochs):
            for i, data in enumerate(self.train_loader):
                x, y = data['x'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                outs = self.global_model(x, st_pred=True)
                logits, st_logits = outs['logits'], outs['st_logits']
                sup_loss = self.ce_loss(logits + self.sup_remargin,
                                    F.one_hot(y.long(), num_classes=self.num_classes).float()).mean() * self.num_classes
                st_sup_loss = self.ce_loss(st_logits,
                                    F.one_hot(y.long(), num_classes=self.num_classes).float()).mean() * self.num_classes
                loss = sup_loss + st_sup_loss
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.global_model.parameters(), self.clip_grad)
                self.optimizer.step()
        

    def train(self, round):
        st = time.time()
        self.global_model.train(True)
        ce_loss_meter = AverageMeter()
        acc_meter = AverageMeter()
        for epoch in range(self.local_steps):
            for i, data in enumerate(self.train_loader):
                x, y = data['x'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                outs = self.global_model(x, st_pred=True)
                logits, st_logits = outs['logits'], outs['st_logits']
                sup_loss = self.ce_loss(logits + self.sup_remargin,
                                    F.one_hot(y.long(), num_classes=self.num_classes).float()).mean() * self.num_classes
                st_sup_loss = self.ce_loss(st_logits,
                                    F.one_hot(y.long(), num_classes=self.num_classes).float()).mean() * self.num_classes
                loss = sup_loss + st_sup_loss
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.global_model.parameters(), self.clip_grad)
                self.optimizer.step()
                ce_loss_meter.update(loss.item(), y.shape[0])
                acc = (logits.argmax(dim=1) == y).float().mean().item()
                acc_meter.update(acc, y.shape[0])
        self.scheduler.step()
        self.logger.log({'train_loss': ce_loss_meter.avg}, step=round)
        self.logger.log({'server@train_acc': acc_meter.avg * 100}, step=round)
        self.printer.info(f"server train cost {(time.time() - st) / 60:.2f} min")


    @torch.no_grad()
    def test(self, loader, round=0):
        # close set test
        self.printer.debug(f'-----------------testing-----------------')
        all_y, all_logits = [], []
        self.global_model.train(False)
        self.global_model.to(self.device)
        for data in loader:   
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            logits = self.global_model(x)['logits']
            all_y.append(y)
            all_logits.append(logits)
        y = torch.cat(all_y, dim=0)
        logits = torch.cat(all_logits, dim=0)
        test_acc = (y == torch.argmax(logits, dim=1)).float().mean().item() * 100
        return test_acc
    

    @torch.no_grad()
    def test_openset(self, loader, round=0):
        # open-set test
        # TODO is there any performance difference between open-set and close-set?
        self.printer.debug(f'-----------------testing: openset-----------------')
        all_y, all_logits = [], []
        self.global_model.train(False)
        self.global_model.to(self.device)
        for data in loader:   
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            logits = self.global_model(x)['logits'].sigmoid()
            all_y.append(y)
            all_logits.append(logits)

        y = torch.cat(all_y, dim=0)
        logits = torch.cat(all_logits, dim=0)
        conf, y_pred = torch.max(logits, dim=1)
        seen_idxs = y < self.num_classes
        close_set_acc = accuracy_score(y[seen_idxs].cpu().numpy(), y_pred[seen_idxs].cpu().numpy()) * 100
        mask = conf.ge(0.5)
        y_pred = torch.where(mask, y_pred, torch.tensor(self.num_classes).to(self.device))
        open_set_acc = balanced_accuracy_score(y.cpu().numpy(), y_pred.cpu().numpy()) * 100
        return open_set_acc, close_set_acc



class Client(ClientBase):
    def __init__(self, args, id, trainset):
        super().__init__(args, id, trainset)
        self.tau = args.tau
        self.ema = args.ema
        self.p_cutoff_pos = args.p_cutoff_pos
        self.p_cutoff_neg = args.p_cutoff_neg
        self.p_model = torch.ones((self.num_classes)) * 0.5
        self.train_loader = DataLoader(self.trainset, batch_size=args.c_batch_size, shuffle=True)
        self.ce_loss = nn.BCEWithLogitsLoss(reduction='none')

    def make_model(self):
        base = super().make_model()
        return BDMatch_Net(base, self.num_classes).to(self.device)

    @torch.no_grad()
    def update_p(self, probs_x_ulb):   ## this distribution estimation is not stable in FL setting
        # check device
        if not self.p_model.is_cuda:
            self.p_model = self.p_model.to(self.device)

        self.p_model = self.p_model * self.ema + (1 - self.ema) * probs_x_ulb.mean(dim=0)
        
    @torch.no_grad()
    def get_logits_adj(self, probs_x_ulb):
        self.update_p(probs_x_ulb)
        logits_adj = torch.log(self.p_model + 1e-8) - torch.log(1.0 - self.p_model + 1e-8)  # imbalance ratio
        return logits_adj


    def pseudo_label(self, probs_x_ulb):
        # B * C
        max_probs, pl = torch.max(probs_x_ulb, dim=-1)
        tmp = max_probs.ge(self.p_cutoff_pos)
        pos_mask = tmp.unsqueeze(1)
        neg_mask = probs_x_ulb.le(self.p_cutoff_neg)
        mask = (pos_mask | neg_mask).to(max_probs.dtype)
            
        pseudo_label = F.one_hot(pl.long(), num_classes=self.num_classes).to(max_probs.dtype)
        pseudo_label *= pos_mask.to(max_probs.dtype)
        
        return mask, pseudo_label, tmp, pl
    

    def train(self, round_idx, lr, state_dict):
        self.prepare(lr, state_dict)
        self.model.train(True)
        loss_meter = AverageMeter()
        losses = []
        util_meter, pl_acc_meter, st_acc_meter, st_util_meter, fix_util_meter, st_fix_util_meter, fix_acc_meter, st_fix_acc_meter = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
        pl_util, pl_accs, st_pl_accs, st_util, fix_util, st_fix_util, fix_acc, st_fix_acc = [], [], [], [], [], [], [], []

        for epoch in range(self.local_steps):
            for i, data in enumerate(self.train_loader):
                x_w, x_s, y = data['x'].to(self.device), data['x_s'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                with torch.no_grad():
                    # generator closed-set and open-set targets (pseudo-labels)
                    outputs = self.model(x_w, st_pred=True)
                    logits_w, st_logits_w = outputs['logits'], outputs['st_logits']
                    prob_w = logits_w.sigmoid()  
                    
                    # adjust the logits with the imbalance ratio
                    logits_adj = self.get_logits_adj(probs_x_ulb=prob_w)
                    prob_adj = (logits_w - self.tau * logits_adj).sigmoid() ## B * C
                    prob_st = st_logits_w.sigmoid()
                    mask, pl, sample_mask, sample_pl = self.pseudo_label(prob_adj)
                    st_mask, st_pl, st_sample_mask, st_sample_pl = self.pseudo_label(prob_st)

                outputs = self.model(x_s, st_pred=True)
                logits_s, st_logits_s = outputs['logits'], outputs['st_logits']
                ## [B, C]
                loss = (self.ce_loss(logits_s, pl) * mask).mean() * self.num_classes
                loss += (self.ce_loss(st_logits_s, st_pl) * st_mask).mean() * self.num_classes
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.model.parameters(), self.clip_grad)
                self.optimizer.step()
                loss_meter.update(loss.item(), y.shape[0])
                util_meter.update(mask.float().mean().item(), y.shape[0])
                st_util_meter.update(st_mask.float().mean().item(), y.shape[0])
                fix_util_meter.update(sample_mask.float().mean().item(), y.shape[0])
                st_fix_util_meter.update(st_sample_mask.float().mean().item(), y.shape[0])

                ground_truth = torch.zeros_like(pl)
                in_mask = y < self.num_classes
                ground_truth[in_mask, y[in_mask]] = 1
                mask, st_mask = mask.bool(), st_mask.bool()

                if mask.any():
                    pl_acc_meter.update(pl[mask].eq(ground_truth[mask]).float().mean().item(), mask.float().sum().item())
                if st_mask.any():
                    st_acc_meter.update(st_pl[st_mask].eq(ground_truth[st_mask]).float().mean().item(), st_mask.float().sum().item())
                    
                if sample_mask.any():
                    fix_acc_meter.update(sample_pl[sample_mask].eq(y[sample_mask]).float().mean().item(), sample_mask.float().sum().item())
                if st_sample_mask.any():
                    st_fix_acc_meter.update(st_sample_pl[st_sample_mask].eq(y[st_sample_mask]).float().mean().item(), st_sample_mask.float().sum().item())
            
            
            pl_util.append(util_meter.avg)
            st_util.append(st_util_meter.avg)
            pl_accs.append(pl_acc_meter.avg)
            st_pl_accs.append(st_acc_meter.avg)
            fix_util.append(fix_util_meter.avg)
            st_fix_util.append(st_fix_util_meter.avg)
            fix_acc.append(fix_acc_meter.avg)
            st_fix_acc.append(st_fix_acc_meter.avg)
            fix_util_meter.reset()
            st_fix_util_meter.reset()
            fix_acc_meter.reset()
            st_fix_acc_meter.reset()
            loss_meter.reset()
            util_meter.reset()
            st_util_meter.reset()
            pl_acc_meter.reset()
            st_acc_meter.reset()
            losses.append(loss_meter.avg)
            loss_meter.reset()

        self.optimizer_dict = self.optimizer.state_dict()
        self.model.to('cpu')
        self.util = np.mean(pl_util) > 0
        self.logs = {
            "samples": len(self.trainset),
            "all_utils": np.array(pl_util),
            "st_all_utils": np.array(st_util),
            "all_accs": np.array(pl_accs),
            "st_all_accs": np.array(st_pl_accs),
            "fix_utils": np.array(fix_util),
            "fix_accs": np.array(fix_acc),
            "st_utils": np.array(st_fix_util),
            "st_accs": np.array(st_fix_acc),
        }



class Attention(nn.Module):
    def __init__(self, dim, num_heads=4, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        # print(f'Attention: dim={dim}, num_heads={num_heads}, qkv_bias={qkv_bias}, attn_drop={attn_drop}, proj_drop={proj_drop}')
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.linear_q = nn.Linear(dim, dim, bias=qkv_bias)
        self.linear_k = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, query_embed):
        B, N, C = x.shape
        K = query_embed.size(1)
        
        q = self.linear_q(query_embed).expand(B, -1, -1).reshape(B, K, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.linear_k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, K, C)
        x = self.proj_drop(x)
        return x        
    

class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim)) # type: ignore

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma
    
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f'drop_prob={round(self.drop_prob,3):0.3f}'
    
class BDMatch_Net(nn.Module):
    def __init__(self, base, num_classes, num_heads=4, qkv_bias=True, attn_drop=0., drop=0.,
                 init_values=None, drop_path=0.2, use_rot=False):
        super(BDMatch_Net, self).__init__()
        self.backbone = base
        self.num_features = base.fc.in_features 
        
        # Multi-head dot-product attention module to extract label-specific features
        self.attn = Attention(self.num_features, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.ls1 = LayerScale(self.num_features, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm = nn.LayerNorm(self.num_features, eps=1e-6)
        
        # initialize the label embedding
        self.query_embed = nn.Parameter(torch.zeros(1, num_classes, self.num_features)) # type: ignore
        nn.init.normal_(self.query_embed)
        
        self.fc2 = nn.Linear(self.num_features, num_classes)
        nn.init.xavier_normal_(self.fc2.weight.data)
        self.fc2.bias.data.zero_()
        
        # Standard classifiers used in the dual-branch architecture
        self.st_fc1 = nn.Linear(self.num_features, num_classes)
        nn.init.xavier_normal_(self.st_fc1.weight.data)
        self.st_fc1.bias.data.zero_()
        self.st_fc2 = nn.Linear(self.num_features, num_classes)
        nn.init.xavier_normal_(self.st_fc2.weight.data)
        self.st_fc2.bias.data.zero_()
    
    def forward(self, x, st_pred=False):
        feat = self.backbone.get_raw_feature(x)
        out = F.adaptive_avg_pool2d(feat, 1)
        out = out.view(-1, 1, self.num_features)
        feat = feat.reshape((feat.size(0), feat.size(1), -1)).permute(0, 2, 1)
        feat = out + self.drop_path1(self.ls1(self.attn(feat, self.query_embed)))
        feat = self.norm(feat)
        
        logits = self.head_forward(feat)
        if not st_pred:
            return {'logits':logits}
        st_logits = self.head_forward_st(feat)
        return {'logits':logits, 'st_logits':st_logits}
    
    def head_forward(self, x):
        logits = (x * self.backbone.fc.weight).sum(dim=-1) + self.backbone.fc.bias
        logits2 = (x * self.fc2.weight).sum(dim=-1) + self.fc2.bias
        return logits - logits2
    
    def head_forward_st(self, x):
        logits = (x * self.st_fc1.weight).sum(dim=-1) + self.st_fc1.bias
        logits2 = (x * self.st_fc2.weight).sum(dim=-1) + self.st_fc2.bias
        return logits - logits2

    def group_matcher(self, coarse=False):
        matcher = self.backbone.group_matcher(coarse, prefix='backbone.')
        return matcher
    
    def no_weight_decay(self):
        nwd = []
        for n, _ in self.named_parameters():
            if 'bn' in n or 'bias' in n:
                nwd.append(n)
        return nwd