from __future__ import print_function
import os
import argparse
import time
import numpy as np
import random
import json
import csv
import torch
import timm
from dataset.cifar100 import get_cifar100_dataloaders, get_cifar100_dataloaders_224
from helper.loops import train_cifar100, evaluate_cifar100, evaluate_centorid
from helper.JPEG_layer import *
from helper.centroids import CentroidHelper
from collections import Counter
from models import cifar100_model_dict
import helper.imagenet_utils as imagenet_utils
import torch.distributed as dist
from torch.optim.lr_scheduler import CosineAnnealingLR


# helper functions
def create_folder(folder):
    if not os.path.isdir(folder):
        try:
            os.makedirs(folder)
        except FileExistsError:
            # the dict maybe created by other process
            print(f"[warning]: Directory {folder} already exists.")


def count_classes(dataset):
    idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}
    class_indices = [label for _, label in dataset]
    class_counts = Counter(class_indices)
    return {idx_to_class[idx]: count for idx, count in class_counts.items()}


def save_class_counts_as_vector(dataset, class_counts, output_file):
    class_to_idx = dataset.class_to_idx
    num_classes = len(class_to_idx)
    class_count_vector = torch.zeros(num_classes, dtype=torch.float32)

    for class_name, count in class_counts.items():
        if class_name in class_to_idx:
            class_idx = class_to_idx[class_name]
            class_count_vector[class_idx] = count

    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    with open(output_file, "w") as f:
        json.dump(class_count_vector.tolist(), f)
    print(f"Class count vector saved to {output_file}")


