import os
import random
import time
import json
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import sys
import logging
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from tqdm import tqdm


from ofa.utils import get_net_info, cross_entropy_loss_with_soft_target, cross_entropy_with_label_smoothing
from ofa.utils import AverageMeter, accuracy, write_log, mix_images, mix_labels, init_models
from ofa.utils import MyRandomResizedCrop
from .utils import *

def Entropy(input_):
    bs = input_.size(0)
    epsilon = 1e-5
    entropy = -input_ * torch.log(input_ + epsilon)
    entropy = torch.sum(entropy, dim=1)
    return entropy

from scipy.spatial.distance import cdist


class LiteDomainAdaption:
    def __init__(self, path, net, optimizer, run_config, target_loader, test_loader, n_classes, init=True, measure_latency=None, no_gpu=False):
        self.path = path
        self.net = net
        self.optimizer = optimizer
        self.run_config = run_config
        self.target_loader = target_loader
        self.test_loader = test_loader
        self.n_classes = n_classes
        self.best_acc = 0
        self.start_epoch = 0

        os.makedirs(self.path, exist_ok=True)

        # move network to GPU if available
        if torch.cuda.is_available() and (not no_gpu):
            self.device = torch.device('cuda:0')
            self.net = self.net.to(self.device)
            cudnn.benchmark = True
        else:
            self.device = torch.device('cpu')
        # initialize model (default)
        # if init:
        #     init_models(run_config.model_init)

        # # net info
        # net_info = get_net_info(self.net, self.run_config.data_provider.data_shape, measure_latency, True)
        # with open('%s/net_info.txt' % self.path, 'w') as fout:
        #     fout.write(json.dumps(net_info, indent=4) + '\n')
        #     # noinspection PyBroadException
        #     try:
        #         fout.write(self.network.module_str + '\n')
        #     except Exception:
        #         pass
        #     fout.write('%s\n' % self.run_config.data_provider.train.dataset.transform)
        #     fout.write('%s\n' % self.run_config.data_provider.test.dataset.transform)
        #     fout.write('%s\n' % self.network)
        #     print(self.network)

        # criterion
        # if isinstance(self.run_config.mixup_alpha, float):
        #     self.train_criterion = cross_entropy_loss_with_soft_target
        # elif self.run_config.label_smoothing > 0:
        #     self.train_criterion = \
        #         lambda pred, target: cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
        # else:
        #     self.train_criterion = nn.CrossEntropyLoss()
        self.test_criterion = nn.CrossEntropyLoss()

        # # optimizer
        # if self.run_config.no_decay_keys:
        #     keys = self.run_config.no_decay_keys.split('#')
        #     net_params = [
        #         self.network.get_parameters(keys, mode='exclude'),  # parameters with weight decay
        #         self.network.get_parameters(keys, mode='include'),  # parameters without weight decay
        #     ]
        # else:
        #     # noinspection PyBroadException
        #     try:
        #         net_params = self.network.weight_parameters()
        #     except Exception:
        #         net_params = []
        #         for param in self.network.parameters():
        #             if param.requires_grad:
        #                 net_params.append(param)
        # self.optimizer = self.run_config.build_optimizer(net_params)

        # self.net = torch.nn.DataParallel(self.net)

    """ save path and log path """
    @property
    def save_path(self):
        if self.__dict__.get('_save_path', None) is None:
            save_path = os.path.join(self.path, 'checkpoint')
            os.makedirs(save_path, exist_ok=True)
            self.__dict__['_save_path'] = save_path
        return self.__dict__['_save_path']

    @property
    def logs_path(self):
        if self.__dict__.get('_logs_path', None) is None:
            logs_path = os.path.join(self.path, 'logs')
            os.makedirs(logs_path, exist_ok=True)
            self.__dict__['_logs_path'] = logs_path
        return self.__dict__['_logs_path']

    @property
    def network(self):
        return self.net.module if isinstance(self.net, nn.DataParallel) else self.net

    def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
        write_log(self.logs_path, log_str, prefix, should_print, mode)

    """ save and load models """

    def save_model(self, checkpoint=None, is_best=False, model_name=None):
        if checkpoint is None:
            checkpoint = {'state_dict': self.network.state_dict()}

        if model_name is None:
            model_name = 'checkpoint.pth.tar'

        checkpoint['dataset'] = self.run_config.dataset  # add `dataset` info to the checkpoint
        latest_fname = os.path.join(self.save_path, 'latest.txt')
        model_path = os.path.join(self.save_path, model_name)
        with open(latest_fname, 'w') as fout:
            fout.write(model_path + '\n')
        torch.save(checkpoint, model_path)

        if is_best:
            best_path = os.path.join(self.save_path, 'model_best.pth.tar')
            torch.save({'state_dict': checkpoint['state_dict']}, best_path)

    def load_model(self, model_fname=None):
        latest_fname = os.path.join(self.save_path, 'latest.txt')
        if model_fname is None and os.path.exists(latest_fname):
            with open(latest_fname, 'r') as fin:
                model_fname = fin.readline()
                if model_fname[-1] == '\n':
                    model_fname = model_fname[:-1]
        # noinspection PyBroadException
        try:
            if model_fname is None or not os.path.exists(model_fname):
                model_fname = '%s/checkpoint.pth.tar' % self.save_path
                with open(latest_fname, 'w') as fout:
                    fout.write(model_fname + '\n')
            print("=> loading checkpoint '{}'".format(model_fname))
            checkpoint = torch.load(model_fname, map_location='cpu')
        except Exception:
            print('fail to load checkpoint from %s' % self.save_path)
            return {}

        self.network.load_state_dict(checkpoint['state_dict'])
        if 'epoch' in checkpoint:
            self.start_epoch = checkpoint['epoch'] + 1
        if 'best_acc' in checkpoint:
            self.best_acc = checkpoint['best_acc']
        if 'optimizer' in checkpoint:
            self.optimizer.load_state_dict(checkpoint['optimizer'])

        print("=> loaded checkpoint '{}'".format(model_fname))
        return checkpoint

    def save_config(self, extra_run_config=None, extra_net_config=None):
        """ dump run_config and net_config to the model_folder """
        run_save_path = os.path.join(self.path, 'run.config')
        if not os.path.isfile(run_save_path):
            run_config = self.run_config.config
            if extra_run_config is not None:
                run_config.update(extra_run_config)
            json.dump(run_config, open(run_save_path, 'w'), indent=4)
            print('Run configs dump to %s' % run_save_path)

        try:
            net_save_path = os.path.join(self.path, 'net.config')
            net_config = self.network.config
            if extra_net_config is not None:
                net_config.update(extra_net_config)
            json.dump(net_config, open(net_save_path, 'w'), indent=4)
            print('Network configs dump to %s' % net_save_path)
        except Exception:
            print('%s do not support net config' % type(self.network))

    """ metric related """

    def get_metric_dict(self):
        return {
            'top1': AverageMeter(),
            'top5': AverageMeter(),
        }

    def update_metric(self, metric_dict, output, labels):
        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        # metric_dict['top1'].update(acc1[0].item(), output.size(0))
        # metric_dict['top5'].update(acc5[0].item(), output.size(0))
        metric_dict['top1'].update(acc1.item(), output.size(0))
        metric_dict['top5'].update(acc5.item(), output.size(0))

    def get_metric_vals(self, metric_dict, return_dict=False):
        if return_dict:
            return {
                key: metric_dict[key].avg for key in metric_dict
            }
        else:
            return [metric_dict[key].avg for key in metric_dict]

    def get_metric_names(self):
        return 'top1', 'top5'

    """ train and test """

    def validate(self, epoch=0, run_str='', net=None, no_logs=False, train_mode=False, is_test=False):
        if net is None:
            net = self.net
        if not isinstance(net, nn.DataParallel):
            net = nn.DataParallel(net)

        # if data_loader is None:
        #     data_loader = self.run_config.test_loader if is_test else self.run_config.valid_loader

        # if train_mode:
        #     net.train()
        # else:
        #     net.eval()
        net.eval()

        losses = AverageMeter()
        metric_dict = self.get_metric_dict()

        with torch.no_grad():
            with tqdm(total=len(self.test_loader),
                      desc='Validate Epoch #{} {}'.format(epoch + 1, run_str), disable=no_logs) as t:
                for i, (images, labels) in enumerate(self.test_loader):
                    # print()
                    # print("--------------------------------------")
                    # print("The size of images: ",images.size(),"datatype: ",images.dtype)
                    # print("The size of labels: ",labels.size())
                    # print("--------------------------------------DONE")

                    images, labels = images.to(self.device), labels.to(self.device)
                    # compute output
                    output, _ = net(images) ### MODIFIED
                    loss = self.test_criterion(output, labels)
                    # measure accuracy and record loss
                    self.update_metric(metric_dict, output, labels)

                    losses.update(loss.item(), images.size(0))
                    t.set_postfix({
                        'loss': losses.avg,
                        **self.get_metric_vals(metric_dict, return_dict=True),
                        'img_size': images.size(2),
                    })
                    t.update(1)
        return losses.avg, self.get_metric_vals(metric_dict)

    # def validate_all_resolution(self, epoch=0, is_test=False, net=None):
    #     if net is None:
    #         net = self.network
    #     if isinstance(self.run_config.data_provider.image_size, list):
    #         img_size_list, loss_list, top1_list, top5_list = [], [], [], []
    #         for img_size in self.run_config.data_provider.image_size:
    #             img_size_list.append(img_size)
    #             self.run_config.data_provider.assign_active_img_size(img_size)
    #             self.reset_running_statistics(net=net)
    #             loss, (top1, top5) = self.validate(epoch, net=net, is_test=is_test)
    #             loss_list.append(loss)
    #             top1_list.append(top1)
    #             top5_list.append(top5)
    #         return img_size_list, loss_list, top1_list, top5_list
    #     else:
    #         loss, (top1, top5) = self.validate(epoch, net=net, is_test=is_test)
    #         return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]

    def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
        from ofa.imagenet_classification.elastic_nn.utils import set_running_statistics
        if net is None:
            net = self.network
        if data_loader is None:
            data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
        set_running_statistics(net, data_loader)

    def obtain_label(self, loader):
        # a test class
        class test_args:
            def __init__(self, class_num):
                self.distance = 'cosine'
                self.epsilon = 1e-5
                self.class_num = class_num
                self.threshold = 0
                self.largest_dist_const = 1.1

        args = test_args(self.n_classes)

        start_test = True
        self.net.eval()  # SET TO EVAL MODE
        with torch.no_grad():
            image_set = []

            iter_test = iter(loader)
            for counting in range(len(loader)):
                data = iter_test.next()
                inputs = data[0]  # IMAGE
                labels = data[1]  # LABEL

                image_set.append((counting, data))  # Insert the image to the list image_set

                inputs = inputs.cuda()
                outputs, feas = self.net(inputs)
                # outputs = netC(feas)
                if start_test:
                    all_fea = feas.float().cpu()
                    all_output = outputs.float().cpu()
                    all_label = labels.float()
                    all_inputs = inputs.float()
                    start_test = False
                else:
                    all_fea = torch.cat((all_fea, feas.float().cpu()), 0)  # CONCATENATING ALL THE LABELS
                    all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                    all_label = torch.cat((all_label, labels.float()), 0)
                    all_inputs = torch.cat((all_inputs, inputs.float()), 0)

        all_output = nn.Softmax(dim=1)(all_output)
        ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1)
        unknown_weight = 1 - ent / np.log(args.class_num)
        _, predict = torch.max(all_output, 1)

        accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
        print('Predict Acc: {}'.format(accuracy))
        if args.distance == 'cosine':
            all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
            all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
        else:
            raise Exception('Distance type error.')

        all_fea = all_fea.float().cpu().numpy()
        K = all_output.size(1)
        aff = all_output.float().cpu().numpy()
        initc = aff.transpose().dot(all_fea)  # initial centroids
        initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
        cls_count = np.eye(K)[predict].sum(axis=0)  # sum to count the number of labels in each class, len:102(nclass)
        labelset = np.where(cls_count > args.threshold)
        labelset = labelset[0]  # count the number
        # print(labelset)

        dd = cdist(all_fea, initc[labelset], args.distance)
        pred_label = dd.argmin(axis=1)
        pred_label = labelset[pred_label]

        for round in range(1):
            aff = np.eye(K)[pred_label]
            initc = aff.transpose().dot(all_fea)
            initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
            dd = cdist(all_fea, initc[labelset], args.distance)

            for i_ in range(len(dd)):
                for j_ in range(len(dd[0])):
                    if str(dd[i_][j_]).lower() == 'nan':
                        dd[i_][j_] = args.largest_dist_const
                        # print("{} {}".format(i_, j_))
            pred_label = dd.argmin(axis=1)
            pred_label = labelset[pred_label]

        # acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
        # log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)

        # args.out_file.write(log_str + '\n')
        # args.out_file.flush()
        # print(log_str+'\n')

        # RESET TO TRAIN MODE
        self.net.train()
        pred_label = pred_label.astype('int')
        # for i in range(102):
        #     if i*np.ones(len(pred_label)) != pred_label.any():
        #         print("Warning: {}".format(i))

        # return pred_label, image_set
        return pred_label

    def train_one_epoch(self, args, epoch, teacher_model=None, warmup_epochs=0, warmup_lr=0):

        class convert_args:
            def __init__(self):
                self.cls_par = 0.0
                self.ent_par = 1.0
                self.gent = True
                self.ent = True
                self.epsilon = 1e-5

        start_test = True  # ADDED
        # switch to train mode
        self.net.train()
        self.net.classifier.eval()
        # MyRandomResizedCrop.EPOCH = epoch  # required by elastic resolution

        nBatch = len(self.target_loader)

        # print(type(self.run_config.train_loader))

        # testing = iter(self.run_config.train_loader)
        # for i in range(10):
        #     images, labels = testing.next()

        #     print(images, labels)

        pseudo_labels = self.obtain_label(self.target_loader)
        pseudo_labels = torch.from_numpy(pseudo_labels)



        losses = AverageMeter()
        metric_dict = self.get_metric_dict()
        data_time = AverageMeter()

        with tqdm(total=nBatch,
                  desc='{} Train Epoch #{}'.format('office31', epoch + 1)) as t:
            end = time.time()
            start = 0

            # for i, (images, labels) in enumerate(self.run_config.train_loader):
            for images, real_labels in self.target_loader:

                labels = pseudo_labels[start:(start + len(real_labels))]

                start += len(real_labels)

                # MyRandomResizedCrop.BATCH = i
                data_time.update(time.time() - end)
                # if epoch < warmup_epochs:
                #     new_lr = self.run_config.warmup_adjust_learning_rate(
                #         self.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
                #     )
                # else:
                #     new_lr = self.run_config.adjust_learning_rate(self.optimizer, epoch - warmup_epochs, i, nBatch)

                images, labels = images.to(self.device), labels.to(self.device)
                real_labels = real_labels.to(self.device)
                # target = labels
                # if isinstance(self.run_config.mixup_alpha, float):
                #     # transform data
                #     lam = random.betavariate(self.run_config.mixup_alpha, self.run_config.mixup_alpha)
                #     images = mix_images(images, lam)
                #     labels = mix_labels(
                #         labels, lam, self.n_classes, self.run_config.label_smoothing
                #     )

                # soft target
                if args.teacher_model is not None:
                    args.teacher_model.eval()
                    with torch.no_grad():
                        soft_logits = args.teacher_model(images)
                        soft_label = F.softmax(soft_logits, dim=1)

                # compute output
                output, _ = self.net(images)  ### MODIFIED

                '''
                Added from SHOT:
                some transformation
                Note: we assume all class_nums are known. See SHOT image_target.py line 213
                '''
                outputs_test_known, pred = output, labels
                softmax_out = F.softmax(outputs_test_known, dim=1)

                this_args = convert_args()

                if this_args.cls_par > 0:
                    classifier_loss = nn.CrossEntropyLoss()(outputs_test_known, pred)
                    classifier_loss *= this_args.cls_par
                else:
                    classifier_loss = torch.tensor(0.0).cuda()

                if this_args.ent:
                    softmax_out_known = nn.Softmax(dim=1)(outputs_test_known)
                    entropy_loss = torch.mean(Entropy(softmax_out_known))
                    # print(entropy_loss)
                    if this_args.gent:
                        msoftmax = softmax_out.mean(dim=0)
                        gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + this_args.epsilon))
                        entropy_loss -= gentropy_loss
                    classifier_loss += entropy_loss * this_args.ent_par

                # print('training loss: {}'.format(classifier_loss.item()))

                if args.teacher_model is not None:
                    loss_type = 'IM + PL + KD'
                    criterion_KL = KL_Loss(args.temperature)
                    kd_loss = criterion_KL(output, soft_logits)
                    classifier_loss += args.kd_ratio * kd_loss
                else:
                    loss_type = 'IM + PL'


                # # FOR DEBUGGING
                # print("===========================OUTPUT SIZE")
                # print(output.size())
                # print("===========================FEATURE")
                # print(feas)
                # print("===========================ONE ITEM OF RESULT")
                # print(output[0])
                # print("===========================LABEL:")
                # print(labels[0])
                # print("===========================END")

                '''
                This is the place for loss modification!
                (Completed here)
            

                =================================
                The original loss
                '''
                # loss = self.train_criterion(output, labels)
                #
                # if args.teacher_model is None:
                #     loss_type = 'ce'
                # else:
                #     criterion_KL = KL_Loss(args.temperature)
                #     kd_loss = criterion_KL(output, soft_logits)
                #     # if args.kd_type == 'ce':
                #     #     kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
                #     # else:
                #     #     kd_loss = F.mse_loss(output, soft_logits)
                #     loss = args.kd_ratio * kd_loss + loss
                #     loss_type = '%.1fkd+ce' % args.kd_ratio

                # compute gradient and do SGD step
                self.net.zero_grad()  # or self.optimizer.zero_grad()
                classifier_loss.backward()
                # classifier_loss.backward()
                self.optimizer.step()

                # measure accuracy and record loss
                losses.update(classifier_loss.item(), images.size(0))
                # losses.update(classifier_loss.item(), images.size(0))
                self.update_metric(metric_dict, output, real_labels)

                t.set_postfix({
                    'loss': losses.avg,
                    **self.get_metric_vals(metric_dict, return_dict=True),
                    'img_size': images.size(2),
                    'loss_type': loss_type,
                    'data_time': data_time.avg,
                })
                t.update(1)
                end = time.time()
        return losses.avg, self.get_metric_vals(metric_dict)

    def train(self, args, teacher_model=None, warmup_epoch=0, warmup_lr=0):
        for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epoch):
            train_loss, (train_top1, train_top5) = self.train_one_epoch(args, epoch,
                                                                        teacher_model, warmup_epoch, warmup_lr)

            if (epoch + 1) % self.run_config.validation_frequency == 0:
                val_loss, (val_acc, val_acc5) = self.validate(epoch=epoch, is_test=False, net=self.network)
                # img_size, val_loss, val_acc, val_acc5 = self.validate_all_resolution(epoch=epoch, is_test=False)

                is_best = np.mean(val_acc) > self.best_acc
                self.best_acc = max(self.best_acc, np.mean(val_acc))
                val_log = 'Valid [{0}/{1}]\tloss {2:.3f}\t{5} {3:.3f} ({4:.3f})'. \
                    format(epoch + 1 - warmup_epoch, self.run_config.n_epochs,
                           np.mean(val_loss), np.mean(val_acc), self.best_acc, self.get_metric_names()[0])
                val_log += '\t{2} {0:.3f}\tTrain {1} {top1:.3f}\tloss {train_loss:.3f}\t'. \
                    format(np.mean(val_acc5), *self.get_metric_names(), top1=train_top1, train_loss=train_loss)
                img_size = 224
                val_log += '(%d, %.3f), ' % (img_size, val_acc)
                # for i_s, v_a in zip(img_size, val_acc):
                #     val_log += '(%d, %.3f), ' % (i_s, v_a)
                self.write_log(val_log, prefix='valid', should_print=False)
            else:
                is_best = False

            self.save_model({
                'epoch': epoch,
                'best_acc': self.best_acc,
                'optimizer': self.optimizer.state_dict(),
                'state_dict': self.network.state_dict(),
            }, is_best=is_best)




