import os
import sys
import torch
import torch.nn as nn
import models
import loaders
import losses
import statistics
import math
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchnet as tnt
import numpy as np
import time
import utils
import matplotlib.pyplot as plt
import data_loader
import pickle
import random
#import agents

from tqdm import tqdm
from torchvision import transforms
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics.cluster import adjusted_mutual_info_score
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import homogeneity_completeness_v_measure
#from center_loss import CenterLoss


sns.set(style="darkgrid")


class CNNAgent:

    def __init__(self, args):

        self.args = args
        self.epoch = None
        self.cluster_predictions_transformation = []

        # create dataset
        image_loaders = loaders.ImageLoader(args)
        (self.trainset_first_partition_loader_for_train,
         self.trainset_second_partition_loader_for_train,
         self.trainset_first_partition_loader_for_infer,
         self.trainset_second_partition_loader_for_infer,
         self.valset_loader, self.ood_loader, self.normalize) = image_loaders.get_loaders()
        self.batch_normalize = loaders.BatchNormalize(
            self.normalize.mean, self.normalize.std, inplace=True, device=torch.cuda.current_device())
        if self.args.partition == "1":
            self.trainset_loader_for_train = self.trainset_first_partition_loader_for_train
        elif self.args.partition == "2":
            self.trainset_loader_for_train = self.trainset_second_partition_loader_for_train
        print("\nDATASET:", args.dataset_full)

        # create model
        torch.manual_seed(self.args.execution_seed)
        torch.cuda.manual_seed(self.args.execution_seed)
        print("=> creating model '{}'".format(self.args.model_name))
        if self.args.model_name == "densenetbc100":
            self.model = models.DenseNet3(
                100, int(self.args.number_of_model_classes), loss=self.args.loss)
        elif self.args.model_name == "resnet32":
            self.model = models.ResNet32(
                num_c=self.args.number_of_model_classes, loss=self.args.loss)
        elif self.args.model_name == "resnet34":
            self.model = models.ResNet34(
                num_c=self.args.number_of_model_classes, loss=self.args.loss)
        elif self.args.model_name == "resnet110":
            self.model = models.ResNet110(
                num_c=self.args.number_of_model_classes, loss=self.args.loss)
        elif self.args.model_name == "resnet56":
            self.model = models.ResNet56(
                num_c=self.args.number_of_model_classes, loss=self.args.loss)
        elif self.args.model_name == "resnet18_":
            self.model = models.resnet18_(
                num_classes=self.args.number_of_model_classes, loss=self.args.loss)
        elif self.args.model_name == "resnet34_":
            self.model = models.resnet34_(
                num_classes=self.args.number_of_model_classes, loss=self.args.loss)
        elif self.args.model_name == "resnet101_":
            self.model = models.resnet101_(
                num_classes=self.args.number_of_model_classes, loss=self.args.loss)
        elif self.args.model_name == "wideresnet3410":
            self.model = models.Wide_ResNet(depth=34, widen_factor=10,
            num_classes=self.args.number_of_model_classes, loss=self.args.loss)
        elif self.args.model_name == "efficientnetb0":
            self.model = models.EfficientNetB0(
                num_c=self.args.number_of_model_classes, loss=self.args.loss)
        self.model.cuda()
        torch.manual_seed(self.args.base_seed)
        torch.cuda.manual_seed(self.args.base_seed)

        # print and save model arch...
        if self.args.exp_type == "cnn_train":
            print("\nMODEL:", self.model)
            with open(os.path.join(self.args.experiment_path, 'model.arch'), 'w') as file:
                print(self.model, file=file)

        print("\n$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
        utils.print_num_params(self.model)
        print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n")

        # create loss
        self.criterion = losses.GenericLossSecondPart(self.model.classifier).cuda()
        print(self.model.classifier.weights.size())
        if self.args.loss.split("_")[3].startswith("cl"):
            print("center loss regularization!!!")
            self.center_loss = losses.CenterLoss(
                num_classes=self.args.number_of_model_classes, feat_dim=self.model.classifier.weights.size(1), use_gpu=True)

        # create train
        parameters = self.model.parameters()
        self.optimizer = torch.optim.SGD(
            parameters,
            lr=self.args.original_learning_rate,
            momentum=self.args.momentum,
            nesterov=True,
            weight_decay=args.weight_decay)

        if self.args.loss.split("_")[3].startswith("cl"):
            print("center loss regularization!!!")
            self.optimizer_centloss = torch.optim.SGD(
                self.center_loss.parameters(),
                lr=0.5,
                momentum=self.args.momentum,
                nesterov=True,
                weight_decay=args.weight_decay)

        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=self.args.learning_rate_decay_epochs, gamma=args.learning_rate_decay_rate)

        print("\nTRAIN:", self.criterion, self.optimizer, self.scheduler)


    def train_validate(self):
        # template for others procedures of this class...
        # building results and raw results files...
        if self.args.execution == 1:
            with open(self.args.executions_best_results_file_path, "w") as best_results:
                best_results.write(
                    "DATA,MODEL,LOSS,EXECUTION,EPOCH,"
                    "TRAIN LOSS,TRAIN ACC1,TRAIN ODD_ACC,"
                    #"TRAIN LOSS,TRAIN ODD_LOSS,TRAIN ACC1,TRAIN ODD_ACC,"
                    "TRAIN INTRA_LOGITS MEAN,TRAIN INTRA_LOGITS STD,TRAIN INTER_LOGITS MEAN,TRAIN INTER_LOGITS STD,"
                    "TRAIN MAX_PROBS MEAN,TRAIN MAX_PROBS STD,"
                    "TRAIN ENTROPIES MEAN,TRAIN ENTROPIES STD,"
                    "VALID LOSS,VALID ACC1,VALID ODD_ACC,"
                    #"VALID LOSS,VALID ODD_LOSS,VALID ACC1,VALID ODD_ACC,"
                    "VALID INTRA_LOGITS MEAN,VALID INTRA_LOGITS STD,VALID INTER_LOGITS MEAN,VALID INTER_LOGITS STD,"
                    "VALID MAX_PROBS MEAN,VALID MAX_PROBS STD,"
                    "VALID ENTROPIES MEAN,VALID ENTROPIES STD\n"
                )
            with open(self.args.executions_raw_results_file_path, "w") as raw_results:
                #raw_results.write("EXECUTION,EPOCH,SET,TYPE,VALUE\n")
                raw_results.write("DATA,MODEL,LOSS,EXECUTION,EPOCH,SET,METRIC,VALUE\n")

        print("\n################ TRAINING ################")

        best_model_results = {"VALID ACC1": 0}

        for self.epoch in range(1, self.args.epochs + 1):
            print("\n######## EPOCH:", self.epoch, "OF", self.args.epochs, "########")

            # Adjusting learning rate (if not using reduce on plateau)...
            # self.scheduler.step()

            # Print current learning rate...
            for param_group in self.optimizer.param_groups:
                print("\nLEARNING RATE:\t\t", param_group["lr"])

            train_loss, train_acc1, train_odd_acc, train_epoch_logits, train_epoch_metrics = self.train_epoch()
            #train_loss, train_odd_loss, train_acc1, train_odd_acc, train_epoch_logits, train_epoch_metrics = self.train_epoch()


            # Adjusting learning rate (if not using reduce on plateau)...
            self.scheduler.step()
            
            valid_loss, valid_acc1, valid_odd_acc, valid_epoch_logits, valid_epoch_metrics = self.validate_epoch()
            #valid_loss, valid_odd_loss, valid_acc1, valid_odd_acc, valid_epoch_logits, valid_epoch_metrics = self.validate_epoch()

            train_intra_logits_mean = statistics.mean(train_epoch_logits["intra"])
            train_intra_logits_std = statistics.pstdev(train_epoch_logits["intra"])
            train_inter_logits_mean = statistics.mean(train_epoch_logits["inter"])
            train_inter_logits_std = statistics.pstdev(train_epoch_logits["inter"])
            #######################################################################
            train_max_probs_mean = statistics.mean(train_epoch_metrics["max_probs"])
            train_max_probs_std = statistics.pstdev(train_epoch_metrics["max_probs"])
            #train_entropies_mean = statistics.mean(train_epoch_metrics["entropies"])
            #train_entropies_std = statistics.pstdev(train_epoch_metrics["entropies"])
            train_entropies_mean = statistics.mean(train_epoch_metrics["entropies"])/math.log(self.args.number_of_model_classes)
            train_entropies_std = statistics.pstdev(train_epoch_metrics["entropies"])/math.log(self.args.number_of_model_classes)
            #######################################################################
            valid_intra_logits_mean = statistics.mean(valid_epoch_logits["intra"])
            valid_intra_logits_std = statistics.pstdev(valid_epoch_logits["intra"])
            valid_inter_logits_mean = statistics.mean(valid_epoch_logits["inter"])
            valid_inter_logits_std = statistics.pstdev(valid_epoch_logits["inter"])
            #######################################################################
            valid_max_probs_mean = statistics.mean(valid_epoch_metrics["max_probs"])
            valid_max_probs_std = statistics.pstdev(valid_epoch_metrics["max_probs"])
            #valid_entropies_mean = statistics.mean(valid_epoch_metrics["entropies"])
            #valid_entropies_std = statistics.pstdev(valid_epoch_metrics["entropies"])
            valid_entropies_mean = statistics.mean(valid_epoch_metrics["entropies"])/math.log(self.args.number_of_model_classes)
            valid_entropies_std = statistics.pstdev(valid_epoch_metrics["entropies"])/math.log(self.args.number_of_model_classes)
            #######################################################################

            print("\n####################################################")
            print("TRAIN MAX PROB MEAN:\t", train_max_probs_mean)
            print("TRAIN MAX PROB STD:\t", train_max_probs_std)
            print("VALID MAX PROB MEAN:\t", valid_max_probs_mean)
            print("VALID MAX PROB STD:\t", valid_max_probs_std)
            print("####################################################\n")

            print("\n####################################################")
            print("TRAIN ENTROPY MEAN:\t", train_entropies_mean)
            print("TRAIN ENTROPY STD:\t", train_entropies_std)
            print("VALID ENTROPY MEAN:\t", valid_entropies_mean)
            print("VALID ENTROPY STD:\t", valid_entropies_std)
            #print("TRAIN ENTROPY MEAN:\t", train_entropies_mean/math.log(self.args.number_of_model_classes))
            #print("TRAIN ENTROPY STD:\t", train_entropies_std/math.log(self.args.number_of_model_classes))
            #print("VALID ENTROPY MEAN:\t", valid_entropies_mean/math.log(self.args.number_of_model_classes))
            #print("VALID ENTROPY STD:\t", valid_entropies_std/math.log(self.args.number_of_model_classes))
            print("####################################################\n")

            with open(self.args.executions_raw_results_file_path, "a") as raw_results:
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "TRAIN", "LOSS", train_loss))
                #raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                #    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                #    "TRAIN", "ODD_LOSS", train_odd_loss))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "TRAIN", "ACC1", train_acc1))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "TRAIN", "ODD_ACC", train_odd_acc))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "TRAIN", "INTRA_LOGITS MEAN", train_intra_logits_mean))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "TRAIN", "INTRA_LOGITS STD", train_intra_logits_std))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "TRAIN", "INTER_LOGITS MEAN", train_inter_logits_mean))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "TRAIN", "INTER_LOGITS STD", train_inter_logits_std))
                #########################################################
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "TRAIN", "MAX_PROBS MEAN", train_max_probs_mean))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "TRAIN", "MAX_PROBS STD", train_max_probs_std))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "TRAIN", "ENTROPIES MEAN", train_entropies_mean))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "TRAIN", "ENTROPIES STD", train_entropies_std))
                #########################################################               
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "VALID", "LOSS", valid_loss))
                #raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                #    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                #    "VALID", "ODD_LOSS", valid_odd_loss))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "VALID", "ACC1", valid_acc1))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "VALID", "ODD_ACC", valid_odd_acc))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "VALID", "INTRA_LOGITS MEAN", valid_intra_logits_mean))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "VALID", "INTRA_LOGITS STD", valid_intra_logits_std))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "VALID", "INTER_LOGITS MEAN", valid_inter_logits_mean))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "VALID", "INTER_LOGITS STD", valid_inter_logits_std))
                #########################################################
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "VALID", "MAX_PROBS MEAN", valid_max_probs_mean))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "VALID", "MAX_PROBS STD", valid_max_probs_std))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "VALID", "ENTROPIES MEAN", valid_entropies_mean))
                raw_results.write("{},{},{},{},{},{},{},{}\n".format(
                    self.args.dataset_full, self.args.model_name, self.args.loss, self.args.execution, self.epoch,
                    "VALID", "ENTROPIES STD", valid_entropies_std))
                #########################################################               

            print()
            print("TRAIN ==>>\tIADM: {0:.8f}\tIADS: {1:.8f}\tIEDM: {2:.8f}\tIEDS: {3:.8f}".format(
                train_intra_logits_mean, train_intra_logits_std, train_inter_logits_mean, train_inter_logits_std))
            print("VALID ==>>\tIADM: {0:.8f}\tIADS: {1:.8f}\tIEDM: {2:.8f}\tIEDS: {3:.8f}".format(
                valid_intra_logits_mean, valid_intra_logits_std, valid_inter_logits_mean, valid_inter_logits_std))
            print()

            #############################################
            print("\nDATA:", self.args.dataset_full)
            print("MODEL:", self.args.model_name)
            print("LOSS:", self.args.loss, "\n")
            #############################################

            # if is best...
            if valid_acc1 > best_model_results["VALID ACC1"]:
                print("!+NEW BEST MODEL VALID ACC1!")
                best_model_results = {
                    "DATA": self.args.dataset_full,
                    "MODEL": self.args.model_name,
                    "LOSS": self.args.loss,
                    "EXECUTION": self.args.execution,
                    "EPOCH": self.epoch,
                    ###########################################################################
                    "TRAIN LOSS": train_loss,
                    #"TRAIN ODD_LOSS": train_odd_loss,
                    "TRAIN ACC1": train_acc1,
                    "TRAIN ODD_ACC": train_odd_acc,
                    "TRAIN INTRA_LOGITS MEAN": train_intra_logits_mean,
                    "TRAIN INTRA_LOGITS STD": train_intra_logits_std,
                    "TRAIN INTER_LOGITS MEAN": train_inter_logits_mean,
                    "TRAIN INTER_LOGITS STD": train_inter_logits_std,
                    ###########################################################################
                    "TRAIN MAX_PROBS MEAN": train_max_probs_mean,
                    "TRAIN MAX_PROBS STD": train_max_probs_std,
                    "TRAIN ENTROPIES MEAN": train_entropies_mean,
                    "TRAIN ENTROPIES STD": train_entropies_std,
                    ###########################################################################
                    "VALID LOSS": valid_loss,
                    #"VALID ODD_LOSS": valid_odd_loss,
                    "VALID ACC1": valid_acc1,
                    "VALID ODD_ACC": valid_odd_acc,
                    "VALID INTRA_LOGITS MEAN": valid_intra_logits_mean,
                    "VALID INTRA_LOGITS STD": valid_intra_logits_std,
                    "VALID INTER_LOGITS MEAN": valid_inter_logits_mean,
                    "VALID INTER_LOGITS STD": valid_inter_logits_std,
                    ###########################################################################
                    "VALID MAX_PROBS MEAN": valid_max_probs_mean,
                    "VALID MAX_PROBS STD": valid_max_probs_std,
                    "VALID ENTROPIES MEAN": valid_entropies_mean,
                    "VALID ENTROPIES STD": valid_entropies_std,
                    ###########################################################################
                }

                print("!+NEW BEST MODEL VALID ACC1:\t\t{0:.4f} IN EPOCH {1}! SAVING {2}\n".format(
                    valid_acc1, self.epoch, self.args.best_model_file_path))

                torch.save(self.model.state_dict(), self.args.best_model_file_path)
                #torch.save(self.model.state_dict(), self.args.best_model_file_alternative_path)

                np.save(os.path.join(
                    self.args.experiment_path, "best_model"+str(self.args.execution)+"_train_epoch_logits.npy"), train_epoch_logits)
                np.save(os.path.join(
                    self.args.experiment_path, "best_model"+str(self.args.execution)+"_train_epoch_metrics.npy"), train_epoch_metrics)
                np.save(os.path.join(
                    self.args.experiment_path, "best_model"+str(self.args.execution)+"_valid_epoch_logits.npy"), valid_epoch_logits)
                np.save(os.path.join(
                    self.args.experiment_path, "best_model"+str(self.args.execution)+"_valid_epoch_metrics.npy"), valid_epoch_metrics)

            print('!$$$$ BEST MODEL TRAIN ACC1:\t\t{0:.4f}'.format(best_model_results["TRAIN ACC1"]))
            print('!$$$$ BEST MODEL VALID ACC1:\t\t{0:.4f}'.format(best_model_results["VALID ACC1"]))

            # Adjusting learning rate (if using reduce on plateau)...
            # scheduler.step(valid_acc1)

        with open(self.args.executions_best_results_file_path, "a") as best_results:
            #best_results.write("{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
            best_results.write("{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
                best_model_results["DATA"],
                best_model_results["MODEL"],
                best_model_results["LOSS"],
                best_model_results["EXECUTION"],
                best_model_results["EPOCH"],
                #############################################
                best_model_results["TRAIN LOSS"],
                #best_model_results["TRAIN ODD_LOSS"],
                best_model_results["TRAIN ACC1"],
                best_model_results["TRAIN ODD_ACC"],
                best_model_results["TRAIN INTRA_LOGITS MEAN"],
                best_model_results["TRAIN INTRA_LOGITS STD"],
                best_model_results["TRAIN INTER_LOGITS MEAN"],
                best_model_results["TRAIN INTER_LOGITS STD"],
                #############################################
                best_model_results["TRAIN MAX_PROBS MEAN"],
                best_model_results["TRAIN MAX_PROBS STD"],
                best_model_results["TRAIN ENTROPIES MEAN"],
                best_model_results["TRAIN ENTROPIES STD"],
                #############################################
                best_model_results["VALID LOSS"],
                #best_model_results["VALID ODD_LOSS"],
                best_model_results["VALID ACC1"],
                best_model_results["VALID ODD_ACC"],
                best_model_results["VALID INTRA_LOGITS MEAN"],
                best_model_results["VALID INTRA_LOGITS STD"],
                best_model_results["VALID INTER_LOGITS MEAN"],
                best_model_results["VALID INTER_LOGITS STD"],
                #############################################
                best_model_results["VALID MAX_PROBS MEAN"],
                best_model_results["VALID MAX_PROBS STD"],
                best_model_results["VALID ENTROPIES MEAN"],
                best_model_results["VALID ENTROPIES STD"],
                #############################################
                )
            )

        # extracting features from best model...
        self.extract_features_for_all_sets(self.args.best_model_file_path)
        print()

    def train_epoch(self):
        print()
        # switch to train mode
        self.model.train()
        self.criterion.train()

        # Meters...
        loss_meter = utils.MeanMeter()
        #odd_loss_meter = utils.MeanMeter()
        accuracy_meter = tnt.meter.ClassErrorMeter(topk=[1], accuracy=True)
        odd_accuracy_meter = tnt.meter.ClassErrorMeter(topk=[1], accuracy=True)
        epoch_logits = {"intra": [], "inter": []}
        #epoch_metrics = {"max_probs": [], "entropies": []}
        epoch_metrics = {"max_probs": [], "entropies": [], "max_logits": []}
        #epoch_entropies_per_classes =  [[] for i in range(self.model.classifier.weights.size(0))]

        batch_index = 0
        ##############################
        #self.ood_loader.dataset.offset = 1234
        #for in_data, ood_data in zip(self.trainset_loader_for_train, self.ood_loader):
        ###############################
        for in_data in self.trainset_loader_for_train:
            batch_index += 1

            # moving to GPU...
            inputs = in_data[0].cuda() 
            targets = in_data[1].cuda(non_blocking=True)

            """
            if self.args.loss.split("_")[11].startswith("oe"):
                print("outlier exposure #1")
                inputs = torch.cat((inputs, ood_data[0].cuda()), 0)
                #if batch_index == 1:
                #    plt.imshow(inputs[len(targets)].cpu().permute(1, 2, 0))  
                #    plt.show()
            """

            # adding noise...
            #noisy_data = torch.clamp(noisy_data, min_pixel, max_pixel)    

            # compute output
            #inputs = torch.add(inputs, 0.10, torch.randn(inputs.size()).cuda())

            # forward pass
            features = self.model(inputs)
            
            # compute loss
            #loss, cls_probabilities, odd_probabilities, intra_logits, inter_logits = self.criterion(features, targets, augmented=augmented)
            #loss, cls_probabilities, odd_probabilities, intra_logits, inter_logits = self.criterion(features, targets)
            loss, cls_probabilities, odd_probabilities, max_logits, intra_logits, inter_logits = self.criterion(features, targets)
            if self.args.loss.split("_")[3].startswith("cl"):
                print("center loss regularization!!!")
                loss += (float(self.args.loss.split("_")[3].strip("cl"))/2) * self.center_loss(features, targets)


            intra_logits = intra_logits.tolist()
            inter_logits = inter_logits.tolist()
            #if not augmented:
            ######################################################################################
            max_probs = odd_probabilities.max(dim=1)[0]
            entropies = utils.entropies_from_probabilities(odd_probabilities)
            ######################################################################################
            loss_meter.add(loss.item(), targets.size(0))
            accuracy_meter.add(cls_probabilities.detach(), targets.detach())
            odd_accuracy_meter.add(odd_probabilities.detach(), targets.detach())
            ######################################################################################
            if self.args.number_of_model_classes > 100:
                print("WARMING!!! DO NOT BLINDLY TRUST EPOCH LOGITS STATISTICS!!!")
                epoch_logits["intra"] = intra_logits
                epoch_logits["inter"] = inter_logits
            else:
                epoch_logits["intra"] += intra_logits
                epoch_logits["inter"] += inter_logits
            epoch_metrics["max_probs"] += max_probs.tolist()
            epoch_metrics["max_logits"] += max_logits.tolist()
            epoch_metrics["entropies"] += entropies.tolist()
            ######################################################################################


            # zero grads, compute gradients and do optimizer step
            self.optimizer.zero_grad()
            if self.args.loss.split("_")[3].startswith("cl"):
                print("center loss regularization!!!")
                self.optimizer_centloss.zero_grad()
            loss.backward()
            self.optimizer.step()
            if self.args.loss.split("_")[3].startswith("cl"):
                print("center loss regularization!!!")
                for param in self.center_loss.parameters():
                    param.grad.data *= (1./(float(self.args.loss.split("_")[3].strip("cl"))/2))
                self.optimizer_centloss.step()

            if batch_index % self.args.print_freq == 0:
                print('Train Epoch: [{0}][{1:3}/{2}]\t'
                      'Loss {loss:.8f}\t\t'
                      'Acc1 {acc1_meter:.2f}\t'
                      'IADM {intra_logits_mean:.4f}\t'
                      'IADS {intra_logits_std:.8f}\t\t'
                      'IEDM {inter_logits_mean:.4f}\t'
                      'IEDS {inter_logits_std:.8f}'
                      .format(self.epoch, batch_index, len(self.trainset_loader_for_train),
                              loss=loss_meter.avg,
                              acc1_meter=accuracy_meter.value()[0],
                              intra_logits_mean=statistics.mean(intra_logits),
                              intra_logits_std=statistics.stdev(intra_logits),
                              inter_logits_mean=statistics.mean(inter_logits),
                              inter_logits_std=statistics.stdev(inter_logits),
                              )
                      )

        print('\n#### TRAIN ACC1:\t{0:.4f}\n\n'.format(accuracy_meter.value()[0]))

        return loss_meter.avg, accuracy_meter.value()[0], odd_accuracy_meter.value()[0], epoch_logits, epoch_metrics
        #return loss_meter.avg, odd_loss_meter.avg, accuracy_meter.value()[0], odd_accuracy_meter.value()[0], epoch_logits, epoch_metrics

    def validate_epoch(self):
        print()
        # switch to evaluate mode
        self.model.eval()
        self.criterion.eval()

        # Meters...
        loss_meter = utils.MeanMeter()
        #odd_loss_meter = utils.MeanMeter()
        accuracy_meter = tnt.meter.ClassErrorMeter(topk=[1], accuracy=True)
        odd_accuracy_meter = tnt.meter.ClassErrorMeter(topk=[1], accuracy=True)
        epoch_logits = {"intra": [], "inter": []}
        #epoch_metrics = {"max_probs": [], "entropies": []}
        epoch_metrics = {"max_probs": [], "entropies": [], "max_logits": []}
        #epoch_entropies_per_classes =  [[] for i in range(self.model.classifier.weights.size(0))]

        with torch.no_grad():

            batch_index = 0
            #for batch_index, (inputs, targets) in enumerate(self.valset_loader):
            for in_data in self.valset_loader:
                batch_index += 1

                # moving to GPU...
                #inputs = inputs.cuda()
                #targets = targets.cuda(non_blocking=True)
                inputs = in_data[0].cuda()
                targets = in_data[1].cuda(non_blocking=True)

                # compute output
                self.model.classifier.metrics_evaluation_mode = True
                features = self.model(inputs)

                # compute loss
                # loss, intra_logits, and inter_logits are already allways using the correct batch size in the bellow line of code...
                #loss, outputs, odd_outputs, intra_logits, inter_logits = self.criterion(features, targets)
                #loss, cls_probabilities, odd_probabilities, intra_logits, inter_logits = self.criterion(features, targets)
                loss, cls_probabilities, odd_probabilities, max_logits, intra_logits, inter_logits = self.criterion(features, targets)
                self.model.classifier.metrics_evaluation_mode = False


                intra_logits = intra_logits.tolist()
                inter_logits = inter_logits.tolist()
                ########################################################################
                max_probs = odd_probabilities.max(dim=1)[0]
                entropies = utils.entropies_from_probabilities(odd_probabilities)
                ########################################################################
                # accumulate metrics over batches...
                loss_meter.add(loss.item(), inputs.size(0))
                accuracy_meter.add(cls_probabilities.detach(), targets.detach())
                odd_accuracy_meter.add(odd_probabilities.detach(), targets.detach())
                ########################################################################
                if self.args.number_of_model_classes > 100:
                    print("WARMING!!! DO NOT BLINDLY TRUST EPOCH LOGITS STATISTICS!!!")
                    epoch_logits["intra"] = intra_logits
                    epoch_logits["inter"] = inter_logits
                else:
                    epoch_logits["intra"] += intra_logits
                    epoch_logits["inter"] += inter_logits
                epoch_metrics["max_probs"] += max_probs.tolist()
                epoch_metrics["max_logits"] += max_logits.tolist()
                epoch_metrics["entropies"] += entropies.tolist()
                ########################################################################


                if batch_index % self.args.print_freq == 0:
                    print('Valid Epoch: [{0}][{1:3}/{2}]\t'
                          'Loss {loss:.8f}\t\t'
                          'Acc1 {acc1_meter:.2f}\t'
                          'IADM {intra_logits_mean:.4f}\t'
                          'IADS {intra_logits_std:.8f}\t\t'
                          'IEDM {inter_logits_mean:.4f}\t'
                          'IEDS {inter_logits_std:.8f}'
                          .format(self.epoch, batch_index, len(self.valset_loader),
                                  loss=loss_meter.avg,
                                  acc1_meter=accuracy_meter.value()[0],
                                  intra_logits_mean=statistics.mean(intra_logits),
                                  intra_logits_std=statistics.stdev(intra_logits),
                                  inter_logits_mean=statistics.mean(inter_logits),
                                  inter_logits_std=statistics.stdev(inter_logits),
                                  )
                          )

        print('\n#### VALID ACC1:\t{0:.4f}\n\n'.format(accuracy_meter.value()[0]))

        return loss_meter.avg, accuracy_meter.value()[0], odd_accuracy_meter.value()[0], epoch_logits, epoch_metrics
        #return loss_meter.avg, odd_loss_meter.avg, accuracy_meter.value()[0], odd_accuracy_meter.value()[0], epoch_logits, epoch_metrics

    def extract_features_for_all_sets(self, model_file_path):
        print("\n################ EXTRACTING FEATURES ################")

        # Loading best model...
        if os.path.isfile(model_file_path):
            print("\n=> loading checkpoint '{}'".format(model_file_path))
            #checkpoint = torch.load(model_file_path)
            #self.model.load_state_dict(checkpoint['best_model_state_dict'])
            #print("=> loaded checkpoint '{}' (epoch {})".format(model_file_path, checkpoint['best_model_epoch']))
            self.model.load_state_dict(torch.load(model_file_path, map_location="cuda:" + str(self.args.gpu_id)))

            print("=> loaded checkpoint '{}'".format(model_file_path))
        else:
            print("=> no checkpoint found at '{}'".format(model_file_path))
            return

        features_trainset_first_partition_file_path = '{}.pth'.format(os.path.splitext(model_file_path)[0]+'_trainset_first_partition')
        features_trainset_second_partition_file_path = '{}.pth'.format(os.path.splitext(model_file_path)[0]+'_trainset_second_partition')
        features_valset_file_path = '{}.pth'.format(os.path.splitext(model_file_path)[0]+'_valset')

        if len(self.trainset_first_partition_loader_for_infer) != 0:
            self.extract_features_from_loader(
                self.trainset_first_partition_loader_for_infer, features_trainset_first_partition_file_path)
        if len(self.trainset_second_partition_loader_for_infer) != 0:
            self.extract_features_from_loader(
                self.trainset_second_partition_loader_for_infer, features_trainset_second_partition_file_path)
        self.extract_features_from_loader(self.valset_loader, features_valset_file_path)

    def extract_features_from_loader(self, loader, file_path):
        # switch to evaluate mode
        self.model.eval()
        # print('\nExtract features on {}set'.format(loader.dataset.set))
        print('Extract features on {}'.format(loader.dataset))

        with torch.no_grad():
            for batch_id, (input_tensor, target_tensor) in enumerate(tqdm(loader)):
                # moving to GPU...
                input_tensor = input_tensor.cuda()
                # target_tensor = target_tensor.cuda(non_blocking=True)
                # compute batch logits and features...
                batch_logits, batch_features = self.model.logits_features(input_tensor)
                if batch_id == 0:
                    logits = torch.Tensor(len(loader.sampler), self.args.number_of_model_classes)
                    features = torch.Tensor(len(loader.sampler), batch_features.size()[1])
                    targets = torch.Tensor(len(loader.sampler))
                    print("LOGITS:", logits.size())
                    print("FEATURES:", features.size())
                    print("TARGETS:", targets.size())
                current_bsize = input_tensor.size(0)
                from_ = int(batch_id * loader.batch_size)
                to_ = int(from_ + current_bsize)
                logits[from_:to_] = batch_logits.cpu()
                features[from_:to_] = batch_features.cpu()
                targets[from_:to_] = target_tensor

        os.system('mkdir -p {}'.format(os.path.dirname(file_path)))
        print('save ' + file_path)
        torch.save((logits, features, targets), file_path)
        return logits, features, targets

    def odd_infer(self):
        print("\n################ INFERING ################")

        # Loading best model...
        if os.path.isfile(self.args.best_model_file_path):
            print("\n=> loading checkpoint '{}'".format(self.args.best_model_file_path))
            self.model.load_state_dict(torch.load(self.args.best_model_file_path, map_location="cuda:" + str(self.args.gpu_id)))
            print("=> loaded checkpoint '{}'".format(self.args.best_model_file_path))
        else:
            print("=> no checkpoint found at '{}'".format(self.args.best_model_file_path))
            return

        # preparing and normalizing data
        if self.args.dataset == 'cifar10':
            in_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])
        elif self.args.dataset == 'cifar100':
            in_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.507, 0.486, 0.440), (0.267, 0.256, 0.276))])
        elif self.args.dataset == 'svhn':
            in_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.437, 0.443, 0.472), (0.198, 0.201, 0.197))])

        if self.args.dataset == 'cifar10':
            out_dist_list = ['svhn', 'imagenet_resize', 'lsun_resize']
            #out_dist_list = ['svhn', 'imagenet_resize', 'lsun_resize', 'fooling_images', 'gaussian_noise','uniform_noise']
        elif self.args.dataset == 'cifar100':
            out_dist_list = ['svhn', 'imagenet_resize', 'lsun_resize']
            #out_dist_list = ['svhn', 'imagenet_resize', 'lsun_resize', 'fooling_images', 'gaussian_noise','uniform_noise']
        elif self.args.dataset == 'svhn':
            out_dist_list = ['cifar10', 'imagenet_resize', 'lsun_resize']
            #out_dist_list = ['cifar100', 'cifar10', 'imagenet_resize', 'lsun_resize', 'fooling_images', 'gaussian_noise','uniform_noise']

        # Storing logits and metrics in out-distribution...
        for out_dist in out_dist_list:
            print('Out-distribution: ' + out_dist)
            self.valset_loader = data_loader.getNonTargetDataSet(out_dist, self.args.batch_size, in_transform, "data")
            #_, _, valid_epoch_logits, valid_epoch_metrics, valid_epoch_entropies_per_classes = self.validate_epoch()
            _, _, _, valid_epoch_logits, valid_epoch_metrics = self.validate_epoch()
            np.save(os.path.join(
                self.args.experiment_path,
                "best_model"+str(self.args.execution)+"_valid_epoch_logits_"+out_dist+".npy"),
                valid_epoch_logits)
            np.save(os.path.join(
                self.args.experiment_path,
                "best_model"+str(self.args.execution)+"_valid_epoch_metrics_"+out_dist+".npy"),
                valid_epoch_metrics)