# more functions
def parse_options():
    parser = argparse.ArgumentParser('argument for training')

    # optimization
    parser.add_argument("--batch_size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size")
    parser.add_argument("--num_workers", default=16, type=int, metavar="N", help="number of data loading workers")
    parser.add_argument("--epochs", default=20, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument('--lr_decay_epochs', type=str, default='110', help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
    parser.add_argument("--weight_decay", default=5e-4, type=float, metavar="W", help="weight decay", dest="weight_decay")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument("--print_freq", default=50, type=int, help="print frequency")
    parser.add_argument("--centroid_update_freq", default=1, type=int, help="centroid update frequency")
    parser.add_argument("--start_epoch", default=0, type=int, metavar="N", help="start epoch")
    
    # model and dataset
    parser.add_argument("--model", default="vgg13", type=str, help="model name")
    parser.add_argument("--data_path", default=None, type=str, help="dataset path")
    parser.add_argument("--finetune_model_path", default=None, type=str, help="pretrain underlying model path")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100'], help='dataset')
    parser.add_argument("--output_dir", default=".", type=str, help="path to save outputs")
    parser.add_argument('--num_samples_per_class', type=int, default=16, help='number of samples per class for mini-dataset')
    parser.add_argument("--optimizer", default="sgd", choices=['sgd','adam'], type=str, help="optimizer")
    parser.add_argument("--trial", default=1, type=int, help="trial number")
    parser.add_argument('--seed', type=int, default=0, help='seed id, set to 0 if do not want to fix the seed')
    parser.add_argument('--accumulation_steps', type=int, default=1, help='Number of micro-batches to accumulate')
    parser.add_argument("--lr_scheduler", default=None, type=str, help="the lr scheduler (default: None)")
    
    # JPEG
    parser.add_argument('--JPEG_enable', action='store_true')
    parser.add_argument('--freeze_model', action='store_true')
    parser.add_argument('--JPEG_alpha_trainable', action='store_true')
    parser.add_argument('--hardness_matching', action='store_true')
    parser.add_argument("--test_only", dest="test_only", help="Only test the model", action="store_true")
    parser.add_argument("--test_checkpoint", dest="test_checkpoint", help="Only test the model with checkpoint", action="store_true")
    parser.add_argument("--test_path", type=str, help="test experiment path")
    parser.add_argument("--train_mode", action='store_true')
    
    # JPEG Hyper
    parser.add_argument('--lambda_MCMI', type=float, default=0.4, help='weight balance for CMI loss')
    parser.add_argument('--JPEG_alpha', type=float, default=20.0, help='Tempurature scaling')
    parser.add_argument("--learning_rate", default=0.1, type=float, help="initial learning rate")
    parser.add_argument('--JPEG_learning_rate', type=float, default=0.025, help='Quantization Table Learning Rate')
    parser.add_argument('--alpha_learning_rate', type=float, default=None, help='Alpha Learning Rate')
    parser.add_argument('--Q_inital', type=float, default=1.0, help='Initial Quantization Step')
    parser.add_argument('--num_jpeg_layers', type=int, default=1, help='the number of consecutive JPEG layers')
    parser.add_argument('--block_size', type=int, default=8, help='the JPEG block size')
    parser.add_argument('--num_bit', type=int, default=11, help='Number of bits to represent DCT coeff')
    parser.add_argument('--min_Q_Step', type=float, default=1.0, help='Minumum Quantization Step')
    parser.add_argument('--max_Q_Step', type=float, default=255, help='Maximum Quantization Step')
    parser.add_argument('--num_non_zero_q', type=int, default=5, choices=range(2,2**10 - 1), help='Window size for the reconstruction space')

    # distribute mode
    parser.add_argument("--world_size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist_url", default="env://", type=str, help="url used to set up distributed training")
    parser.add_argument("--sync_bn", dest="sync_bn", help="Use sync batch norm", action="store_true",)

    opt = parser.parse_args()
    return opt


def prepare_dir(opt):
    if (opt.test_only or opt.test_checkpoint) and opt.test_path is not None:
        with open(os.path.join("./save/{}/teacher/{}".format(opt.dataset, opt.model), opt.test_path, "opt.json"), 'r') as f:
            test_opt = argparse.Namespace(**json.load(f))
        opt.added_layer = test_opt.added_layer
        opt.model_name = test_opt.model_name
        opt.output_dir = test_opt.output_dir
        opt.min_Q_Step = test_opt.min_Q_Step
        opt.hardness_matching = test_opt.hardness_matching
        opt.num_jpeg_layers = test_opt.num_jpeg_layers
        opt.trial_folder = test_opt.trial_folder
        if test_opt.JPEG_enable:
            opt.q_tables_folder = test_opt.q_tables_folder
        if test_opt.JPEG_alpha_trainable or test_opt.hardness_matching:
            opt.alpha_folder = test_opt.alpha_folder
        opt.logs_fname = test_opt.logs_fname
        opt.opt_fname = test_opt.opt_fname
        return opt
    
    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))
    
    if opt.JPEG_enable:
        opt.added_layer = "JPEG"
        if opt.JPEG_alpha_trainable:
            opt.model_name = 'JPEG{}_lr_{}_alpha_{}_alpha_lr_{}_lambda_{}'.format(opt.num_jpeg_layers, opt.JPEG_learning_rate, opt.JPEG_alpha, opt.alpha_learning_rate, opt.lambda_MCMI)
        elif opt.hardness_matching:
            opt.model_name = 'JPEG{}_lr_{}_hardness_{}_lambda_{}'.format(opt.num_jpeg_layers, opt.JPEG_learning_rate, opt.JPEG_alpha, opt.lambda_MCMI)
        else:
            opt.model_name = 'JPEG{}_lr_{}_alpha_{}_lambda_{}'.format(opt.num_jpeg_layers, opt.JPEG_learning_rate, opt.JPEG_alpha, opt.lambda_MCMI)
        if not opt.freeze_model:
            opt.model_name += "_underlying_lr_{}".format(opt.learning_rate)
        if opt.finetune_model_path != None:
            opt.model_name += "_finetune"
        if opt.min_Q_Step < 1.0:
            opt.model_name += "_min_step_{}".format(opt.min_Q_Step)
        if opt.train_mode:
            opt.model_name += "_train_mode"
        if opt.centroid_update_freq <= 0:
            opt.model_name += "_mcmi"
    else:
        opt.added_layer = "vanilla"
        opt.model_name = 'lr_{}_lambda_{}'.format(opt.learning_rate, opt.lambda_MCMI)
    
    opt.output_dir = './save/{}/teacher/{}'.format(opt.dataset, opt.model)
    create_folder(opt.output_dir)

    base_folder = os.path.join(opt.output_dir, opt.model_name)
    opt.trial_folder = os.path.join(base_folder, "trial_{}".format(opt.trial))
    create_folder(opt.trial_folder)
    
    if opt.JPEG_enable:
        opt.q_tables_folder = os.path.join(opt.trial_folder, "Q_tables")
        create_folder(opt.q_tables_folder)
    
    if opt.JPEG_alpha_trainable or opt.hardness_matching:
        opt.alpha_folder = os.path.join(opt.trial_folder, "Alpha_tables")
        create_folder(opt.alpha_folder)
    
    opt.logs_fname = os.path.join(opt.trial_folder, "logs.csv")
    if (not hasattr(opt, 'rank') or opt.rank == 0):
        columns = ['epoch', 'test_acc', 'test_loss', 'val_acc', 'val_mcmi_loss', 'val_cmi_value', 'val_ce_loss', 'train_acc', 'train_mcmi_loss', 'train_cmi_value', 'train_ce_loss', 'train_total_loss']
        if opt.JPEG_enable:
            columns.extend(['Q_min', 'Q_max'])
        if opt.JPEG_enable and opt.JPEG_alpha_trainable:
            columns.extend(['Alpha_min', 'Alpha_max'])
        with open(opt.logs_fname, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(columns)
    
    opt.opt_fname = os.path.join(opt.trial_folder, "opt.json")
    with open(opt.opt_fname, 'w') as f:
        try:
            json.dump(vars(opt), f, indent=4)
        except FileExistsError:
            print("\n==> save opt.json failed.")
    
    print("\n==> Finish prepare directory: ", opt.model_name)
    return opt


def main(opt):
    best_acc = 0
    
    # initial distributed mode
    imagenet_utils.init_distributed_mode(opt)
    opt = prepare_dir(opt)
    device = torch.device(opt.device)
    
    if opt.seed:
        torch.manual_seed(opt.seed)
        torch.cuda.manual_seed(opt.seed)
        torch.cuda.manual_seed_all(opt.seed)
        np.random.seed(opt.seed)
        random.seed(opt.seed)
        torch.manual_seed(opt.seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        # torch.use_deterministic_algorithms(True)
        print("\n==> Use seed", opt.seed)
    
    # dataloader
    if opt.JPEG_enable:
        mean_datatloader=(0, 0, 0)
        std_datatloader=(1/255., 1/255., 1/255.)
    else:
        mean_datatloader=(0.5071, 0.4867, 0.4408)
        std_datatloader=(0.2675, 0.2565, 0.2761)

    if opt.dataset == 'cifar100':
        # train_set, train_original_set, val_set, test_set, test_original_set, train_loader, train_original_loader, val_loader, test_loader, test_original_loader = \
        #     get_cifar100_dataloaders_224(batch_size=opt.batch_size, num_workers=opt.num_workers, opt=opt, mean=mean_datatloader, std=std_datatloader) # for vit resnet18 model.
        train_set, train_original_set, val_set, test_set, test_original_set, train_loader, train_original_loader, val_loader, test_loader, test_original_loader = \
            get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers, opt=opt, mean=mean_datatloader, std=std_datatloader)
        num_classes = 100
    else:
        raise NotImplementedError(opt.dataset)

    # save the number of samples per class for cifar100
    train_class_counts_file = "./save/{}/train_class_counts.json".format(opt.dataset)
    test_class_counts_file = "./save/{}/test_class_counts.json".format(opt.dataset)
    if not os.path.exists(train_class_counts_file):
        train_class_counts = count_classes(train_set)
        test_class_counts = count_classes(test_set)
        save_class_counts_as_vector(train_set, train_class_counts, train_class_counts_file)
        save_class_counts_as_vector(test_set, test_class_counts, test_class_counts_file)
    
    # load underlying model
    if opt.model == "ResNet18": # for vit-cnn analysis
        print("\n==> load pretrain ResNet18 model.")
        from models import MODEL
        underlying_model = MODEL.get("ResNet")()
        path_t = "./save/models/{}_vanilla/ckpt_epoch_240.pth".format(opt.model)
        checkpoint= torch.load(path_t, map_location="cpu")
        
        test_err = checkpoint["test_err"] if "test_err" in checkpoint else 100
        ema_err = checkpoint["ema_err"] if "ema_err" in checkpoint else 100
        ema_state = "ema_state" if "ema_state" in checkpoint else "model_state"
        best_state = "model_state" if test_err <= ema_err else ema_state
        
        underlying_model.load_state_dict(checkpoint[best_state])
    elif opt.model == "convnext_tiny" or opt.model == "vit_small_patch16_224" or opt.model == "swin_tiny_patch4_window7_224": # for vit-cnn analysis
        underlying_model = timm.create_model(opt.model, num_classes=num_classes)
        path_t = "./save/models/{}_vanilla/ckpt_epoch_240.pth".format(opt.model)
        timm.models.load_checkpoint(underlying_model, checkpoint_path=path_t)
    else:
        underlying_model = cifar100_model_dict[opt.model](num_classes=num_classes)
        if opt.finetune_model_path != None:
            print("\n==> load fine-tune model.")
            checkpoint = torch.load(opt.finetune_model_path, weights_only=False)
            underlying_model.load_state_dict(checkpoint['model'])
        else:
            print("\n==> load pretrain model.")
            path_t = "./save/models/{}_vanilla/ckpt_epoch_240.pth".format(opt.model)
            underlying_model.load_state_dict(torch.load(path_t, weights_only=False)['model'])

    # load whole model
    opt_name = opt.optimizer.lower()
    if opt.JPEG_enable:
        print("\n==> JPEG_enable")
        sample, _ = train_set[0]
        # shape: [224,224,3] or [32,32,3]
        # jpeg_layer = JPEG_layers(opt=opt, img_shape=(224,224,3), mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761))
        jpeg_layer = JPEG_layers(opt=opt, img_shape=(32,32,3), mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761))
        model = CustomModel(jpeg_layer, underlying_model)
        
        if opt_name == "sgd":
            optimizer_data = [{'params': model.jpeg_layer.lum_qtable, 'lr': opt.JPEG_learning_rate, 'momentum': opt.momentum,},
                              {'params': model.jpeg_layer.chrom_qtable, 'lr': opt.JPEG_learning_rate, 'momentum': opt.momentum,}]
            if opt.JPEG_alpha_trainable:
                optimizer_data.append({'params': model.jpeg_layer.lum_alpha, 'lr': opt.alpha_learning_rate, 'momentum': opt.momentum,})
                optimizer_data.append({'params': model.jpeg_layer.chrom_alpha, 'lr': opt.alpha_learning_rate, 'momentum': opt.momentum,})
            if not opt.freeze_model:
                optimizer_data.append({'params': model.underlying_model.parameters(), 'lr': opt.learning_rate, 'momentum': opt.momentum,})
        elif opt_name == "adam":
            optimizer_data = [{'params': model.jpeg_layer.lum_qtable}, {'params': model.jpeg_layer.chrom_qtable}]
            if opt.JPEG_alpha_trainable:
                optimizer_data.append({'params': model.jpeg_layer.lum_alpha})
                optimizer_data.append({'params': model.jpeg_layer.chrom_alpha})
            if not opt.freeze_model:
                optimizer_data.append({'params': model.underlying_model.parameters()})
        else:
            raise RuntimeError(f"Invalid optimizer {opt.optimizer}. Only SGD, Adam are supported.")
    else:
        print("\n==> Not JPEG_enable")
        model = underlying_model
        if opt_name == "sgd":
            optimizer_data = [{'params': model.parameters(), 'lr': opt.learning_rate, 'momentum': opt.momentum, "weight_decay": opt.weight_decay}]
        elif opt_name == "adam":
            optimizer_data = [{'params': model.parameters()}]
        else:
            raise RuntimeError(f"Invalid optimizer {opt.optimizer}. Only SGD and Adam are supported.")
    
    if opt.distributed and opt.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    
    if torch.cuda.is_available():
        model = model.to(device)
    
    # optimizer
    if opt_name == "sgd":
        print("\n==> Using optimizer SGD.")
        optimizer = torch.optim.SGD(optimizer_data)
    elif opt_name == "adam":
        print("\n==> Using optimizer Adam.")
        optimizer = torch.optim.Adam(optimizer_data, lr=1e-2, betas=(0.9, 0.999), eps=1e-08)
    else:
        raise RuntimeError(f"Invalid optimizer {opt.optimizer}. Only SGD and Adam are supported.")

    # data parallel
    if opt.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.gpu])
        
    if opt.lr_scheduler is not None:
        print("\n==> Learning rate scheduler: {}".format(opt.lr_scheduler))
        if opt.lr_scheduler == "cosineAnnealingLR":
            LrScheduler = CosineAnnealingLR(optimizer, opt.epochs)
    else:
        print("\n==> No learning rate scheduler.")
    
    # compute centroids
    centroid_helper = CentroidHelper(opt, num_classes)
    if opt.freeze_model:
        if not os.path.exists("./save/{}/teacher/{}/centroids.pt".format(opt.dataset, opt.model)):
            centroid_helper.initialize_centroids(model, train_loader)
    else:
        centroid_helper.initialize_centroids(model, train_loader, save=False)

    # get initial cmi value.
    # evaluate_centorid(opt, centroid_helper, model, val_loader)
    # test_acc, test_acc_top5, test_loss = evaluate_cifar100(opt, model, test_loader)
    # print("test_acc", test_acc)
    # print("test_acc_top5", test_acc_top5)
    # print("test_loss", test_loss)
    # return

    # start routine
    for epoch in range(opt.start_epoch, opt.epochs+1):
        if opt.distributed:
            train_loader.sampler.set_epoch(epoch)
            train_original_loader.sampler.set_epoch(epoch)
            
        print("\n==> Training.")
        time1 = time.time()
        train_acc, train_loss, train_ce_loss, train_mcmi_loss = 0, 0, 0, 0
        if epoch == 0:
            train_acc, train_loss, train_ce_loss, train_mcmi_loss, train_cmi_value = \
                train_cifar100(opt, model, centroid_helper, optimizer, train_loader, train_original_loader, val_loader, epoch, backward=False)
        else:
            train_acc, train_loss, train_ce_loss, train_mcmi_loss, train_cmi_value = \
                train_cifar100(opt, model, centroid_helper, optimizer, train_loader, train_original_loader, val_loader, epoch)
        time2 = time.time()
        print('Epoch [{}]: total time {:.5f} s'.format(epoch, time2-time1))
        
        # check training results
        val_mcmi_loss, val_cmi_value, val_ce_loss, val_acc = evaluate_centorid(opt, centroid_helper, model, val_loader)
        
        print("\n==> Validating.")
        time1 = time.time()
        test_acc, test_acc_top5, test_loss = evaluate_cifar100(opt, model, test_loader)
        time2 = time.time()
        print('Acc@1 {top1:.15f} Acc@5 {top5:.15f}'.format(top1=test_acc, top5=test_acc_top5))
        print('Epoch [{}]: total time {:.5f} s'.format(epoch, time2-time1))

        if opt.lr_scheduler is not None:
            LrScheduler.step()
            
        # only one process can record
        if (not hasattr(opt, 'rank') or opt.rank == 0):
            # save q_table
            if opt.JPEG_enable:
                # shape: [num_layers,1,1,1,size,size,1] => [1,num_layers,size,size]
                lum_qtable = model.module.jpeg_layer.lum_qtable.squeeze(1).squeeze(1).squeeze(1).squeeze(-1).unsqueeze(0).clone().detach()
                chrom_qtable = model.module.jpeg_layer.chrom_qtable.squeeze(1).squeeze(1).squeeze(1).squeeze(-1).unsqueeze(0).clone().detach()
                qTable = torch.cat((lum_qtable, chrom_qtable), 0).cpu()
                
                # save the exact q_table
                q_tables_file = os.path.join(opt.q_tables_folder, 'q_table_epoch_{}.pt'.format(epoch))
                torch.save(qTable, q_tables_file)
                
                # save lum and chrom q_table for ploting trend
                lum_qtable = lum_qtable.reshape(opt.num_jpeg_layers, 8, 8)
                chrom_qtable = chrom_qtable.reshape(opt.num_jpeg_layers, 8, 8)
                
                lum_q_tables_file = os.path.join(opt.q_tables_folder, 'lum_q_table.txt')
                with open(lum_q_tables_file, 'a') as f:
                    f.write('epoch_{}\n'.format(epoch))
                    for layer_index in range(opt.num_jpeg_layers):
                        np.savetxt(f, lum_qtable[layer_index].cpu().numpy(), delimiter=',', fmt='%.04f')
                        f.write('\n')
                chrom_q_tables_file = os.path.join(opt.q_tables_folder, 'chrom_q_table.txt')
                with open(chrom_q_tables_file, 'a') as f:
                    f.write('epoch_{}\n'.format(epoch))
                    for layer_index in range(opt.num_jpeg_layers):
                        np.savetxt(f, chrom_qtable[layer_index].cpu().numpy(), delimiter=',', fmt='%.04f')
                        f.write('\n')

            # save alpha_table
            if opt.JPEG_enable and (opt.JPEG_alpha_trainable or opt.hardness_matching):
                # shape: [num_layers,1,1,1,size,size,1] => [1,num_layers,size,size]
                lum_alpha = model.module.jpeg_layer.lum_alpha.squeeze(1).squeeze(1).squeeze(1).squeeze(-1).unsqueeze(0).clone().detach()
                chrom_alpha = model.module.jpeg_layer.chrom_alpha.squeeze(1).squeeze(1).squeeze(1).squeeze(-1).unsqueeze(0).clone().detach()
                alphaTable = torch.cat((lum_alpha, chrom_alpha), 0).cpu()

                # save the exact alpha_table
                alpha_tables_file = os.path.join(opt.alpha_folder, 'alpha_table_epoch_{}.pt'.format(epoch))
                torch.save(alphaTable, alpha_tables_file)
                
                # save lum and chrom alpha_table for ploting trend
                lum_alpha = lum_alpha.reshape(opt.num_jpeg_layers, 8, 8)
                chrom_alpha = chrom_alpha.reshape(opt.num_jpeg_layers, 8, 8)
                
                lum_alpha_tables_file = os.path.join(opt.alpha_folder, 'lum_alpha_table.txt')
                with open(lum_alpha_tables_file, 'a') as f:
                    f.write('epoch_{}\n'.format(epoch))
                    for layer_index in range(opt.num_jpeg_layers):
                        np.savetxt(f, lum_alpha[layer_index].cpu().numpy(), delimiter=',', fmt='%.04f')
                        f.write('\n')
                chrom_alpha_tables_file = os.path.join(opt.alpha_folder, 'chrom_alpha_table.txt')
                with open(chrom_alpha_tables_file, 'a') as f:
                    f.write('epoch_{}\n'.format(epoch))
                    for layer_index in range(opt.num_jpeg_layers):
                        np.savetxt(f, chrom_alpha[layer_index].cpu().numpy(), delimiter=',', fmt='%.04f')
                        f.write('\n')

            # save acc and loss
            new_log = {
                'epoch': epoch,
                'test_acc': test_acc, 'test_loss': test_loss,
                'val_acc': val_acc, 'val_mcmi_loss': val_mcmi_loss, 'val_cmi_value': val_cmi_value, 'val_ce_loss': val_ce_loss, 
                'train_acc': train_acc, 'train_mcmi_loss': train_mcmi_loss, 'train_cmi_value': train_cmi_value, 'train_ce_loss': train_ce_loss, 'train_total_loss': train_loss}
            if opt.JPEG_enable:
                new_log['Q_min'] = qTable.min().item()
                new_log['Q_max'] = qTable.max().item()
            with open(opt.logs_fname, 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(new_log.values())
            
            # save best model
            if not opt.freeze_model and test_acc > best_acc:
                    best_acc = test_acc
                    checkpoint = {'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(),}
                    torch.save(checkpoint, os.path.join(opt.trial_folder, "best.pth"))
    
    # save last model
    if (not hasattr(opt, 'rank') or opt.rank == 0):
        print("\n==> save last model.")
        checkpoint = {'model': model.module.state_dict(), 'optimizer': optimizer.state_dict()}
        torch.save(checkpoint, os.path.join(opt.trial_folder, "last.pth"))
    
    print("\n==> Finish training: ", opt.trial_folder)
    if dist.is_initialized():
        dist.destroy_process_group()


if __name__ == "__main__":
    opt = parse_options()
    
    timeStart = time.time()
    main(opt)
    timeEnd = time.time()
    
    print('\n==> Total time {:.2f} s.'.format(timeEnd - timeStart))
