"""
Algorithm: adapt IOMatch to SSFL setting
Details:
    - In IOMatch, there three classifier: close-set classifier (which seen class the sample belongs to), open-set classifier (which class, seen+unknown, the sample belongs to), 
    and multi-binary classifier (ova classifier, probs belonging to the target class). 
    - The ova classifier is trained only with labeled data, the close-set classifier is trained with labeled data and high-confidence unlabeled data, the open-set classifier is trained with labeled data 
    and moderate-confidence unlabeled data. 
    - We noticed that the original implementation of IOMatch uses a distribution alignment stage, aligning p_model with uniform distriburion, which is not suitable in heterogeneous FL setting.
Limitation: 
    - During local training, the ova_classifier is fixed, while the backbone keep updating, which may cause the parameter incompitability.
    - the fused open-set pseudo-labels is not accurate under complex task
Reference: https://github.com/nukezil/IOMatch
"""

import os
import copy
import time
import torch
import numpy as np
import torch.nn as nn
import torch.cuda as cuda
import torch.nn.functional as F
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import DataLoader
from sklearn.metrics import balanced_accuracy_score, accuracy_score
from algorithm.base import ClientBase, ServerBase
from utils import AverageMeter


class IOMatch(ServerBase):

    def __init__(self, args):
        super().__init__(args, Client)
        self.lambda_mb = 1
        
    def make_model(self):
        model = super().make_model()
        return IOMatchNet(model, self.num_classes).to(self.device)
    
    def training_stats(self, round_idx):
        fix_util, os_util, os_acc, fix_acc = 0, 0, 0, 0
        samples = 0
        for id in self.selected_clients:
            logs = self.clients[id].logs
            s = logs['samples']
            samples += s
            x = np.array(logs['fix_utils']).mean() * s
            fix_util += x
            fix_acc += np.array(logs['fix_accs']).mean() * x
            x = np.array(logs['os_utils']).mean() * s
            os_util += x
            os_acc += np.array(logs['os_accs']).mean() * x

            print_log = f'client-{id}:'
            for i in range(self.local_steps):
                print_log += f'fix_u={logs["fix_utils"][i]:.2f},fix_acc={logs["fix_accs"][i]:.2f},os_utils={logs["os_utils"][i]:.2f},os_acc={logs["os_accs"][i]:.2f}\n'
            self.printer.info(print_log)
        self.logger.log({'fix_util': fix_util / (samples + 1e-9)}, step=round_idx)
        self.logger.log({'fix_acc': fix_acc / (fix_util + 1e-9)}, step=round_idx)
        self.logger.log({'os_util': os_util / (samples + 1e-9)}, step=round_idx)
        self.logger.log({'os_acc': os_acc / (os_util + 1e-9)}, step=round_idx)

    def warmup(self):
        self.global_model.train(True)
        for epoch in range(self.warmup_epochs):
            for i, data in enumerate(self.train_loader):
                x, y = data['x'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                logits = self.global_model(x)['logits']
                loss = self.ce_loss(logits, y)
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.global_model.parameters(), self.clip_grad)
                self.optimizer.step()
        

    def train(self, round):
        st = time.time()
        self.global_model.train(True)
        ce_loss_meter, ova_loss_meter = AverageMeter(), AverageMeter()
        acc_meter = AverageMeter()
        for epoch in range(self.local_steps):
            for i, data in enumerate(self.train_loader):
                x, y = data['x'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                outputs = self.global_model(x)
                logits = outputs['logits']
                logits_ova = outputs['logits_ova']
                sup_closed_loss = self.ce_loss(logits, y)
                sup_mb_loss = self.lambda_mb * mb_sup_loss(logits_ova, y)
                loss = sup_closed_loss + sup_mb_loss
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.global_model.parameters(), self.clip_grad)
                self.optimizer.step()
                ce_loss_meter.update(loss.item(), y.shape[0])
                acc = (logits.argmax(dim=1) == y).float().mean().item()
                acc_meter.update(acc, y.shape[0])
                # self.printer.info(f"server train epoch {epoch}, iter {i}, loss {loss.item():.4f}, acc {acc * 100:.2f}")
        self.scheduler.step()
        self.logger.log({'train_loss': ce_loss_meter.avg}, step=round)
        self.logger.log({'server@train_acc': acc_meter.avg * 100}, step=round)
        self.printer.info(f"server train cost {(time.time() - st) / 60:.2f} min")


    @torch.no_grad()
    def test(self, loader, round=0):
        # close set test
        self.printer.debug(f'-----------------testing----------------')
        all_y, all_logits = [], []
        self.global_model.train(False)
        self.global_model.to(self.device)
        for data in loader:   
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            logits = self.global_model(x)['logits']
            all_y.append(y)
            all_logits.append(logits)
        y = torch.cat(all_y, dim=0)
        logits = torch.cat(all_logits, dim=0)
        test_acc = (y == torch.argmax(logits, dim=1)).float().mean().item() * 100
        return test_acc
    


    @torch.no_grad()
    def test_openset(self, loader, round=0):
        # open-set test
        # TODO is there any performance difference between open-set and close-set?
        self.printer.debug(f'-----------------testing-----------------')
        all_y, all_logits, all_logits_op = [], [], []
        self.global_model.train(False)
        self.global_model.to(self.device)
        for data in loader:   
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            outputs = self.global_model(x)
            logits, logits_op = outputs['logits'], outputs['logits_op']
            all_y.append(y)
            all_logits.append(logits)
            all_logits_op.append(logits_op)
        y = torch.cat(all_y, dim=0)
        logits = torch.cat(all_logits, dim=0)
        logits_op = torch.cat(all_logits_op, dim=0)
        op_pred = logits_op.argmax(dim=1)
        in_pred = logits.argmax(dim=1)
        open_set_acc = balanced_accuracy_score(y.cpu().numpy(), op_pred.cpu().numpy()) * 100
        seen_idxs = y < self.num_classes
        close_set_acc = accuracy_score(in_pred[seen_idxs].cpu().numpy(), y[seen_idxs].cpu().numpy()) * 100
        return open_set_acc, close_set_acc