# class RunManager:
#     def __init__(self, path, net, run_config, init=True, measure_latency=None, no_gpu=False):
#         self.path = path
#         self.net = net
#         self.run_config = run_config
#
#         self.best_acc = 0
#         self.start_epoch = 0
#
#         os.makedirs(self.path, exist_ok=True)
#
#         # move network to GPU if available
#         if torch.cuda.is_available() and (not no_gpu):
#             self.device = torch.device('cuda:0')
#             self.net = self.net.to(self.device)
#             cudnn.benchmark = True
#         else:
#             self.device = torch.device('cpu')
#         # initialize model (default)
#         if init:
#             init_models(run_config.model_init)
#
#         # net info
#         net_info = get_net_info(self.net, self.run_config.data_provider.data_shape, measure_latency, True)
#         with open('%s/net_info.txt' % self.path, 'w') as fout:
#             fout.write(json.dumps(net_info, indent=4) + '\n')
#             # noinspection PyBroadException
#             try:
#                 fout.write(self.network.module_str + '\n')
#             except Exception:
#                 pass
#             fout.write('%s\n' % self.run_config.data_provider.train.dataset.transform)
#             fout.write('%s\n' % self.run_config.data_provider.test.dataset.transform)
#             fout.write('%s\n' % self.network)
#             print(self.network)
#
#         # criterion
#         if isinstance(self.run_config.mixup_alpha, float):
#             self.train_criterion = cross_entropy_loss_with_soft_target
#         elif self.run_config.label_smoothing > 0:
#             self.train_criterion = \
#                 lambda pred, target: cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
#         else:
#             self.train_criterion = nn.CrossEntropyLoss()
#         self.test_criterion = nn.CrossEntropyLoss()
#
#         # optimizer
#         if self.run_config.no_decay_keys:
#             keys = self.run_config.no_decay_keys.split('#')
#             net_params = [
#                 self.network.get_parameters(keys, mode='exclude'),  # parameters with weight decay
#                 self.network.get_parameters(keys, mode='include'),  # parameters without weight decay
#             ]
#         else:
#             # noinspection PyBroadException
#             try:
#                 net_params = self.network.weight_parameters()
#             except Exception:
#                 net_params = []
#                 for param in self.network.parameters():
#                     if param.requires_grad:
#                         net_params.append(param)
#         self.optimizer = self.run_config.build_optimizer(net_params)
#
#         self.net = torch.nn.DataParallel(self.net)
#
#     """ save path and log path """
#     @property
#     def save_path(self):
#         if self.__dict__.get('_save_path', None) is None:
#             save_path = os.path.join(self.path, 'checkpoint')
#             os.makedirs(save_path, exist_ok=True)
#             self.__dict__['_save_path'] = save_path
#         return self.__dict__['_save_path']
#
#     @property
#     def logs_path(self):
#         if self.__dict__.get('_logs_path', None) is None:
#             logs_path = os.path.join(self.path, 'logs')
#             os.makedirs(logs_path, exist_ok=True)
#             self.__dict__['_logs_path'] = logs_path
#         return self.__dict__['_logs_path']
#
#     @property
#     def network(self):
#         return self.net.module if isinstance(self.net, nn.DataParallel) else self.net
#
#     def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
#         write_log(self.logs_path, log_str, prefix, should_print, mode)
#
#     """ save and load models """
#
#     def save_model(self, checkpoint=None, is_best=False, model_name=None):
#         if checkpoint is None:
#             checkpoint = {'state_dict': self.network.state_dict()}
#
#         if model_name is None:
#             model_name = 'checkpoint.pth.tar'
#
#         checkpoint['dataset'] = self.run_config.dataset  # add `dataset` info to the checkpoint
#         latest_fname = os.path.join(self.save_path, 'latest.txt')
#         model_path = os.path.join(self.save_path, model_name)
#         with open(latest_fname, 'w') as fout:
#             fout.write(model_path + '\n')
#         torch.save(checkpoint, model_path)
#
#         if is_best:
#             best_path = os.path.join(self.save_path, 'model_best.pth.tar')
#             torch.save({'state_dict': checkpoint['state_dict']}, best_path)
#
#     def load_model(self, model_fname=None):
#         latest_fname = os.path.join(self.save_path, 'latest.txt')
#         if model_fname is None and os.path.exists(latest_fname):
#             with open(latest_fname, 'r') as fin:
#                 model_fname = fin.readline()
#                 if model_fname[-1] == '\n':
#                     model_fname = model_fname[:-1]
#         # noinspection PyBroadException
#         try:
#             if model_fname is None or not os.path.exists(model_fname):
#                 model_fname = '%s/checkpoint.pth.tar' % self.save_path
#                 with open(latest_fname, 'w') as fout:
#                     fout.write(model_fname + '\n')
#             print("=> loading checkpoint '{}'".format(model_fname))
#             checkpoint = torch.load(model_fname, map_location='cpu')
#         except Exception:
#             print('fail to load checkpoint from %s' % self.save_path)
#             return {}
#
#         self.network.load_state_dict(checkpoint['state_dict'])
#         if 'epoch' in checkpoint:
#             self.start_epoch = checkpoint['epoch'] + 1
#         if 'best_acc' in checkpoint:
#             self.best_acc = checkpoint['best_acc']
#         if 'optimizer' in checkpoint:
#             self.optimizer.load_state_dict(checkpoint['optimizer'])
#
#         print("=> loaded checkpoint '{}'".format(model_fname))
#         return checkpoint
#
#     def save_config(self, extra_run_config=None, extra_net_config=None):
#         """ dump run_config and net_config to the model_folder """
#         run_save_path = os.path.join(self.path, 'run.config')
#         if not os.path.isfile(run_save_path):
#             run_config = self.run_config.config
#             if extra_run_config is not None:
#                 run_config.update(extra_run_config)
#             json.dump(run_config, open(run_save_path, 'w'), indent=4)
#             print('Run configs dump to %s' % run_save_path)
#
#         try:
#             net_save_path = os.path.join(self.path, 'net.config')
#             net_config = self.network.config
#             if extra_net_config is not None:
#                 net_config.update(extra_net_config)
#             json.dump(net_config, open(net_save_path, 'w'), indent=4)
#             print('Network configs dump to %s' % net_save_path)
#         except Exception:
#             print('%s do not support net config' % type(self.network))
#
#     """ metric related """
#
#     def get_metric_dict(self):
#         return {
#             'top1': AverageMeter(),
#             'top5': AverageMeter(),
#         }
#
#     def update_metric(self, metric_dict, output, labels):
#         acc1, acc5 = accuracy(output, labels, topk=(1, 5))
#         metric_dict['top1'].update(acc1[0].item(), output.size(0))
#         metric_dict['top5'].update(acc5[0].item(), output.size(0))
#
#     def get_metric_vals(self, metric_dict, return_dict=False):
#         if return_dict:
#             return {
#                 key: metric_dict[key].avg for key in metric_dict
#             }
#         else:
#             return [metric_dict[key].avg for key in metric_dict]
#
#     def get_metric_names(self):
#         return 'top1', 'top5'
#
#     """ train and test """
#
#     def validate(self, epoch=0, is_test=False, run_str='', net=None, data_loader=None, no_logs=False, train_mode=False):
#         if net is None:
#             net = self.net
#         if not isinstance(net, nn.DataParallel):
#             net = nn.DataParallel(net)
#
#         if data_loader is None:
#             data_loader = self.run_config.test_loader if is_test else self.run_config.valid_loader
#
#         if train_mode:
#             net.train()
#         else:
#             net.eval()
#
#         losses = AverageMeter()
#         metric_dict = self.get_metric_dict()
#
#         with torch.no_grad():
#             with tqdm(total=len(data_loader),
#                       desc='Validate Epoch #{} {}'.format(epoch + 1, run_str), disable=no_logs) as t:
#                 for i, (images, labels) in enumerate(data_loader):
#                     # print()
#                     # print("--------------------------------------")
#                     # print("The size of images: ",images.size(),"datatype: ",images.dtype)
#                     # print("The size of labels: ",labels.size())
#                     # print("--------------------------------------DONE")
#
#                     images, labels = images.to(self.device), labels.to(self.device)
#                     # compute output
#                     output, _ = net(images) ### MODIFIED
#                     loss = self.test_criterion(output, labels)
#                     # measure accuracy and record loss
#                     self.update_metric(metric_dict, output, labels)
#
#                     losses.update(loss.item(), images.size(0))
#                     t.set_postfix({
#                         'loss': losses.avg,
#                         **self.get_metric_vals(metric_dict, return_dict=True),
#                         'img_size': images.size(2),
#                     })
#                     t.update(1)
#         return losses.avg, self.get_metric_vals(metric_dict)
#
#     def validate_all_resolution(self, epoch=0, is_test=False, net=None):
#         if net is None:
#             net = self.network
#         if isinstance(self.run_config.data_provider.image_size, list):
#             img_size_list, loss_list, top1_list, top5_list = [], [], [], []
#             for img_size in self.run_config.data_provider.image_size:
#                 img_size_list.append(img_size)
#                 self.run_config.data_provider.assign_active_img_size(img_size)
#                 self.reset_running_statistics(net=net)
#                 loss, (top1, top5) = self.validate(epoch, is_test, net=net)
#                 loss_list.append(loss)
#                 top1_list.append(top1)
#                 top5_list.append(top5)
#             return img_size_list, loss_list, top1_list, top5_list
#         else:
#             loss, (top1, top5) = self.validate(epoch, is_test, net=net)
#             return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
#
#     def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
#         from ofa.imagenet_classification.elastic_nn.utils import set_running_statistics
#         if net is None:
#             net = self.network
#         if data_loader is None:
#             data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
#         set_running_statistics(net, data_loader)
#
#     def obtain_label(self, loader):
#         # a test class
#         class test_args:
#             def __init__(self, class_num):
#                 self.distance = 'cosine'
#                 self.epsilon = 1e-5
#                 self.class_num = class_num
#                 self.threshold = 0
#                 self.largest_dist_const = 1.1
#
#         args = test_args(self.run_config.data_provider.n_classes)
#
#         start_test = True
#         self.net.eval()  # SET TO EVAL MODE
#         with torch.no_grad():
#             image_set = []
#
#             iter_test = iter(loader)
#             for counting in range(len(loader)):
#                 data = iter_test.next()
#                 inputs = data[0]  # IMAGE
#                 labels = data[1]  # LABEL
#
#                 image_set.append((counting, data))  # Insert the image to the list image_set
#
#                 inputs = inputs.cuda()
#                 outputs, feas = self.net(inputs)
#                 # outputs = netC(feas)
#                 if start_test:
#                     all_fea = feas.float().cpu()
#                     all_output = outputs.float().cpu()
#                     all_label = labels.float()
#                     all_inputs = inputs.float()
#                     start_test = False
#                 else:
#                     all_fea = torch.cat((all_fea, feas.float().cpu()), 0)  # CONCATENATING ALL THE LABELS
#                     all_output = torch.cat((all_output, outputs.float().cpu()), 0)
#                     all_label = torch.cat((all_label, labels.float()), 0)
#                     all_inputs = torch.cat((all_inputs, inputs.float()), 0)
#
#         all_output = nn.Softmax(dim=1)(all_output)
#         ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1)
#         unknown_weight = 1 - ent / np.log(args.class_num)
#         _, predict = torch.max(all_output, 1)
#
#         accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
#         if args.distance == 'cosine':
#             all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
#             all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
#         else:
#             raise Exception('Distance type error.')
#
#         all_fea = all_fea.float().cpu().numpy()
#         K = all_output.size(1)
#         aff = all_output.float().cpu().numpy()
#         initc = aff.transpose().dot(all_fea)  # initial centroids
#         initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
#         cls_count = np.eye(K)[predict].sum(axis=0)  # sum to count the number of labels in each class, len:102(nclass)
#         labelset = np.where(cls_count > args.threshold)
#         labelset = labelset[0]  # count the number
#         # print(labelset)
#
#         dd = cdist(all_fea, initc[labelset], args.distance)
#         pred_label = dd.argmin(axis=1)
#         pred_label = labelset[pred_label]
#
#         for round in range(1):
#             aff = np.eye(K)[pred_label]
#             initc = aff.transpose().dot(all_fea)
#             initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
#             dd = cdist(all_fea, initc[labelset], args.distance)
#
#             for i_ in range(len(dd)):
#                 for j_ in range(len(dd[0])):
#                     if str(dd[i_][j_]).lower() == 'nan':
#                         dd[i_][j_] = args.largest_dist_const
#                         # print("{} {}".format(i_, j_))
#             pred_label = dd.argmin(axis=1)
#             pred_label = labelset[pred_label]
#
#         # acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
#         # log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)
#
#         # args.out_file.write(log_str + '\n')
#         # args.out_file.flush()
#         # print(log_str+'\n')
#
#         # RESET TO TRAIN MODE
#         self.net.train()
#         pred_label = pred_label.astype('int')
#         # for i in range(102):
#         #     if i*np.ones(len(pred_label)) != pred_label.any():
#         #         print("Warning: {}".format(i))
#
#         return pred_label, image_set
#
#     def train_one_epoch(self, args, epoch, teacher_model=None, warmup_epochs=0, warmup_lr=0):
#
#         class convert_args:
#             def __init__(self):
#                 self.cls_par = 0.3
#                 self.ent_par = 1.0
#                 self.gent = True
#                 self.ent = True
#                 self.epsilon = 1e-5
#
#         start_test = True  # ADDED
#         # switch to train mode
#         self.net.train()
#         MyRandomResizedCrop.EPOCH = epoch  # required by elastic resolution
#
#         nBatch = len(self.run_config.train_loader)
#
#         # print(type(self.run_config.train_loader))
#
#         # testing = iter(self.run_config.train_loader)
#         # for i in range(10):
#         #     images, labels = testing.next()
#
#         #     print(images, labels)
#
#         pseudo_labels, image_set = self.obtain_label(self.run_config.train_loader)
#         pseudo_labels = torch.from_numpy(pseudo_labels)
#
#         losses = AverageMeter()
#         metric_dict = self.get_metric_dict()
#         data_time = AverageMeter()
#
#         with tqdm(total=nBatch,
#                   desc='{} Train Epoch #{}'.format(self.run_config.dataset, epoch + 1)) as t:
#             end = time.time()
#             start = 0
#
#             # for i, (images, labels) in enumerate(self.run_config.train_loader):
#             for i, (images, real_labels) in image_set:
#
#                 labels = pseudo_labels[start:(start + len(real_labels))]
#
#                 print("rlabel: ", real_labels)
#                 print("plabel: ", labels)
#                 # if (image_set[i] != images).sum().item() != 0:
#                 #     print("BADBADBADBADBAD!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
#
#                 start += len(real_labels)
#
#                 MyRandomResizedCrop.BATCH = i
#                 data_time.update(time.time() - end)
#                 if epoch < warmup_epochs:
#                     new_lr = self.run_config.warmup_adjust_learning_rate(
#                         self.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
#                     )
#                 else:
#                     new_lr = self.run_config.adjust_learning_rate(self.optimizer, epoch - warmup_epochs, i, nBatch)
#
#                 images, labels = images.to(self.device), labels.to(self.device)
#                 target = labels
#                 if isinstance(self.run_config.mixup_alpha, float):
#                     # transform data
#                     lam = random.betavariate(self.run_config.mixup_alpha, self.run_config.mixup_alpha)
#                     images = mix_images(images, lam)
#                     labels = mix_labels(
#                         labels, lam, self.run_config.data_provider.n_classes, self.run_config.label_smoothing
#                     )
#
#                 # soft target
#                 if args.teacher_model is not None:
#                     args.teacher_model.eval()
#                     with torch.no_grad():
#                         soft_logits = args.teacher_model(images)
#                         soft_label = F.softmax(soft_logits, dim=1)
#
#                 # compute output
#                 output, feas = self.net(images)  ### MODIFIED
#
#                 '''
#                 Added from SHOT:
#                 some transformation
#                 Note: we assume all class_nums are known. See SHOT image_target.py line 213
#                 '''
#                 # outputs_test_known, pred = output, labels
#                 # softmax_out = nn.Softmax(dim=1)(outputs_test_known)
#
#                 # this_args = convert_args()
#
#                 # if this_args.cls_par > 0:
#                 #     classifier_loss = nn.CrossEntropyLoss()(outputs_test_known, pred)
#                 #     classifier_loss *= this_args.cls_par
#                 # else:
#                 #     classifier_loss = torch.tensor(0.0).cuda()
#
#                 # if this_args.ent:
#                 #     softmax_out_known = nn.Softmax(dim=1)(outputs_test_known)
#                 #     entropy_loss = torch.mean(Loss.Entropy(softmax_out_known))
#                 #     if this_args.gent:
#                 #         msoftmax = softmax_out.mean(dim=0)
#                 #         gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + this_args.epsilon))
#                 #         entropy_loss -= gentropy_loss
#                 #     classifier_loss += entropy_loss * this_args.ent_par
#
#                 # # FOR DEBUGGING
#                 # print("===========================OUTPUT SIZE")
#                 # print(output.size())
#                 # print("===========================FEATURE")
#                 # print(feas)
#                 # print("===========================ONE ITEM OF RESULT")
#                 # print(output[0])
#                 # print("===========================LABEL:")
#                 # print(labels[0])
#                 # print("===========================END")
#
#                 '''
#                 This is the place for loss modification!
#                 (Completed here)
#
#                 =================================
#                 The original loss
#                 '''
#
#                 loss = self.train_criterion(output, labels)
#
#                 loss_type = 'ce'
#
#                 # if args.teacher_model is None:
#                 #     loss_type = 'ce'
#                 # else:
#                 #     if args.kd_type == 'ce':
#                 #         kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
#                 #     else:
#                 #         kd_loss = F.mse_loss(output, soft_logits)
#                 #     loss = args.kd_ratio * kd_loss + loss
#                 #     loss_type = '%.1fkd+ce' % args.kd_ratio
#
#                 # compute gradient and do SGD step
#                 self.net.zero_grad()  # or self.optimizer.zero_grad()
#                 loss.backward()
#                 # classifier_loss.backward()
#                 self.optimizer.step()
#
#                 # measure accuracy and record loss
#                 losses.update(loss.item(), images.size(0))
#                 # losses.update(classifier_loss.item(), images.size(0))
#                 self.update_metric(metric_dict, output, target)
#
#                 t.set_postfix({
#                     'loss': losses.avg,
#                     **self.get_metric_vals(metric_dict, return_dict=True),
#                     'img_size': images.size(2),
#                     'lr': new_lr,
#                     'loss_type': loss_type,
#                     'data_time': data_time.avg,
#                 })
#                 t.update(1)
#                 end = time.time()
#         return losses.avg, self.get_metric_vals(metric_dict)
#
#     def train(self, args, warmup_epoch=0, warmup_lr=0):
#         for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epoch):
#             train_loss, (train_top1, train_top5) = self.train_one_epoch(args, epoch, warmup_epoch, warmup_lr)
#
#             if (epoch + 1) % self.run_config.validation_frequency == 0:
#                 img_size, val_loss, val_acc, val_acc5 = self.validate_all_resolution(epoch=epoch, is_test=False)
#
#                 is_best = np.mean(val_acc) > self.best_acc
#                 self.best_acc = max(self.best_acc, np.mean(val_acc))
#                 val_log = 'Valid [{0}/{1}]\tloss {2:.3f}\t{5} {3:.3f} ({4:.3f})'. \
#                     format(epoch + 1 - warmup_epoch, self.run_config.n_epochs,
#                            np.mean(val_loss), np.mean(val_acc), self.best_acc, self.get_metric_names()[0])
#                 val_log += '\t{2} {0:.3f}\tTrain {1} {top1:.3f}\tloss {train_loss:.3f}\t'. \
#                     format(np.mean(val_acc5), *self.get_metric_names(), top1=train_top1, train_loss=train_loss)
#                 for i_s, v_a in zip(img_size, val_acc):
#                     val_log += '(%d, %.3f), ' % (i_s, v_a)
#                 self.write_log(val_log, prefix='valid', should_print=False)
#             else:
#                 is_best = False
#
#             self.save_model({
#                 'epoch': epoch,
#                 'best_acc': self.best_acc,
#                 'optimizer': self.optimizer.state_dict(),
#                 'state_dict': self.network.state_dict(),
#             }, is_best=is_best)