'''
Train with only labeled data, serving as lower bound.
'''

import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.clip_grad import clip_grad_norm_
from datasets import fetch_dataset, fetch_os_dataset
from sklearn.metrics import balanced_accuracy_score, accuracy_score
from utils import AverageMeter, get_optimizer, get_net_builder

class TrainAloneOS(object):
    def __init__(self, args) -> None:
        self.args = args
        self.algorithm = args.algorithm
        self.device = torch.device('cuda:0')
        self.epochs = args.global_rounds
        self.logger = args.logger
        self.printer = args.printer
        self.save_dir = args.save_dir
        self.exp_tag = args.exp_tag
        self.data_shape = args.data_shape
        self.num_classes = args.num_classes
        self.clip_grad = args.clip_grad
        self.lb_set, self.testset = self.make_dataset(args)
        self.train_loader = DataLoader(self.lb_set, batch_size=args.batch_size, shuffle=True)
        self.model = self.make_model()
        self.optimizer = get_optimizer(self.model, args.optim, args.lr, args.momentum, args.weight_decay)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.epochs * len(self.train_loader), eta_min=0)
        self.ce_loss = nn.CrossEntropyLoss()
        self.best_acc = 0.
        self.train_ova = args.train_ova

    def make_dataset(self, args):
        self.printer.debug('make dataset')
        if args.num_seen_class < args.num_classes:
            datasets = fetch_os_dataset(args.data_dir, args.dataset, args.num_labels, args.num_seen_class, close_set=args.close_train, train=True)
            testset = fetch_os_dataset(args.data_dir, args.dataset, args.num_labels, args.num_seen_class, close_set=args.close_test, train=False)
            self.num_classes = args.num_seen_class
        else:
            datasets = fetch_dataset(args.data_dir, args.dataset, args.num_labels, train=True)
            testset = fetch_dataset(args.data_dir, args.dataset, args.num_labels, train=False)
        train, test = datasets['lb_set'], testset['ulb_set']  
        return train, test
    
    def make_model(self):
        net_builer = get_net_builder(self.args.net)
        return FedNet(net_builer(self.data_shape, self.num_classes), self.num_classes).to(self.device)

    def train(self, round):
        st = time.time()
        self.model.train(True)
        ce_loss_meter = AverageMeter()
        acc_meter = AverageMeter()
        for i, data in enumerate(self.train_loader):
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(x)
            logits = outputs['logits']
            logits_ova = outputs['logits_ova']
            loss = self.ce_loss(logits, y)
            if self.train_ova:
                logits_ova = logits_ova.view(logits_ova.size(0), 2, -1)
                loss += mb_sup_loss(logits_ova, y)

            loss.backward()
            if self.clip_grad > 0:
                clip_grad_norm_(self.model.parameters(), self.clip_grad)
            self.optimizer.step()
            self.scheduler.step()
            ce_loss_meter.update(loss.item(), y.shape[0])
            acc = (logits.argmax(dim=1) == y).float().mean().item()
            acc_meter.update(acc, y.shape[0])
        self.logger.log({'train_loss': ce_loss_meter.avg}, step=round)
        self.logger.log({'server@train_acc': acc_meter.avg}, step=round)
        self.printer.info(f"server train cost {(time.time() - st) / 60:.2f} min")

    @torch.no_grad()
    def test_openset(self, loader, round=0, log=True):
        # open-set test
        self.printer.debug(f'-----------------testing-----------------')
        all_y, all_logits, all_logits_op = [], [], []
        self.model.train(False)
        self.model.to(self.device)
        for data in loader:   
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            outputs = self.model(x)
            logits, logits_op = outputs['logits'], outputs['logits_ova']
            all_y.append(y)
            all_logits.append(logits)
            all_logits_op.append(logits_op)
        y = torch.cat(all_y, dim=0)
        logits = torch.cat(all_logits, dim=0)
        logits_op = torch.cat(all_logits_op, dim=0)

        logits_op = F.softmax(logits_op.view(logits_op.size(0), 2, -1), 1) # (bs, 2, num_class)
        inlier_pred = logits.argmax(dim=1)  # outlier score of the inlier classifier
        outlier_score = logits_op[torch.arange(len(y)), 0, inlier_pred]
        seen_idxs = y < self.num_classes
        close_set_acc = (y[seen_idxs] == torch.argmax(logits[seen_idxs], dim=1)).float().mean().item() * 100

        ood_mask = outlier_score > 0.5
        inlier_pred[ood_mask] = self.num_classes
        open_set_acc = balanced_accuracy_score(y.cpu().numpy(), inlier_pred.cpu().numpy()) * 100
        
        self.printer.info(f'close_set_acc: {close_set_acc}, openset_acc: {open_set_acc}')

        return open_set_acc, close_set_acc
    
    
    def evaluate(self, round_idx):
        """
        Evaluate the accuracy of current model on each domain testset
        
        """
        st = time.time()
        print_log = 'Evaluation: '
        loader = DataLoader(self.testset, batch_size=512)
        open_set_acc, acc = self.test_openset(loader, round=round_idx)
        self.logger.log({"open_set_acc": open_set_acc}, step=round_idx)
        best = False
        if acc > self.best_acc:
            self.best_acc = acc
            best = True
        log_dict = {f'test_acc': acc, f'best_acc': self.best_acc}
        self.logger.log(log_dict, step=round_idx)
        print_log += f'{acc:.2f}%; '
        self.save_model(round_idx, best=best)
        self.printer.info(print_log)
        self.printer.info(f'evaluation cost: {(time.time() - st)/60:.2f} min')
        

    def run(self):
        for epoch in range(self.epochs):
            self.train(epoch)
            self.evaluate(epoch)


    def save_model(self, round, best=False):
        if best:
            model_path = os.path.join(self.save_dir, f'models_best.pth')
        
            ckpt = {
                'ckpt_model': self.model.state_dict(),
                'ckpt_round': round,
            }
            torch.save(ckpt, model_path)
        if round == self.epochs - 1:
            model_path = os.path.join(self.save_dir, f'models_final.pth')
            ckpt = {
                'ckpt_model': self.model.state_dict(),
                'ckpt_round': round,
            }
            torch.save(ckpt, model_path)
            self.printer.info(f'save model to {model_path}')