class Client(ClientBase):
    def __init__(self, args, id, trainset):
        super().__init__(args, id, trainset)
        self.lambda_op = args.op_loss_ratio
        self.open_cutoff = args.open_threshold
        self.train_loader = DataLoader(self.trainset, batch_size=args.c_batch_size, shuffle=True)

    def make_model(self):
        model = super().make_model()
        return IOMatchNet(model, self.num_classes).to(self.device)
    

    def train(self, round_idx, lr, state_dict):
        self.prepare(lr, state_dict)
        self.model.train(True)
        loss_meter = AverageMeter()
        losses = []
        fix_util_meter, fix_prec_meter, os_util_meter, os_prec_meter = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
        os_acc, fix_acc, os_util, fix_util = [], [], [], []
        
        for epoch in range(self.local_steps):
            for i, data in enumerate(self.train_loader):
                x_w, x_s, y = data['x'].to(self.device), data['x_s'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                with torch.no_grad():
                    # generator closed-set and open-set targets (pseudo-labels)
                    outputs = self.model(x_w)
                    logits_w = outputs['logits']
                    logits_ova = outputs['logits_ova']
                    pl_close = F.softmax(logits_w, dim=-1)
                    num_ulb = y.size(0)
                    logits_ova = logits_ova.view(num_ulb, 2, -1)
                    ova = F.softmax(logits_ova, 1)
                    tmp_range = torch.arange(0, num_ulb).long().to(self.device)
                    o_neg = ova[tmp_range, 0, :]
                    o_pos = ova[tmp_range, 1, :]
                    pl_open = torch.zeros((num_ulb, self.num_classes + 1)).to(self.device)
                    pl_open[:, :self.num_classes] = pl_close * o_pos                 # inlier score
                    pl_open[:, self.num_classes] = torch.sum(pl_close * o_neg, 1)    # outlier score
                    out_scores = torch.sum(pl_close * o_neg, 1)   # outlier score
                    in_mask = (out_scores < 0.5)
                    close_conf, close_lb = torch.max(pl_close, 1)
                    open_conf, open_lb = torch.max(pl_open, 1)
                    close_mask = close_conf.ge(self.threshold)
                    open_mask = open_conf.ge(self.open_cutoff)

                outputs = self.model(x_s)
                logits_s = outputs['logits']
                logits_open_s = outputs['logits_op']
                loss = torch.tensor(0., requires_grad=True).to(self.device)
                if any(in_mask * close_mask):
                    mask = in_mask * close_mask
                    ui_loss = (F.cross_entropy(logits_s, close_lb, reduction='none') * mask.float()).mean()
                    loss += ui_loss
                    fix_util_meter.update(mask.float().mean().item(), int(mask.float().sum().item()))
                    fix_prec_meter.update((y[mask] == close_lb[mask]).float().mean().item(), int(mask.float().sum().item()))
                if any(open_mask):
                    op_loss = (F.cross_entropy(logits_open_s, open_lb, reduction='none') * open_mask.float()).mean()
                    loss += op_loss
                    os_prec_meter.update((y[open_mask] == open_lb[open_mask]).float().mean().item(), int(open_mask.float().sum().item()))
                    os_util_meter.update(open_mask.float().mean().item(), int(open_mask.float().sum().item()))
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.model.parameters(), self.clip_grad)
                self.optimizer.step()
                loss_meter.update(loss.item(), num_ulb)
            
            os_acc.append(os_prec_meter.avg)
            os_prec_meter.reset()
            fix_acc.append(fix_prec_meter.avg)
            fix_prec_meter.reset()
            os_util.append(os_util_meter.avg)
            fix_util.append(fix_util_meter.avg)
            os_util_meter.reset()
            fix_util_meter.reset()
            losses.append(loss_meter.avg)
            loss_meter.reset()

        self.logs = {
            "os_accs": np.array(os_acc),
            "fix_accs": np.array(fix_acc),
            "os_utils": np.array(os_util),
            "fix_utils": np.array(fix_util),
            "samples": len(self.trainset)
        }
        self.optimizer_dict = self.optimizer.state_dict()
        self.model.to('cpu')
        self.util = np.mean(os_util) > 0 or np.mean(fix_util) > 0

        

