from __future__ import print_function, division
import csv
import os
import sys
import time
import torch
import torch.nn.functional as F
from .util import AverageMeter, accuracy
from .JPEG_layer import *
import helper.imagenet_utils as imagenet_utils
import warnings
from helper.cmi import CMILoss, MCMILoss
import torch.optim as optim
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt

#====================================================== cifar100 ========================================================#

def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, args):
    """One epoch distillation"""
    # set modules as train()
    for module in module_list:
        module.train()
    
    # set teacher as eval()
    if args.train_mode:
        module_list[-1].train()
    else:
        module_list[-1].eval()

    if args.distill == 'ab':
        module_list[1].eval()
    elif args.distill == 'ft':
        module_list[2].eval()
    
    criterion_cls = criterion_list[0]
    criterion_div = criterion_list[1]
    criterion_kd = criterion_list[2]

    model_s = module_list[0]
    model_t = module_list[-1]

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1_t = AverageMeter()
    top5_t = AverageMeter()
    top1_s = AverageMeter()
    top5_s = AverageMeter()

    transform = transforms.Compose([
        transforms.Normalize(mean=[0, 0, 0], std=[255., 255., 255.]),
        transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)),])

    end = time.time()
    for idx, data in enumerate(train_loader):
        if args.distill in ['crd']:
            input, target, index, contrast_idx = data
        else:
            input, target, index = data
        data_time.update(time.time() - end)

        input = input.float()
        if torch.cuda.is_available():
            input = input.cuda()
            target = target.cuda()
            index = index.cuda()
            if args.distill in ['crd']:
                contrast_idx = contrast_idx.cuda()
        else:
            input = input.to(torch.device('mps'))
            target = target.to(torch.device('mps'))
            index = index.to(torch.device('mps'))

        # ===================forward=====================
        preact = False
        if args.distill in ['itrd', 'ab']:
            preact = True
        
        if args.JPEG_enable:
            feat_s, logit_s = model_s(transform(input), is_feat=True, preact=preact)
        else:
            feat_s, logit_s = model_s(input, is_feat=True, preact=preact)
        acc1_s, acc5_s = accuracy(logit_s, target, topk=(1, 5))
        
        with torch.no_grad():
            feat_t, logit_t = model_t(input, is_feat=True, preact=preact)
            feat_t = [f.detach() for f in feat_t]
        acc1_t, acc5_t = accuracy(logit_t, target, topk=(1, 5))
        
        # cls + kl div
        loss_cls = criterion_cls(logit_s, target)
        loss_div = criterion_div(logit_s, logit_t)

        # other distillation loss
        if args.distill == 'kd':
            loss_kd_method = 0
        elif args.distill == 'lskd':
            loss_kd_method = 0
        elif args.distill == 'ttm':
            loss_kd_method = criterion_kd(logit_s, logit_t)
        elif args.distill == 'wttm':
            loss_kd_method = criterion_kd(logit_s, logit_t)
        elif args.distill == 'crd':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_kd_method = criterion_kd(f_s, f_t, index, contrast_idx)
        elif args.distill == 'itrd':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_correlation = args.lambda_corr * criterion_kd.forward_correlation_it(f_s, f_t)
            loss_mutual = args.lambda_mutual * criterion_kd.forward_mutual_it(f_s, f_t)
            loss_kd_method = loss_mutual + loss_correlation
        elif args.distill == 'dist':
            loss_kd_method = criterion_kd(logit_s, logit_t)
        elif args.distill == 'dkd':
            loss_kd_method = criterion_kd(logit_s, logit_t, target, args.dkd_alpha, args.dkd_beta, epoch)
        elif args.distill == 'fitnet':
            f_s = module_list[1](feat_s[args.hint_layer])
            f_t = feat_t[args.hint_layer]
            loss_kd_method = criterion_kd(f_s, f_t)
        elif args.distill == 'crd':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_kd_method = criterion_kd(f_s, f_t, index, contrast_idx)
        elif args.distill == 'at':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd_method = sum(loss_group)
        elif args.distill == 'nst':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd_method = sum(loss_group)
        elif args.distill == 'sp':
            g_s = [feat_s[-2]]
            g_t = [feat_t[-2]]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd_method = sum(loss_group)
        elif args.distill == 'rkd':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_kd_method = criterion_kd(f_s, f_t)
        elif args.distill == 'pkt':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_kd_method = criterion_kd(f_s, f_t)
        elif args.distill == 'kdsvd':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd_method = sum(loss_group)
        elif args.distill == 'cc':
            f_s = module_list[1](feat_s[-1])
            f_t = module_list[2](feat_t[-1])
            loss_kd_method = criterion_kd(f_s, f_t)
        elif args.distill == 'vid':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)]
            loss_kd_method = sum(loss_group)
        elif args.distill == 'ab':
            # can also add loss to this stage
            loss_kd_method = 0
        elif args.distill == 'fsp':
            # can also add loss to this stage
            loss_kd_method = 0
        elif args.distill == 'ft':
            factor_s = module_list[1](feat_s[-2])
            factor_t = module_list[2](feat_t[-2], is_factor=True)
            loss_kd_method = criterion_kd(factor_s, factor_t)
        else:
            raise NotImplementedError(args.distill)

        # ===================backward=====================
        loss = args.gamma * loss_cls + args.alpha * loss_div + args.beta * loss_kd_method
        # print("loss_cls:{}, loss_div:{}, loss_kd_method:{}.".format(loss_cls, loss_div, loss_kd_method))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ===================meters=====================
        batch_time.update(time.time() - end)
        losses.update(loss.item(), input.size(0))
        top1_s.update(acc1_s[0], input.size(0))
        top5_s.update(acc5_s[0], input.size(0))
        top1_t.update(acc1_t[0], input.size(0))
        top5_t.update(acc5_t[0], input.size(0))
        if idx % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1_t {top1_t.val:.3f} ({top1_t.avg:.3f})\t'
                  'Acc@5_t {top5_t.val:.3f} ({top5_t.avg:.3f})\t'
                  'Acc@1_s {top1_s.val:.3f} ({top1_s.avg:.3f})\t'
                  'Acc@5_s {top5_s.val:.3f} ({top5_s.avg:.3f})'.format(
                epoch, idx, len(train_loader), batch_time=batch_time,
                loss=losses, top1_s=top1_s, top5_s=top5_s, top1_t=top1_t, top5_t=top5_t))
            sys.stdout.flush()
        end = time.time()
        
    print(' * Acc@1 {top1_s.avg:.3f} Acc@5 {top5_s.avg:.3f}'.format(top1_s=top1_s, top5_s=top5_s))
    
    return top1_s.avg, losses.avg