class FedNet(nn.Module):
    def __init__(self, base, num_classes):
        super(FedNet, self).__init__()
        self.backbone = base
        feat_planes = base.fc.in_features
        self.ova_feat = nn.Sequential(nn.Linear(feat_planes, feat_planes),
                                    nn.ReLU(),
                                    nn.Linear(feat_planes, feat_planes),
                                    nn.ReLU())
        self.fc_feat = nn.Sequential(nn.Linear(feat_planes, feat_planes),
                                    nn.ReLU(),
                                    nn.Linear(feat_planes, feat_planes),
                                    nn.ReLU())
        self.fc = nn.Linear(feat_planes, num_classes, bias=False)
        self.ova = nn.Linear(feat_planes, 2 * num_classes, bias=False)
        
        nn.init.xavier_normal_(self.fc.weight.data)
        nn.init.xavier_normal_(self.ova.weight.data)
                   

    def forward(self, x, **kwargs):
        feat = self.backbone.get_feature(x)
        logits = self.fc(self.fc_feat(feat)) 
        logits_ova = self.ova(
            self.ova_feat(feat)
        )  

        results = {
            'logits': logits, 
            'logits_ova': logits_ova
        }
        return results


def mb_sup_loss(logits_ova, label):
    batch_size, num_classes = logits_ova.size(0), logits_ova.size(2)
    probs_ova = F.softmax(logits_ova, 1)
    label_s_sp = torch.zeros((batch_size, num_classes)).to(label.device)
    label_range = torch.arange(0, batch_size).long().to(label.device)
    # set the multi-binary label for inliers
    label_s_sp[label_range[label < num_classes], label[label < num_classes]] = 1.
    label_sp_neg = 1 - label_s_sp
    open_loss = torch.mean(torch.sum(-torch.log(probs_ova[:, 1, :] + 1e-8) * label_s_sp, 1))
    open_loss_neg = torch.mean(torch.max(-torch.log(probs_ova[:, 0, :] + 1e-8) * label_sp_neg, 1)[0])
    return open_loss_neg + open_loss