def mb_sup_loss(logits_ova, label):
    batch_size = logits_ova.size(0)
    logits_ova = logits_ova.view(batch_size, 2, -1)
    num_classes = logits_ova.size(2)
    probs_ova = F.softmax(logits_ova, 1)
    label_s_sp = torch.zeros((batch_size, num_classes)).long().to(label.device)
    label_range = torch.arange(0, batch_size).long().to(label.device)
    # set the multi-binary label for inliers
    label_s_sp[label_range[label < num_classes], label[label < num_classes]] = 1
    label_sp_neg = 1 - label_s_sp
    open_loss = torch.mean(torch.sum(-torch.log(probs_ova[:, 1, :] + 1e-8) * label_s_sp, 1))
    open_loss_neg = torch.mean(torch.max(-torch.log(probs_ova[:, 0, :] + 1e-8) * label_sp_neg, 1)[0])
    l_ova_sup = open_loss_neg + open_loss
    return l_ova_sup


class IOMatchNet(nn.Module):
    def __init__(self, base, num_classes, proj_size=128):
        super(IOMatchNet, self).__init__()
        self.backbone = base
        self.feat_planes = base.fc.in_features

        self.mlp_proj = nn.Sequential(*[
            nn.Linear(self.feat_planes, self.feat_planes),
            nn.ReLU(inplace=False),
            nn.Linear(self.feat_planes, proj_size)
        ])

        self.mb_classifiers = nn.Linear(proj_size, num_classes * 2, bias=False)
        self.openset_classifier = nn.Linear(proj_size, num_classes + 1)

        # initialize the added two classifiers
        nn.init.xavier_normal_(self.mb_classifiers.weight.data)
        nn.init.xavier_normal_(self.openset_classifier.weight.data)
        self.openset_classifier.bias.data.zero_()

    def forward(self, x, **kwargs):
        feat = self.backbone.get_feature(x)
        logits = self.backbone.fc(feat)
        feat_proj = self.mlp_proj(feat)
        logits_open = self.openset_classifier(feat_proj)  # (k+1)-way logits
        logits_ova = self.mb_classifiers(feat_proj)  # shape: [bsz, 2K]

        return_dict = {
            'logits': logits, 
            'logits_op': logits_open, 
            'logits_ova': logits_ova
        }

        return return_dict

    def group_matcher(self, coarse=False):
        matcher = self.backbone.group_matcher(coarse, prefix='backbone.')
        return matcher