"""This code demonstrates the OOD fine-tuning process of EM+LNOIB. We hope this can provide a clearer understanding of how LNOIB serves as a versatile framework to integrate with current OOD fine-tuning strategies that employ global-level objectives. We are happy to share all the code upon publication."""

import argparse
import time
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim

from config import config_training_setup
from src.feat_utils import MNN, calculate_ood_prototypes, calculate_id_sem_prototypes
from src.imageaugmentations import Compose, Normalize, ToTensor, RandomCrop, RandomHorizontalFlip
from src.model_utils import load_network
from torch.utils.data import DataLoader


'''Global-level objective of EM'''
def cross_entropy_glob(logits, targets):
    neg_log_like = - 1.0 * F.log_softmax(logits, 1)
    L = torch.mul(targets.float(), neg_log_like)
    L = L.mean()
    return L


'''Instance-level objective of LNOIB'''
def cross_entropy_ins(logits, ins_tar_list):   
    loss_ins = 0
    for ins_tar in ins_tar_list:
        neg_log_like = - 1.0 * F.log_softmax(logits, 1)
        L = torch.mul(ins_tar.float(), neg_log_like)
        loss_ins += L.sum() / ins_tar.float().sum()
    loss_ins /= len(ins_tar_list)
    return loss_ins


'''ID Semantic objective of LNOIB'''
def cosine_feat_sem(protos_ood, protos_id_sem):
    loss_sem = 0
    for proto_ood in protos_ood:
        for proto_id_sem in protos_id_sem:
            loss_sem += F.cosine_similarity(proto_ood, proto_id_sem)
                
    loss_sem /= (len(protos_ood)*len(protos_id_sem))
    return loss_sem


'''Nearest Neighbor objective of LNOIB'''
def cosine_feat_near(protos_ood, protos_id_sem, M):
    loss_near = 0
    for proto_ood in protos_ood:
        neighbors_list = MNN(proto_ood, protos_id_sem, M)
        for neighbor in neighbors_list:
            loss_near += F.cosine_similarity(proto_ood, neighbor)
    loss_near /= (len(protos_ood)*M)
    return loss_near
            

def training_routine(config):
    print("START OOD FINE-TUNING...")
    params = config.params
    roots = config.roots
    dataset = config.dataset()
    start_epoch = params.training_starting_epoch
    epochs = params.num_training_epochs
    M = 1
    alpha = 0.5
    beta = 0.5
    tau = 0.7
    
    model = load_network(model_name=roots.model_name, num_classes=dataset.num_classes,
                               ckpt_path=roots.init_ckpt, train=True)

    transform = Compose([RandomHorizontalFlip(), RandomCrop(params.crop_size), ToTensor(),
                         Normalize(dataset.mean, dataset.std)])

    for epoch in range(start_epoch, start_epoch + epochs):
        print('\nEpoch {}/{}'.format(epoch + 1, start_epoch + epochs))
        optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
        trainloader = config.dataset('train', transform, roots.cs_root, roots.coco_root, params.ood_subsampling_factor)
        dataloader = DataLoader(trainloader, batch_size=params.batch_size, shuffle=True)   
        for i, (x, target) in enumerate(dataloader):
            optimizer.zero_grad()
            feats, logits = model(x.cuda())
            glob_tar, ins_tar_list = target[0], target[1]
            
            loss_glob = cross_entropy_glob(logits, glob_tar)    # global-level objective in Eq. (1)
            loss_ins = cross_entropy_ins(logits, ins_tar_list)  # instance-level objective in Eq. (5)
            loss_pred = alpha*loss_glob + (1-alpha)*loss_ins    # Eq. (6)
            

            protos_ood = calculate_ood_prototypes(feats, ins_tar_list)  #  calculate prototypes for each anomaly in Eq. (7)
            protos_id_sem = calculate_id_sem_prototypes(feats, logits, glob_tar, tau)   # calculate ID prototypes for each ID category in Eq. (8)
            loss_sem = cosine_feat_sem(protos_ood, protos_id_sem)   # ID semantic loss in Eq. (9) 
            loss_near = cosine_feat_near(protos_ood, protos_id_sem, M)  # nearest neighbor loss in Eq. (12)
            loss_feat = beta*loss_sem + (1-beta)*loss_near  # Eq. (13) 
            
            loss_LNOIB = loss_pred + loss_feat  # overall loss of LNOIB in Eq. (14)
            loss_LNOIB.backward()
            optimizer.step()
            print('{} Loss: {}'.format(i, loss_LNOIB.item()))


def main(args):
    """Perform OOD fine-tuning"""
    config = config_training_setup(args)
    training_routine(config)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='OPTIONAL argument setting')
    parser.add_argument("-train", "--TRAINSET", nargs="?", type=str)
    parser.add_argument("-model", "--MODEL", nargs="?", type=str)
    parser.add_argument("-epoch", "--training_starting_epoch", nargs="?", type=int)
    parser.add_argument("-nepochs", "--num_training_epochs", nargs="?", type=int)
    parser.add_argument("-lr", "--learning_rate", nargs="?", type=float)
    parser.add_argument("-crop", "--crop_size", nargs="?", type=int)
    parser.add_argument("-M", "--MNeighbor", nargs="?", type=int)
    parser.add_argument("-tau", "--tau", nargs="?", type=float)
    parser.add_argument("-alpha", "--alpha", nargs="?", type=float)
    parser.add_argument("-beta", "--beta", nargs="?", type=float)
    main(vars(parser.parse_args()))
