
from __future__ import print_function
import argparse
import torch
import data_loader
import calculate_log as callog
import models
import os
import lib_generation
import numpy as np

from torchvision import transforms
#from torch.autograd import Variable

parser = argparse.ArgumentParser(description='PyTorch code: Mahalanobis detector')
#parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='batch size for data loader')
parser.add_argument('--batch_size', type=int, default=32, metavar='N', help='batch size for data loader')
#parser.add_argument('--dataset', required=True, help='cifar10 | cifar100 | svhn')
parser.add_argument('--dataset_full', required=True, help='cifar10 | cifar100 | svhn')
parser.add_argument('--dataroot', default='./data', help='path to dataset')
#parser.add_argument('--outf', default='output/odd/', help='folder to output results')
##parser.add_argument('--num_classes', type=int, default=10, help='the # of classes')
parser.add_argument('--net_type', required=True, help='resnet | densenet')
parser.add_argument('--gpu', type=int, default=0, help='gpu index')
parser.add_argument('--loss', required=True, help='the loss used')
##parser.add_argument('--valid', required=True, help='none')
##parser.add_argument('--score', default=None, help='score to be used')
#parser.add_argument('--path', default="expers/odd/cnn_train", type=str, help='Path for the odd csv results file')
parser.add_argument('--dir', default="", type=str, help='Part of the dir to use')
##parser.add_argument('-x', '--executions', default=1, type=int, metavar='N', help='Number of executions (default: 1)')


args = parser.parse_args()
print(args)

########
#random.seed(0)
#numpy.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
########