def validate(val_loader, model, criterion, args):
    """validation"""
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for idx, (input, target) in enumerate(val_loader):
            input = input.float()
            if torch.cuda.is_available():
                input = input.cuda()
                target = target.cuda()
            else:
                input = input.to(torch.device('mps'))
                target = target.to(torch.device('mps'))

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1[0], input.size(0))
            top5.update(acc5[0], input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if idx % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       idx, len(val_loader), batch_time=batch_time, loss=losses,
                       top1=top1, top5=top5))

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))

    return top1.avg, top5.avg, losses.avg


def init(model_s, model_t, init_modules, criterion, train_loader, args):
    model_t.eval()
    model_s.eval()
    init_modules.train()

    if torch.cuda.is_available():
        model_s.cuda()
        model_t.cuda()
        init_modules.cuda()
        cudnn.benchmark = True

    transform = transforms.Compose([
        transforms.Normalize(mean=[0, 0, 0], std=[255., 255., 255.]),
        transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)),])
    
    if args.model_s in ['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110',
                       'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2'] and args.distill == 'ft':
        lr = 0.01
    else:
        lr = args.learning_rate
    optimizer = optim.SGD(init_modules.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    for epoch in range(1, args.init_epochs + 1):
        batch_time.reset()
        data_time.reset()
        losses.reset()
        end = time.time()
        for idx, data in enumerate(train_loader):
            if args.distill in ['crd']:
                input, target, index, contrast_idx = data
            else:
                input, target, index = data
            data_time.update(time.time() - end)

            input = input.float()
            if torch.cuda.is_available():
                input = input.cuda()
                target = target.cuda()
                index = index.cuda()
                if args.distill in ['crd']:
                    contrast_idx = contrast_idx.cuda()

            # ============= forward ==============
            preact = (args.distill == 'ab')
            if args.JPEG_enable:
                feat_s, _ = model_s(transform(input), is_feat=True, preact=preact)
            else:
                feat_s, _ = model_s(input, is_feat=True, preact=preact)
            with torch.no_grad():
                feat_t, _ = model_t(input, is_feat=True, preact=preact)
                feat_t = [f.detach() for f in feat_t]

            if args.distill == 'ab':
                g_s = init_modules[0](feat_s[1:-1])
                g_t = feat_t[1:-1]
                loss_group = criterion(g_s, g_t)
                loss = sum(loss_group)
            elif args.distill == 'ft':
                f_t = feat_t[-2]
                _, f_t_rec = init_modules[0](f_t)
                loss = criterion(f_t_rec, f_t)
            elif args.distill == 'fsp':
                loss_group = criterion(feat_s[:-1], feat_t[:-1])
                loss = sum(loss_group)
            else:
                raise NotImplemented('Not supported in init training: {}'.format(args.distill))

            losses.update(loss.item(), input.size(0))

            # ===================backward=====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

        # test_acc_student, tect_acc_top5_student, test_loss_student = validate(val_loader, model_s, criterion_cls, args)
        # print('student initial accuracy: ', test_acc_student)
        
        # end of epoch
        print('Epoch: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'losses: {losses.val:.5f} ({losses.avg:.5f})'.format(epoch, args.init_epochs, batch_time=batch_time, losses=losses))
        sys.stdout.flush()


def train_cifar100(opt, model, centroid_helper, optimizer, train_loader, train_original_loader, val_loader, epoch, backward=True):
    metric_logger = imagenet_utils.MetricLogger(delimiter=" ")
    metric_logger.add_meter("lr", imagenet_utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", imagenet_utils.SmoothedValue(window_size=10, fmt="{value}"))
    
    # loss function
    ce_criterion = nn.CrossEntropyLoss()
    cmi_criterion = CMILoss()
    mcmi_criterion = MCMILoss(opt.dataset)
    if torch.cuda.is_available():
        ce_criterion = ce_criterion.to(device)
        cmi_criterion = cmi_criterion.to(device)
        mcmi_criterion = mcmi_criterion.to(device)
    
    # num_batches = len(train_loader)
    num_processed_samples = 0
    header = f"Train: [{epoch}]"
    for idx, (input, target) in enumerate(metric_logger.log_every(train_loader, opt.print_freq, header)):
        start_time = time.time()
        input = input.float()
        input, target = input.to(opt.device), target.to(opt.device)
        
        # ===============centroid update=================
        if opt.centroid_update_freq > 0 and idx % opt.centroid_update_freq == 0: # opt.centroid_update_freq <= 0 means dont update centroid
            centroid_helper.update_centroids(model, train_loader) # update centroid
        centroids = centroid_helper.get_centroids(target).float().to(device).detach() # shape: [batch_size, num_classes]
        
        # ===================forward=====================
        # set to train mode
        if not opt.freeze_model or opt.train_mode:
            model.train()
        else:
            model.eval()
        output = model(input)
        ce_loss = ce_criterion(output, target)
        cmi_value = cmi_criterion(output, centroids)
        mcmi_loss = mcmi_criterion(output, centroids)
        
        # minimizing CE and maximizing CMI
        if opt.centroid_update_freq > 0:
            loss = ce_loss + (opt.lambda_MCMI * mcmi_loss)
            # loss = opt.lambda_MCMI * mcmi_loss
        else:
            loss = ce_loss - (opt.lambda_MCMI * cmi_value)
            # loss = -opt.lambda_MCMI * cmi_value
        acc1, acc5 = imagenet_utils.accuracy(output, target, topk=(1, 5))
        
        # ===================backward=====================
        if backward:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # ===================meters=====================
        batch_size = input.shape[0]
        num_processed_samples += batch_size
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters["losses"].update(loss.item(), n=batch_size)
        metric_logger.meters["ce_losses"].update(ce_loss.item(), n=batch_size)
        metric_logger.meters["mcmi_loss"].update(mcmi_loss.item(), n=batch_size)
        metric_logger.meters["cmi_value"].update(cmi_value.item(), n=batch_size)
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
        
        if opt.JPEG_enable:
            lum_qtable = model.module.jpeg_layer.lum_qtable.squeeze().unsqueeze(0).clone().detach()
            chrom_qtable = model.module.jpeg_layer.chrom_qtable.squeeze().unsqueeze(0).clone().detach()
            qTable = torch.cat((lum_qtable, chrom_qtable), 0)
            metric_logger.meters["q_min"].update(qTable.min().item(), n=batch_size)
            metric_logger.meters["q_max"].update(qTable.max().item(), n=batch_size)
        
        if opt.JPEG_enable and opt.JPEG_alpha_trainable:
            lum_alpha = model.module.jpeg_layer.lum_alpha.squeeze().unsqueeze(0).clone().detach()
            chrom_alpha = model.module.jpeg_layer.chrom_alpha.squeeze().unsqueeze(0).clone().detach()
            alphaTable = torch.cat((lum_alpha, chrom_alpha), 0)
            metric_logger.meters["alpha_min"].update(alphaTable.min().item(), n=batch_size)
            metric_logger.meters["alpha_max"].update(alphaTable.max().item(), n=batch_size)
    
    print('Acc@1 {top1.global_avg:.15f} Acc@5 {top5.global_avg:.15f}'.format(top1=metric_logger.acc1, top5=metric_logger.acc5))

    return metric_logger.acc1.global_avg, metric_logger.losses.global_avg, \
           metric_logger.ce_losses.global_avg, metric_logger.mcmi_loss.global_avg, metric_logger.cmi_value.global_avg


def evaluate_cifar100(opt, model, test_loader, log_suffix=""):
    model.eval()

    # loss function
    ce_criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        ce_criterion = ce_criterion.to(device)
    
    header = f"Test: {log_suffix}"
    num_processed_samples = 0
    metric_logger = imagenet_utils.MetricLogger(delimiter=" ")
    with torch.inference_mode():
        for input, target in metric_logger.log_every(test_loader, opt.print_freq, header):
            input = input.float()
            input = input.to(opt.device, non_blocking=True)
            target = target.to(opt.device, non_blocking=True)
            
            # ===================inference=====================
            output = model(input)
            ce_loss = ce_criterion(output, target)
            acc1, acc5 = imagenet_utils.accuracy(output, target, topk=(1, 5))

            # ===================meters=====================
            batch_size = input.shape[0]
            num_processed_samples += batch_size
            metric_logger.meters["ce_losses"].update(ce_loss.item(), n=batch_size)
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        
    # gather the stats from all processes
    num_processed_samples = imagenet_utils.reduce_across_processes(num_processed_samples)
    if (hasattr(test_loader.dataset, "__len__") and len(test_loader.dataset) != num_processed_samples and torch.distributed.get_rank() == 0):
        warnings.warn(f"It looks like the dataset has {len(test_loader.dataset)} samples, but {num_processed_samples} samples were used for the validation, which might bias the results.")
    metric_logger.synchronize_between_processes()

    return metric_logger.acc1.global_avg, metric_logger.acc5.global_avg, metric_logger.ce_losses.global_avg


#====================================================== imagenet ========================================================#

def train_imagenet(opt, model, centroid_helper, optimizer, train_loader, val_loader, mini_loader, epoch, backward=True):
    metric_logger = imagenet_utils.MetricLogger(delimiter=" ")
    metric_logger.add_meter("lr", imagenet_utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", imagenet_utils.SmoothedValue(window_size=10, fmt="{value}"))
    
    # loss function
    ce_criterion = nn.CrossEntropyLoss()
    cmi_criterion = CMILoss()
    mcmi_criterion = MCMILoss(opt.dataset)
    if torch.cuda.is_available():
        ce_criterion = ce_criterion.to(device)
        cmi_criterion = cmi_criterion.to(device)
        mcmi_criterion = mcmi_criterion.to(device)
    
    # num_batches = len(train_loader)
    num_processed_samples = 0
    header = f"Train: [{epoch}]"
    for idx, (input, target) in enumerate(metric_logger.log_every(train_loader, opt.print_freq, header)):
        start_time = time.time()
        input = input.float()
        input, target = input.to(opt.device), target.to(opt.device)
        
        # ===============centroid update=================
        if idx % opt.centroid_update_freq == 0:
            centroid_helper.update_centroids(model, mini_loader) # update centroid
        centroids = centroid_helper.get_centroids(target).float().to(device).detach() # shape: [batch_size, num_classes]
        
        # ===================forward=====================
        # set to train mode
        if not opt.freeze_model or opt.train_mode:
            model.train()
        output = model(input)
        ce_loss = ce_criterion(output, target)
        cmi_value = cmi_criterion(output, centroids)
        mcmi_loss = mcmi_criterion(output, centroids)
        
        # minimizing CE and maximizing CMI
        loss = ce_loss + (opt.lambda_MCMI * mcmi_loss)
        acc1, acc5 = imagenet_utils.accuracy(output, target, topk=(1, 5))
        
        # ===================backward=====================
        if backward:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # ===================meters=====================
        batch_size = input.shape[0]
        num_processed_samples += batch_size
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters["losses"].update(loss.item(), n=batch_size)
        metric_logger.meters["ce_losses"].update(ce_loss.item(), n=batch_size)
        metric_logger.meters["mcmi_loss"].update(mcmi_loss.item(), n=batch_size)
        metric_logger.meters["cmi_value"].update(cmi_value.item(), n=batch_size)
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
        
        if opt.JPEG_enable:
            lum_qtable = model.module.jpeg_layer.lum_qtable.squeeze().unsqueeze(0).clone().detach()
            chrom_qtable = model.module.jpeg_layer.chrom_qtable.squeeze().unsqueeze(0).clone().detach()
            qTable = torch.cat((lum_qtable, chrom_qtable), 0)
            metric_logger.meters["q_min"].update(qTable.min().item(), n=batch_size)
            metric_logger.meters["q_max"].update(qTable.max().item(), n=batch_size)
        
        if opt.JPEG_enable and opt.JPEG_alpha_trainable:
            lum_alpha = model.module.jpeg_layer.lum_alpha.squeeze().unsqueeze(0).clone().detach()
            chrom_alpha = model.module.jpeg_layer.chrom_alpha.squeeze().unsqueeze(0).clone().detach()
            alphaTable = torch.cat((lum_alpha, chrom_alpha), 0)
            metric_logger.meters["alpha_min"].update(alphaTable.min().item(), n=batch_size)
            metric_logger.meters["alpha_max"].update(alphaTable.max().item(), n=batch_size)
        
        # save training resutls
        if idx % opt.print_freq == 0 and idx != 0:
            if (not hasattr(opt, 'rank') or opt.rank == 0):
                # save log
                new_log = {'epoch': epoch,
                           'test_acc': -1., 'test_loss': -1., 'val_acc': -1., 'val_mcmi_loss': -1., 'val_cmi_value': -1., 'val_ce_loss': -1., 
                           'train_acc': metric_logger.acc1.global_avg,
                           'train_mcmi_loss': metric_logger.mcmi_loss.global_avg, 'train_cmi_value': metric_logger.cmi_value.global_avg,
                           'train_ce_loss': metric_logger.ce_losses.global_avg, 'train_total_loss': metric_logger.losses.global_avg, }
                if opt.JPEG_enable:
                    new_log['Q_min'] = qTable.min().item()
                    new_log['Q_max'] = qTable.max().item()
                with open(opt.logs_fname, 'a', newline='') as f:
                    writer = csv.writer(f)
                    writer.writerow(new_log.values())
            
                # save q_table
                if opt.JPEG_enable:
                    # save the exact q_table
                    q_tables_file = os.path.join(opt.q_tables_folder, 'q_table_idx_{}.pt'.format(idx))
                    torch.save(qTable, q_tables_file)
                    
                    if not opt.JPEG_layer_blockwise:
                        lum_qtable = lum_qtable.reshape(opt.num_jpeg_layers, 8, 8)
                        chrom_qtable = chrom_qtable.reshape(opt.num_jpeg_layers, 8, 8)
                        
                        lum_q_tables_file = os.path.join(opt.q_tables_folder, 'lum_q_table.txt')
                        with open(lum_q_tables_file, 'a') as f:
                            f.write('idx_{}\n'.format(idx))
                            for layer_index in range(opt.num_jpeg_layers):
                                np.savetxt(f, lum_qtable[layer_index].cpu().numpy(), delimiter=',', fmt='%.04f')
                                f.write('\n')
                        chrom_q_tables_file = os.path.join(opt.q_tables_folder, 'chrom_q_table.txt')
                        with open(chrom_q_tables_file, 'a') as f:
                            f.write('idx_{}\n'.format(idx))
                            for layer_index in range(opt.num_jpeg_layers):
                                np.savetxt(f, chrom_qtable[layer_index].cpu().numpy(), delimiter=',', fmt='%.04f')
                                f.write('\n')
                
                if opt.JPEG_enable and opt.JPEG_alpha_trainable:
                    # save the exact alpha_table
                    alpha_tables_file = os.path.join(opt.alpha_folder, 'alpha_table_idx_{}.pt'.format(idx))
                    torch.save(alphaTable, alpha_tables_file)
                    
                    if not opt.JPEG_layer_blockwise:
                        # save lum and chrom alpha_table for ploting trend
                        lum_alpha = lum_alpha.reshape(opt.num_jpeg_layers, 8, 8)
                        chrom_alpha = chrom_alpha.reshape(opt.num_jpeg_layers, 8, 8)
                        
                        lum_alpha_tables_file = os.path.join(opt.alpha_folder, 'lum_alpha_table.txt')
                        with open(lum_alpha_tables_file, 'a') as f:
                            f.write('idx_{}\n'.format(idx))
                            for layer_index in range(opt.num_jpeg_layers):
                                np.savetxt(f, lum_alpha[layer_index].cpu().numpy(), delimiter=',', fmt='%.04f')
                                f.write('\n')
                        chrom_alpha_tables_file = os.path.join(opt.alpha_folder, 'chrom_alpha_table.txt')
                        with open(chrom_alpha_tables_file, 'a') as f:
                            f.write('idx_{}\n'.format(idx))
                            for layer_index in range(opt.num_jpeg_layers):
                                np.savetxt(f, chrom_alpha[layer_index].cpu().numpy(), delimiter=',', fmt='%.04f')
                                f.write('\n')
                
            # check the cmi value across the train dataset without DA
            if opt.analysis_mode:
                evaluate_centorid(opt, centroid_helper, model, val_loader)
            
    print('Acc@1 {top1.global_avg:.15f} Acc@5 {top5.global_avg:.15f}'.format(top1=metric_logger.acc1, top5=metric_logger.acc5))

    return metric_logger.acc1.global_avg, metric_logger.losses.global_avg, \
           metric_logger.ce_losses.global_avg, metric_logger.mcmi_loss.global_avg, metric_logger.cmi_value.global_avg


def evaluate_imagenet(opt, model, test_loader, log_suffix=""):
    model.eval()

    # loss function
    ce_criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        ce_criterion = ce_criterion.to(device)
    
    header = f"Test: {log_suffix}"
    num_processed_samples = 0
    metric_logger = imagenet_utils.MetricLogger(delimiter=" ")
    with torch.inference_mode():
        for input, target in metric_logger.log_every(test_loader, opt.print_freq, header):
            input = input.float()
            input = input.to(opt.device, non_blocking=True)
            target = target.to(opt.device, non_blocking=True)
            
            # ===================inference=====================
            output = model(input)
            ce_loss = ce_criterion(output, target)
            acc1, acc5 = imagenet_utils.accuracy(output, target, topk=(1, 5))

            # ===================meters=====================
            batch_size = input.shape[0]
            num_processed_samples += batch_size
            metric_logger.meters["ce_losses"].update(ce_loss.item(), n=batch_size)
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        
    # gather the stats from all processes
    num_processed_samples = imagenet_utils.reduce_across_processes(num_processed_samples)
    if (hasattr(test_loader.dataset, "__len__") and len(test_loader.dataset) != num_processed_samples and torch.distributed.get_rank() == 0):
        warnings.warn(f"It looks like the dataset has {len(test_loader.dataset)} samples, but {num_processed_samples} samples were used for the validation, which might bias the results.")
    metric_logger.synchronize_between_processes()

    return metric_logger.acc1.global_avg, metric_logger.acc5.global_avg, metric_logger.ce_losses.global_avg


def evaluate_centorid(opt, centroid_helper, model, val_loader):
    print("\n==> Computing centroids.")
    all_centroids, _ = centroid_helper.compute_centroids(model, val_loader, save=True)
    # all_centroids = torch.load(os.path.join("./save/cifar100/teacher/vgg13/JPEG1_lr_1.0_hardness_20.0_lambda_0.7/trial_1", 'val_centroid.pt')).cuda()
    # all_centroids = torch.load(os.path.join("./save/imagenet/teacher/Resnet34", 'centroids.pt'))["centroids"].cuda()
    
    # loss function
    ce_criterion = nn.CrossEntropyLoss()
    cmi_criterion = CMILoss()
    mcmi_criterion = MCMILoss(opt.dataset)
    if torch.cuda.is_available():
        ce_criterion = ce_criterion.to(device)
        cmi_criterion = cmi_criterion.to(device)
        mcmi_criterion = mcmi_criterion.to(device)
    
    print("\n==> Validation on model.")
    header = f"val: "
    num_processed_samples = 0    
    metric_logger = imagenet_utils.MetricLogger(delimiter=" ")
    if opt.train_mode:
        model.train()
    else:
        model.eval()
    with torch.inference_mode():
        for inputs, target in metric_logger.log_every(val_loader, 500, header):
            inputs = inputs.float()
            if torch.cuda.is_available():
                inputs = inputs.to(device)
                target = target.to(device)
            centroids = torch.index_select(all_centroids, 0, target).float().detach()
            
            # ===================inference=====================
            output = model(inputs)
            ce_loss = ce_criterion(output, target)
            cmi_value = cmi_criterion(output, centroids)
            mcmi_loss = mcmi_criterion(output, centroids)
            acc1, acc5 = imagenet_utils.accuracy(output, target, topk=(1, 5))

            # ===================meters=======================
            batch_size = inputs.shape[0]
            num_processed_samples += batch_size
            metric_logger.meters['mcmi_losses'].update(mcmi_loss.item(), batch_size)
            metric_logger.meters['cmi_values'].update(cmi_value.item(), batch_size)
            metric_logger.meters['ce_losses'].update(ce_loss.item(), batch_size)
            metric_logger.meters['acc_values'].update(acc1.item(), batch_size)
    
    # gather the stats from all processes
    num_processed_samples = imagenet_utils.reduce_across_processes(num_processed_samples)
    if (hasattr(val_loader.dataset, "__len__") and len(val_loader.dataset) != num_processed_samples and torch.distributed.get_rank() == 0):
        warnings.warn(f"It looks like the dataset has {len(val_loader.dataset)} samples, but {num_processed_samples} samples were used for the validation, which might bias the results.")
    metric_logger.synchronize_between_processes()
    
    print("====> val_mcmi_loss: {}".format(str(metric_logger.mcmi_losses.global_avg)))
    print("====> val_cmi_value: {}".format(str(metric_logger.cmi_values.global_avg)))
    print("====> val_ce_loss: {}".format(str(metric_logger.ce_losses.global_avg)))
    print("====> val_acc: {}".format(str(metric_logger.acc_values.global_avg)))
    
    return metric_logger.mcmi_losses.global_avg, metric_logger.cmi_values.global_avg, \
        metric_logger.ce_losses.global_avg, metric_logger.acc_values.global_avg