def main():
    #############################################
    if "_" in args.dataset_full:
        args.dataset = args.dataset_full.split("_")[0]
    else:
        args.dataset = args.dataset_full
    #############################################

    #dir_path = os.path.join("expers", args.dir, "cnn_train", "data~"+args.dataset+"+model~"+args.net_type+"+loss~"+str(args.loss))
    dir_path = os.path.join("expers", args.dir, "cnn_train", "data~"+args.dataset_full+"+model~"+args.net_type+"+loss~"+str(args.loss))
    #Path for the odd csv results file:
    #file_path = os.path.join(args.path, "data~"+args.dataset+"+model~"+args.net_type+"+loss~"+str(args.loss), "results_odd.csv")
    file_path = os.path.join(dir_path, "results_odd.csv")
    
    #"""
    with open(file_path, "w") as results_file:
        results_file.write("EXECUTION,MODEL,IN-DATA,OUT-DATA,LOSS,AD-HOC,SCORE,INFER-LEARN,INFER-TRANS,TNR,AUROC,DTACC,AUIN,AUOUT,CPU_FALSE,CPU_TRUE,GPU_FALSE,GPU_TRUE,TEMPERATURE,MAGNITUDE\n")
    #"""

    pre_trained_net = os.path.join(dir_path, "model" + "1" + ".pth")

    args_outf = os.path.join("output", args.dir, args.loss, args.net_type + '+' + args.dataset_full)# + '/'
    if os.path.isdir(args_outf) == False:
        os.makedirs(args_outf)

    torch.cuda.manual_seed(0)
    torch.cuda.set_device(args.gpu)
    
    # define number of classes
    if args.dataset == 'cifar100':
        args.num_classes = 100
    elif args.dataset == 'imagenet32':
        args.num_classes = 1000
    else:
        args.num_classes = 10

    if args.dataset == 'cifar10':
        out_dist_list = ['svhn', 'imagenet_resize', 'lsun_resize']
        #out_dist_list = ['svhn', 'imagenet_resize', 'lsun_resize', 'fooling_images', 'gaussian_noise','uniform_noise']
    elif args.dataset == 'cifar100':
        out_dist_list = ['svhn', 'imagenet_resize', 'lsun_resize']
        #out_dist_list = ['svhn', 'imagenet_resize', 'lsun_resize', 'fooling_images', 'gaussian_noise','uniform_noise']
    elif args.dataset == 'svhn':
        out_dist_list = ['cifar10', 'imagenet_resize', 'lsun_resize']
        #out_dist_list = ['cifar10', 'imagenet_resize', 'lsun_resize', 'fooling_images', 'gaussian_noise','uniform_noise']

    ############ NEW CODE ################
    # preparing and normalizing data
    if args.dataset == 'cifar10':
        in_transform = transforms.Compose(
            [transforms.ToTensor(),
             #transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))])
             transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])
    elif args.dataset == 'cifar100':
        in_transform = transforms.Compose(
            [transforms.ToTensor(),
             #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
             transforms.Normalize((0.507, 0.486, 0.440), (0.267, 0.256, 0.276))])
    elif args.dataset == 'svhn':
        in_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.437, 0.443, 0.472), (0.198, 0.201, 0.197))])
    ############ NEW CODE ################

    # load networks
    if args.net_type == 'densenetbc100':
        model = models.DenseNet3(100, int(args.num_classes), loss=args.loss)
        #model.load_state_dict(torch.load(pre_trained_net, map_location="cuda:" + str(args.gpu)))
    elif args.net_type == 'resnet34':
        model = models.ResNet34(num_c=args.num_classes, loss=args.loss)
        #model.load_state_dict(torch.load(pre_trained_net, map_location="cuda:" + str(args.gpu)))
    elif args.net_type == 'resnet32':
        model = models.ResNet32(num_c=args.num_classes, loss=args.loss)
        #model.load_state_dict(torch.load(pre_trained_net, map_location="cuda:" + str(args.gpu)))
    elif args.net_type == 'resnet110':
        model = models.ResNet110(num_c=args.num_classes, loss=args.loss)
        #model.load_state_dict(torch.load(pre_trained_net, map_location="cuda:" + str(args.gpu)))
    elif args.net_type == 'efficientnetb0':
        model = models.EfficientNetB0(num_c=args.num_classes, loss=args.loss)
        #model.load_state_dict(torch.load(pre_trained_net, map_location="cuda:" + str(args.gpu)))
    model.load_state_dict(torch.load(pre_trained_net, map_location="cuda:" + str(args.gpu)))

    model.cuda()
    #print(model)
    print('load model: ' + args.net_type)
    
    # load dataset
    #print('load target data: ', args.dataset)
    #print('load target data: ', args.dataset_full)
    print('load target valid data: ', args.dataset)
    #train_loader, test_loader = data_loader.getTargetDataSet(args.dataset, args.batch_size, in_transform, args.dataroot)
    _, test_loader = data_loader.getTargetDataSet(args.dataset, args.batch_size, in_transform, args.dataroot)
    #_, inference_loader = data_loader.getTargetDataSet(args.dataset, 1, in_transform, args.dataroot)

    
    M_list = [0]
    T_list = [1]
    """#liberar ODIN
    M_list = [0, 0.0005, 0.001, 0.0014, 0.002, 0.0024, 0.005, 0.01, 0.05, 0.1, 0.2]
    T_list = [0.001, 0.01, 0.1, 0.2, 0.3, 0.5, 1, 2, 3, 5, 10, 100, 1000]
    """#liberar ODIN

    if args.loss.startswith("soft"): 
        scores = ["MPS","ES","MIDS"]
    elif args.loss.startswith("iso"):
        scores = ["MPS","ES","MIDS"]

    for score in scores:
        print("\n\n\n###############################")
        print("###############################")
        print("SCORE:", score)
        print("###############################")
        print("###############################")

        ###############################################################
        ###############################################################
        #infer_learns = ['NO']
        #"""
        if args.loss.startswith("soft"): 
            infer_learns = ['NO']
        elif args.loss.startswith("iso"):
            infer_learns = ['NO']
        elif args.loss.startswith("ls"):
            infer_learns = ['NO']
        elif args.loss.startswith("eml"):
            infer_learns = ['NO']
        elif args.loss.startswith("cml"):
            infer_learns = ['NO']
        elif args.loss.startswith("nml"):
            infer_learns = ['NO']
        #"""

        for infer_learn in infer_learns:
            model.classifier.inference_learn = infer_learn

            ##############################################################
            #for infer_trans in [True, False]:
            for infer_trans in [False]:
                model.classifier.inference_transform = infer_trans
                ##############################################################
                print()
                print("***********", score)
                print("***********", infer_learn)
                print("***********", infer_trans)
                print()
                ##############################################################
                ##### STARTING IDENTATION ####################################
                ##############################################################
                print("\n#########################################")
                print("INFER LEARN [PROCESSING...]:", infer_learn)
                print("INFER TRANS [PROCESSING...]:", infer_trans)
                print("#########################################\n")

                base_line_list = []
                ODIN_best_tnr = [0, 0, 0]
                ODIN_best_results = [0, 0, 0]
                ODIN_best_temperature = [-1, -1, -1]
                ODIN_best_magnitude = [-1, -1, -1]

                for T in T_list:
                    for m in M_list:
                        magnitude = m
                        temperature = T
                        #lib_generation.get_posterior(model, args.net_type, test_loader, magnitude, temperature, args_outf, True)
                        lib_generation.get_posterior(
                            model, args.dataset, test_loader, magnitude, temperature, args_outf, True, args.loss, score)
                        ###################
                        #"""
                        inference_time_cpu_false = 0
                        inference_time_cpu_true = 0
                        inference_time_gpu_false = 0
                        inference_time_gpu_true = 0
                        #"""
                        ###################
                        out_count = 0
                        print('Temperature: ' + str(temperature) + ' / noise: ' + str(magnitude)) 
                        for out_dist in out_dist_list:
                            ###################################################################################
                            ###################################################################################
                            out_test_loader = data_loader.getNonTargetDataSet(out_dist, args.batch_size, in_transform, args.dataroot)
                            ###################################################################################
                            ###################################################################################
                            print('Out-distribution: ' + out_dist)
                            #lib_generation.get_posterior(
                            # model, args.net_type, out_test_loader, magnitude, temperature, args_outf, False)
                            lib_generation.get_posterior(
                                model, args.dataset, out_test_loader, magnitude, temperature, args_outf, False, args.loss, score)
                            if temperature == 1 and magnitude == 0:
                                test_results = callog.metric(args_outf, ['PoT'])
                                base_line_list.append(test_results)
                            """#liberar ODIN
                            else:
                                val_results = callog.metric(args_outf, ['PoV'])
                                # PoT is TRAIN (from a metric learning point of view) and PoV is VALIDATION
                                if ODIN_best_tnr[out_count] < val_results['PoV']['TNR']:
                                    ODIN_best_tnr[out_count] = val_results['PoV']['TNR']
                                    ODIN_best_results[out_count] = callog.metric(args_outf, ['PoT'])
                                    ODIN_best_temperature[out_count] = temperature
                                    ODIN_best_magnitude[out_count] = magnitude
                            """#liberar ODIN
                            out_count += 1
                
                # print the results
                mtypes = ['TNR', 'AUROC', 'DTACC', 'AUIN', 'AUOUT']
                #######################################################

                #print('Baseline method: in_distribution: ' + args.dataset + '==========')
                #print('Baseline method: in_distribution: ' + args.dataset_full + '==========')
                print('Baseline method: train in_distribution: ' + args.dataset_full + '==========')
                count_out = 0
                for results in base_line_list:
                    #print("COUNT_OUT:", count_out)
                    print('out_distribution: '+ out_dist_list[count_out])
                    for mtype in mtypes:
                        print(' {mtype:6s}'.format(mtype=mtype), end='')
                    print('\n{val:6.2f}'.format(val=100.*results['PoT']['TNR']), end='')
                    print(' {val:6.2f}'.format(val=100.*results['PoT']['AUROC']), end='')
                    print(' {val:6.2f}'.format(val=100.*results['PoT']['DTACC']), end='')
                    print(' {val:6.2f}'.format(val=100.*results['PoT']['AUIN']), end='')
                    print(' {val:6.2f}\n'.format(val=100.*results['PoT']['AUOUT']), end='')
                    print('')
                    ######################################################################
                    #"""
                    #Saving odd results:
                    with open(file_path, "a") as results_file:
                        #results_file.write("{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
                        results_file.write("{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
                            #################
                            "1",
                            #################
                            args.net_type,
                            #args.dataset,
                            args.dataset_full,
                            out_dist_list[count_out],
                            str(args.loss),
                            "NATIVE",
                            score,
                            #################
                            infer_learn,
                            infer_trans,
                            #################
                            '{:.2f}'.format(100.*results['PoT']['TNR']),
                            '{:.2f}'.format(100.*results['PoT']['AUROC']),
                            '{:.2f}'.format(100.*results['PoT']['DTACC']),
                            '{:.2f}'.format(100.*results['PoT']['AUIN']),
                            '{:.2f}'.format(100.*results['PoT']['AUOUT']),
                            ##################
                            inference_time_cpu_false,
                            inference_time_cpu_true,
                            inference_time_gpu_false,
                            inference_time_gpu_true,
                            ##################
                            1,
                            0,
                            )
                        )
                    #"""
                    ######################################################################
                    count_out += 1

                """#liberar ODIN
                print('ODIN method: in_distribution: ' + args.dataset + '==========')
                count_out = 0
                for results in ODIN_best_results:
                    #print("COUNT_OUT:", count_out)
                    print('out_distribution: '+ out_dist_list[count_out])
                    for mtype in mtypes:
                        print(' {mtype:6s}'.format(mtype=mtype), end='')
                    print('\n{val:6.2f}'.format(val=100.*results['PoT']['TNR']), end='')
                    print(' {val:6.2f}'.format(val=100.*results['PoT']['AUROC']), end='')
                    print(' {val:6.2f}'.format(val=100.*results['PoT']['DTACC']), end='')
                    print(' {val:6.2f}'.format(val=100.*results['PoT']['AUIN']), end='')
                    print(' {val:6.2f}\n'.format(val=100.*results['PoT']['AUOUT']), end='')
                    print('temperature: ' + str(ODIN_best_temperature[count_out]))
                    print('magnitude: '+ str(ODIN_best_magnitude[count_out]))
                    print('')
                    ######################################################################
                    #Saving odd results:
                    with open(file_path, "a") as results_file:
                        #results_file.write("{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
                        results_file.write("{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
                            #################
                            "1",
                            #################
                            args.net_type,
                            #args.dataset,
                            args.dataset_full,
                            out_dist_list[count_out],
                            str(args.loss),
                            "ODIN",
                            score,
                            #################
                            infer_learn,
                            infer_trans,
                            #################
                            '{:.2f}'.format(100.*results['PoT']['TNR']),
                            '{:.2f}'.format(100.*results['PoT']['AUROC']),
                            '{:.2f}'.format(100.*results['PoT']['DTACC']),
                            '{:.2f}'.format(100.*results['PoT']['AUIN']),
                            '{:.2f}'.format(100.*results['PoT']['AUOUT']),
                            ##################
                            inference_time_cpu_false,
                            inference_time_cpu_true,
                            inference_time_gpu_false,
                            inference_time_gpu_true,
                            ##################
                            str(ODIN_best_temperature[count_out]),
                            str(ODIN_best_magnitude[count_out]),
                            )
                        )
                    ######################################################################
                    count_out += 1
                """#liberar ODIN
                ##############################################################
                ##### ENDING IDENTATION ######################################
                ##############################################################
    
if __name__ == '__main__':
    main()
