import os
import time
import json
import re
import csv
import helper.presets as presets
import numpy as np
import random
import torch
import argparse
# import torch.utils.data
import torchvision
import torchvision.transforms
from helper.JPEG_layer import * 
from helper.loops import train_imagenet, evaluate_imagenet, evaluate_centorid
import helper.imagenet_utils as imagenet_utils
from helper.sampler import RASampler, FixedClassSampler, DistributedFixedClassSampler
# from torch import nn
from models import load_model_imagenet
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from helper.transforms import get_mixup_cutmix
from helper.centroids import CentroidHelper
from collections import Counter
import torch.distributed as dist


# helper functions
def create_folder(folder):
    if not os.path.isdir(folder):
        try:
            os.makedirs(folder)
        except FileExistsError:
            # the dict maybe created by other process
            print(f"[warning]: Directory {folder} already exists.")


def find_all_pt_files(folder_path):
    pattern = re.compile(r'q_table_epoch_(\d+)\.pt')
    
    matched_files = []
    for filename in os.listdir(folder_path):
        match = pattern.match(filename)
        if match:
            epoch_number = int(match.group(1))
            matched_files.append((epoch_number, filename))
    matched_files.sort(key=lambda x: x[0])
    sorted_filenames = [filename for _, filename in matched_files]
    print(sorted_filenames)
    return sorted_filenames


def _get_cache_path(filepath):
    import hashlib
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


def count_classes(dataset):
    idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}
    class_indices = [label for _, label in dataset.samples]
    class_counts = Counter(class_indices)
    return {idx_to_class[idx]: count for idx, count in class_counts.items()}


def save_class_counts_as_vector(dataset, class_counts, output_file):
    class_to_idx = dataset.class_to_idx
    num_classes = len(class_to_idx)
    class_count_vector = torch.zeros(num_classes, dtype=torch.float32)

    for class_name, count in class_counts.items():
        if class_name in class_to_idx:
            class_idx = class_to_idx[class_name]
            class_count_vector[class_idx] = count

    with open(output_file, "w") as f:
        json.dump(class_count_vector.tolist(), f)
    print(f"Class count vector saved to {output_file}")


# main functions
def parse_option(add_help=True):
    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
    
    # optimization
    parser.add_argument("--batch_size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size")
    parser.add_argument("--num_workers", default=16, type=int, metavar="N", help="number of data loading workers")
    parser.add_argument("--epochs", default=20, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument('--lr_decay_epochs', type=str, default='110', help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
    parser.add_argument("--weight_decay", "--wd",  default=1e-4, type=float, metavar="W", help="weight decay", dest="weight_decay")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    
    # model and dataset
    parser.add_argument("--model", default="Resnet34", type=str, choices=['Resnet18','Resnet34','Resnet50','mobilenet_v2','ConvNeXt_tiny','Mnasnet',"Swin_T"], help="model name")
    parser.add_argument("--data_path", default=None, type=str, help="dataset path")
    parser.add_argument("--finetune_model_path", default=None, type=str, help="pretrain underlying model path")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100','imagenet'], help='dataset')
    parser.add_argument("--output_dir", default=".", type=str, help="path to save outputs")
    parser.add_argument('--num_samples_per_class', type=int, default=16, help='number of samples per class for mini-dataset')
    parser.add_argument("--optimizer", default="sgd", choices=['sgd','adam'], type=str, help="optimizer")
    parser.add_argument("--trial", default=1, type=int, help="trial number")
    parser.add_argument('--seed', type=int, default=0, help='seed id, set to 0 if do not want to fix the seed')
    parser.add_argument('--accumulation_steps', type=int, default=1, help='Number of micro-batches to accumulate')
    
    # JPEG
    parser.add_argument('--JPEG_enable', action='store_true')
    parser.add_argument('--freeze_model', action='store_true')
    parser.add_argument('--JPEG_alpha_trainable', action='store_true')
    parser.add_argument('--analysis_mode', action='store_true')
    parser.add_argument('--hardness_matching', action='store_true')
    parser.add_argument('--JPEG_layer_blockwise', action='store_true')
    parser.add_argument("--train_mode", action='store_true')
    
    # JPEG Hyper
    parser.add_argument('--lambda_MCMI', type=float, default=0.4, help='weight balance for CMI loss')
    parser.add_argument('--JPEG_alpha', type=float, default=20.0, help='Tempurature scaling')
    parser.add_argument("--learning_rate", default=0.1, type=float, help="initial learning rate")
    parser.add_argument('--JPEG_learning_rate', type=float, default=0.01, help='Quantization Table Learning Rate')
    parser.add_argument('--alpha_learning_rate', type=float, default=None, help='Alpha Learning Rate')
    parser.add_argument('--Q_inital', type=float, default=1.0, help='Initial Quantization Step')
    parser.add_argument('--num_jpeg_layers', type=int, default=1, help='the number of consecutive JPEG layers')
    parser.add_argument('--block_size', type=int, default=8, help='the JPEG block size')
    parser.add_argument('--outter_block_size', type=int, default=32, help='the outter block size for blockwise JEPG')
    parser.add_argument('--num_bit', type=int, default=11, help='Number of bits to represent DCT coeff')
    parser.add_argument('--min_Q_Step', type=float, default=1.0, help='Minumum Quantization Step')
    parser.add_argument('--max_Q_Step', type=float, default=255, help='Maximum Quantization Step')
    parser.add_argument('--num_non_zero_q', type=int, default=5, choices=range(2,2**10 - 1), help='Window size for the reconstruction space')

    # other parameter for datasets
    parser.add_argument("--norm_weight_decay", default=None, type=float, help="weight decay for Normalization layers (default: None, same value as --wd)")
    parser.add_argument("--bias_weight_decay", default=None, type=float, help="weight decay for bias parameters of all layers (default: None, same value as --wd)")
    parser.add_argument("--transformer_embedding_decay", default=None, type=float,help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)")
    parser.add_argument("--label_smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing")
    parser.add_argument("--mixup_alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
    parser.add_argument("--cutmix_alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
    parser.add_argument("--lr_scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
    parser.add_argument("--lr_warmup_epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument("--lr_warmup_method", default="constant", type=str, help="the warmup method (default: constant)")
    parser.add_argument("--lr_warmup_decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr_step_size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr_gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--lr_min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
    parser.add_argument("--print_freq", default=50, type=int, help="print frequency")
    parser.add_argument("--centroid_update_freq", default=1, type=int, help="centroid update frequency")
    parser.add_argument("--start_epoch", default=0, type=int, metavar="N", help="start epoch")
    parser.add_argument("--cache_dataset", dest="cache_dataset", help="Cache the datasets for quicker initialization. It also serializes the transforms", action="store_true")
    parser.add_argument("--sync_bn", dest="sync_bn", help="Use sync batch norm", action="store_true",)
    parser.add_argument("--test_only", dest="test_only", help="Only test the model", action="store_true",)
    parser.add_argument("--test_checkpoint", dest="test_checkpoint", help="Only test the model with checkpoint", action="store_true",)
    parser.add_argument("--test_path", type=str, help="test experiment path")
    parser.add_argument("--auto_augment", default=None, type=str, help="auto augment policy (default: None)")
    parser.add_argument("--ra_magnitude", default=9, type=int, help="magnitude of auto augment policy")
    parser.add_argument("--augmix_severity", default=3, type=int, help="severity of augmix policy")
    parser.add_argument("--random_erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")

    # distributed training parameters
    parser.add_argument("--world_size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist_url", default="env://", type=str, help="url used to set up distributed training")
    parser.add_argument("--use_deterministic_algorithms", action="store_true", help="Forces the use of deterministic algorithms only.")
    parser.add_argument("--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)")
    parser.add_argument("--val_resize_size", default=256, type=int, help="the resize size used for validation (default: 256)")
    parser.add_argument("--val_crop_size", default=224, type=int, help="the central crop size used for validation (default: 224)")
    parser.add_argument("--train_crop_size", default=224, type=int, help="the random crop size used for training (default: 224)")
    parser.add_argument("--clip_grad_norm", default=None, type=float, help="the maximum gradient norm (default None)")
    parser.add_argument("--ra_sampler", action="store_true", help="whether to use Repeated Augmentation in training")
    parser.add_argument("--ra_reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)")
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
    parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
    parser.add_argument("--use_v2", action="store_true", help="Use V2 transforms")
    
    opt = parser.parse_args()
    return opt
    

def prepare_dir(opt):
    if opt.test_only or opt.test_checkpoint:
        with open(os.path.join("./save/{}/teacher/{}".format(opt.dataset, opt.model), opt.test_path, "opt.json"), 'r') as f:
            test_opt = argparse.Namespace(**json.load(f))
        opt.added_layer = test_opt.added_layer
        opt.model_name = test_opt.model_name
        opt.output_dir = test_opt.output_dir
        opt.trial_folder = test_opt.trial_folder
        if test_opt.JPEG_enable:
            opt.q_tables_folder = test_opt.q_tables_folder
        if test_opt.JPEG_alpha_trainable or test_opt.hardness_matching:
            opt.alpha_folder = test_opt.alpha_folder
        opt.logs_fname = test_opt.logs_fname
        opt.opt_fname = test_opt.opt_fname
        return opt

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))
    
    if opt.JPEG_enable:
        opt.added_layer = "JPEG"
        if opt.JPEG_alpha_trainable:
            opt.model_name = 'JPEG{}_lr_{}_alpha_{}_alpha_lr_{}_lambda_{}'.format(opt.num_jpeg_layers, opt.JPEG_learning_rate, opt.JPEG_alpha, opt.alpha_learning_rate, opt.lambda_MCMI)
        elif opt.hardness_matching:
            opt.model_name = 'JPEG{}_lr_{}_hardness_{}_lambda_{}'.format(opt.num_jpeg_layers, opt.JPEG_learning_rate, opt.JPEG_alpha, opt.lambda_MCMI)
        else:
            opt.model_name = 'JPEG{}_lr_{}_alpha_{}_lambda_{}'.format(opt.num_jpeg_layers, opt.JPEG_learning_rate, opt.JPEG_alpha, opt.lambda_MCMI)
        if not opt.freeze_model:
            opt.model_name += "_underlying_lr_{}".format(opt.learning_rate)
        if opt.JPEG_layer_blockwise:
            opt.model_name += "_blockwise"
        if opt.finetune_model_path != None:
            opt.model_name += "_finetune"
        if opt.train_mode:
            opt.model_name += "_train_mode"
    else:
        opt.added_layer = "vanilla"
        opt.model_name = 'lr_{}_lambda_{}'.format(opt.learning_rate, opt.lambda_MCMI)
    
    opt.output_dir = './save/{}/teacher/{}'.format(opt.dataset, opt.model)
    create_folder(opt.output_dir)

    base_folder = os.path.join(opt.output_dir, opt.model_name)
    opt.trial_folder = os.path.join(base_folder, "trial_{}".format(opt.trial))
    create_folder(opt.trial_folder)
    
    if opt.JPEG_enable:
        opt.q_tables_folder = os.path.join(opt.trial_folder, "Q_tables")
        create_folder(opt.q_tables_folder)
    
    if opt.JPEG_alpha_trainable or opt.hardness_matching:
        opt.alpha_folder = os.path.join(opt.trial_folder, "Alpha_tables")
        create_folder(opt.alpha_folder)
    
    opt.logs_fname = os.path.join(opt.trial_folder, "logs.csv")
    if (not hasattr(opt, 'rank') or opt.rank == 0):
        columns = ['epoch', 'test_acc', 'test_loss', 'val_acc', 'val_mcmi_loss', 'val_cmi_value', 'val_ce_loss', 'train_acc', 'train_mcmi_loss', 'train_cmi_value', 'train_ce_loss', 'train_total_loss']
        if opt.JPEG_enable:
            columns.extend(['Q_min', 'Q_max'])
        if opt.JPEG_enable and opt.JPEG_alpha_trainable:
            columns.extend(['Alpha_min', 'Alpha_max'])
        with open(opt.logs_fname, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(columns)
    
    opt.opt_fname = os.path.join(opt.trial_folder, "opt.json")
    with open(opt.opt_fname, 'w') as f:
        try:
            json.dump(vars(opt), f, indent=4)
        except FileExistsError:
            print("\n==> save opt.json failed.")
    
    print("\n==> Finish prepare directory: ", opt.model_name)
    return opt


def load_data(traindir, valdir, opt, mean=(0.,0.,0.), std=(1/255., 1/255., 1/255.)):
    # Data loading code
    print("\n==> Loading data.")
    val_resize_size, val_crop_size, train_crop_size = (opt.val_resize_size, opt.val_crop_size, opt.train_crop_size,)
    interpolation = InterpolationMode(opt.interpolation)

    # print("\n==> Loading training data")
    auto_augment_policy = getattr(opt, "auto_augment", None)
    random_erase_prob = getattr(opt, "random_erase", 0.0)
    ra_magnitude = getattr(opt, "ra_magnitude", None)
    augmix_severity = getattr(opt, "augmix_severity", None)
    train_dataset = torchvision.datasets.ImageFolder(traindir,
        presets.ClassificationPresetTrain(
            crop_size=train_crop_size,
            interpolation=interpolation,
            auto_augment_policy=auto_augment_policy,
            random_erase_prob=random_erase_prob,
            ra_magnitude=ra_magnitude,
            augmix_severity=augmix_severity,
            backend=opt.backend,
            use_v2=opt.use_v2,
            mean=mean, std=std),)
    
    # print("\n==> Loading validating data")
    val_dataset = torchvision.datasets.ImageFolder(traindir,
        presets.ClassificationPresetEval(
            crop_size=val_crop_size,
            resize_size=val_resize_size,
            interpolation=interpolation,
            backend=opt.backend,
            use_v2=opt.use_v2,
            mean=mean, std=std,))
    
    # print("\n==> Loading testing data")
    if opt.weights and opt.test_only:
        weights = torchvision.models.get_weight(opt.weights)
        preprocessing = weights.transforms(antialias=True)
        if opt.backend == "tensor":
            preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing])
    else:
        preprocessing = presets.ClassificationPresetEval(
            crop_size=val_crop_size,
            resize_size=val_resize_size,
            interpolation=interpolation,
            backend=opt.backend,
            use_v2=opt.use_v2,
            mean=mean, std=std,)
    test_dataset = torchvision.datasets.ImageFolder(valdir, preprocessing)
    
    # Creating data loaders
    if opt.distributed:
        if hasattr(opt, "ra_sampler") and opt.ra_sampler:
            train_sampler = RASampler(train_dataset, shuffle=True, repetitions=opt.ra_reps)
            val_sampler = RASampler(train_dataset, shuffle=False, repetitions=opt.ra_reps)
        else:
            train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
            val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False)
        test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, shuffle=False)
    else:
        train_sampler = torch.utils.data.RandomSampler(train_dataset)
        val_sampler = torch.utils.data.RandomSampler(val_dataset)
        test_sampler = torch.utils.data.SequentialSampler(test_dataset)

    return train_dataset, val_dataset, test_dataset, train_sampler, val_sampler, test_sampler


def load_mini_data(traindir, opt, mean=(0.0, 0.0, 0.0), std=(1/255., 1/255., 1/255.)):
    # Data loading code
    print("\n==> Loading mini data.")
    val_resize_size, val_crop_size, train_crop_size = (opt.val_resize_size, opt.val_crop_size, opt.train_crop_size,)
    interpolation = InterpolationMode(opt.interpolation)

    # Loading training data
    cache_path = _get_cache_path(traindir)
    auto_augment_policy = getattr(opt, "auto_augment", None)
    random_erase_prob = getattr(opt, "random_erase", 0.0)
    ra_magnitude = getattr(opt, "ra_magnitude", None)
    augmix_severity = getattr(opt, "augmix_severity", None)
    dataset = torchvision.datasets.ImageFolder(traindir,
        presets.ClassificationPresetTrain(
            crop_size=train_crop_size,
            interpolation=interpolation,
            auto_augment_policy=auto_augment_policy,
            random_erase_prob=random_erase_prob,
            ra_magnitude=ra_magnitude,
            augmix_severity=augmix_severity,
            backend=opt.backend,
            use_v2=opt.use_v2,
            mean=mean, std=std,),)

    # Creating data loaders
    if opt.distributed:
        train_sampler = DistributedFixedClassSampler(dataset, num_samples_per_class=opt.num_samples_per_class)
    else:
        train_sampler = FixedClassSampler(dataset, num_samples_per_class=opt.num_samples_per_class)
    return dataset, train_sampler


def main(opt):
    # initial distributed mode
    imagenet_utils.init_distributed_mode(opt)
    opt = prepare_dir(opt)
    device = torch.device(opt.device)
    
    if opt.seed:
        torch.manual_seed(opt.seed)
        torch.cuda.manual_seed(opt.seed)
        torch.cuda.manual_seed_all(opt.seed)
        np.random.seed(opt.seed)
        random.seed(opt.seed)
        torch.manual_seed(opt.seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        # torch.use_deterministic_algorithms(True)
        print("\n==> Use seed", opt.seed)

    # load dataset
    train_dir = os.path.join(opt.data_path, "train")
    test_dir = os.path.join(opt.data_path, "val")
    if opt.JPEG_enable:
        mean, std = (0.,0.,0.), (1/255., 1/255., 1/255.)
    else:
        mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
    train_set, val_set, test_set, train_sampler, val_sampler, test_sampler = load_data(train_dir, test_dir, opt, mean=mean, std=std)
    mini_dataset, mini_sampler = load_mini_data(train_dir, opt, mean=mean, std=std)
    
    # save the number of samples per class for imagenet
    if not os.path.exists("./save/{}/train_class_counts.json".format(opt.dataset)):
        train_class_counts = count_classes(train_set)
        val_class_counts = count_classes(test_set)
        save_class_counts_as_vector(train_set, train_class_counts, "./save/{}/train_class_counts.json".format(opt.dataset))
        save_class_counts_as_vector(test_set, val_class_counts, "./save/{}/test_class_counts.json".format(opt.dataset))
    
    # data augmentation
    num_classes = len(train_set.classes)
    mixup_cutmix = get_mixup_cutmix(mixup_alpha=opt.mixup_alpha, cutmix_alpha=opt.cutmix_alpha, num_classes=num_classes, use_v2=opt.use_v2)
    if mixup_cutmix is not None:
        def collate_fn(batch):
            return mixup_cutmix(*default_collate(batch))
    else:
        print("\n==> No mixup cutmix.")
        collate_fn = default_collate

    # data_loader
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=opt.batch_size, sampler=train_sampler, num_workers=opt.num_workers, pin_memory=True, collate_fn=collate_fn)
    val_loader   = torch.utils.data.DataLoader(val_set,   batch_size=opt.batch_size, sampler=val_sampler,   num_workers=opt.num_workers, pin_memory=True)
    mini_loader  = torch.utils.data.DataLoader(mini_dataset,  batch_size=opt.batch_size, sampler=mini_sampler,  num_workers=opt.num_workers, pin_memory=True, collate_fn=collate_fn)
    test_loader  = torch.utils.data.DataLoader(test_set,  batch_size=opt.batch_size, sampler=test_sampler,  num_workers=opt.num_workers, pin_memory=True)

    # load underlying model
    print("\n==> Loading model {}".format(opt.model))
    if opt.finetune_model_path != None:
        underlying_model = load_model_imagenet(opt.model, pretrained=False)
        checkpoint = torch.load(opt.finetune_model_path)
        underlying_model.load_state_dict(checkpoint['model'])
    else:
        underlying_model = load_model_imagenet(opt.model, pretrained=True)
    
    # load whole model
    opt_name = opt.optimizer.lower()
    if opt.JPEG_enable:
        print("\n==> JPEG_enable")
        if opt.JPEG_layer_blockwise:
            jpeg_layer = JPEG_layer_blockwise(opt=opt, img_shape=(opt.train_crop_size, opt.train_crop_size, 3), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        else:
            jpeg_layer = JPEG_layers(opt=opt, img_shape=(opt.train_crop_size, opt.train_crop_size, 3), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        model = CustomModel(jpeg_layer, underlying_model)
        if opt_name == "sgd":
            optimizer_data = [{'params': model.jpeg_layer.lum_qtable, 'lr': opt.JPEG_learning_rate, 'momentum': opt.momentum,},
                              {'params': model.jpeg_layer.chrom_qtable, 'lr': opt.JPEG_learning_rate, 'momentum': opt.momentum,}]
            if opt.JPEG_alpha_trainable:
                optimizer_data.append({'params': model.jpeg_layer.lum_alpha, 'lr': opt.alpha_learning_rate, 'momentum': opt.momentum,})
                optimizer_data.append({'params': model.jpeg_layer.chrom_alpha, 'lr': opt.alpha_learning_rate, 'momentum': opt.momentum,})
            if not opt.freeze_model:
                optimizer_data.append({'params': model.underlying_model.parameters(), 'lr': opt.learning_rate, 'momentum': opt.momentum,})
        elif opt_name == "adam":
            optimizer_data = [{'params': model.jpeg_layer.lum_qtable}, {'params': model.jpeg_layer.chrom_qtable}]
            if opt.JPEG_alpha_trainable:
                optimizer_data.append({'params': model.jpeg_layer.lum_alpha})
                optimizer_data.append({'params': model.jpeg_layer.chrom_alpha})
            if not opt.freeze_model:
                optimizer_data.append({'params': model.underlying_model.parameters()})
        else:
            raise RuntimeError(f"Invalid optimizer {opt.optimizer}. Only SGD, Adam are supported.")
    else:
        print("\n==> not JPEG_enable")
        model = underlying_model
        if opt_name == "sgd":
            optimizer_data = [{'params': model.parameters(), 'lr': opt.learning_rate, 'momentum': opt.momentum}]
        elif opt_name == "adam":
            optimizer_data = [{'params': model.parameters()}]
        else:
            raise RuntimeError(f"Invalid optimizer {opt.optimizer}. Only SGD, Adam are supported.")
    
    if opt.distributed and opt.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    
    if torch.cuda.is_available():
        model = model.to(device)
    
    # optimizer
    if opt_name == "sgd":
        print("\n==> Using optimizer SGD.")
        optimizer = torch.optim.SGD(optimizer_data)
    elif opt_name == "adam":
        print("\n==> Using optimizer Adam.")
        optimizer = torch.optim.Adam(optimizer_data, lr=1e-2, betas=(0.9, 0.999), eps=1e-08)
    else:
        raise RuntimeError(f"Invalid optimizer {opt.optimizer}. Only SGD, and Adam are supported.")
    
    # data parallel
    if opt.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.gpu])
    
    # compute centroids
    centroid_helper = CentroidHelper(opt, num_classes)
    if not os.path.exists("./save/imagenet/teacher/{}/centroids.pt".format(opt.model)):
        centroid_helper.initialize_centroids(model, train_loader)
    
    # get initial cmi value.
    # evaluate_centorid(opt, centroid_helper, model, val_loader)
    # test_acc, test_acc_top5, test_loss = evaluate_imagenet(opt, model, test_loader)
    # print("test_acc", test_acc)
    # print("test_acc_top5", test_acc_top5)
    # print("test_loss", test_loss)
    # return
 
    # routine
    for epoch in range(opt.start_epoch, opt.epochs+1):
        if opt.distributed:
            train_sampler.set_epoch(epoch)
        
        print("\n==> Training.")
        time1 = time.time()
        train_acc, train_loss, train_ce_loss, train_mcmi_loss, train_cmi_value = 0, 0, 0, 0, 0
        if epoch == 0:
            train_acc, train_loss, train_ce_loss, train_mcmi_loss, train_cmi_value = \
                train_imagenet(opt, model, centroid_helper, optimizer, train_loader, val_loader, mini_loader, epoch, backward=False)
        else:
            train_acc, train_loss, train_ce_loss, train_mcmi_loss, train_cmi_value = \
                train_imagenet(opt, model, centroid_helper, optimizer, train_loader, val_loader, mini_loader, epoch)
        time2 = time.time()
        print('Epoch [{}]: total time {:.5f} s'.format(epoch, time2-time1))
        
        # check training results
        # val_mcmi_loss, val_cmi_value, val_ce_loss, val_acc = 0, 0, 0, 0
        val_mcmi_loss, val_cmi_value, val_ce_loss, val_acc = evaluate_centorid(opt, centroid_helper, model, val_loader)
        
        print("\n==> Validating.")
        time1 = time.time()
        # test_acc, test_acc_top5, test_loss = 0, 0, 0
        test_acc, test_acc_top5, test_loss = evaluate_imagenet(opt, model, test_loader)
        time2 = time.time()
        print('Acc@1 {top1:.15f} Acc@5 {top5:.15f}'.format(top1=test_acc, top5=test_acc_top5))
        print('Epoch [{}]: total time {:.5f} s'.format(epoch, time2-time1))
        
        # only one process can record
        if (not hasattr(opt, 'rank') or opt.rank == 0):
            # save q_table
            if opt.JPEG_enable:
                # shape: [num_layers,1,1,1,size,size,1] => [1,num_layers,size,size] or
                # shape: [num_layers,1,1,B1,1,size,size,1] => [1,num_layers,B1,size,size]
                lum_qtable = model.module.jpeg_layer.lum_qtable.squeeze().unsqueeze(0).clone().detach()
                chrom_qtable = model.module.jpeg_layer.chrom_qtable.squeeze().unsqueeze(0).clone().detach()
                qTable = torch.cat((lum_qtable, chrom_qtable), 0).cpu()
                
                # save the exact q_table
                q_tables_file = os.path.join(opt.q_tables_folder, 'q_table_epoch_{}.pt'.format(epoch))
                torch.save(qTable, q_tables_file)
                
                if not opt.JPEG_layer_blockwise:
                    # save lum and chrom q_table for ploting trend
                    lum_qtable = lum_qtable.reshape(opt.num_jpeg_layers, 8, 8)
                    chrom_qtable = chrom_qtable.reshape(opt.num_jpeg_layers, 8, 8)
                    
                    lum_q_tables_file = os.path.join(opt.q_tables_folder, 'lum_q_table.txt')
                    with open(lum_q_tables_file, 'a') as f:
                        f.write('epoch_{}\n'.format(epoch))
                        for layer_index in range(opt.num_jpeg_layers):
                            np.savetxt(f, lum_qtable[layer_index].cpu().numpy(), delimiter=',', fmt='%.04f')
                            f.write('\n')
                    chrom_q_tables_file = os.path.join(opt.q_tables_folder, 'chrom_q_table.txt')
                    with open(chrom_q_tables_file, 'a') as f:
                        f.write('epoch_{}\n'.format(epoch))
                        for layer_index in range(opt.num_jpeg_layers):
                            np.savetxt(f, chrom_qtable[layer_index].cpu().numpy(), delimiter=',', fmt='%.04f')
                            f.write('\n')

            # save alpha_table
            if opt.JPEG_enable and (opt.JPEG_alpha_trainable or opt.hardness_matching):
                # shape: [num_layers,1,1,1,size,size,1] => [1,num_layers,size,size] or
                # shape: [num_layers,1,1,B1,1,size,size,1] => [1,num_layers,B1,size,size]
                lum_alpha = model.module.jpeg_layer.lum_alpha.squeeze().unsqueeze(0).clone().detach()
                chrom_alpha = model.module.jpeg_layer.chrom_alpha.squeeze().unsqueeze(0).clone().detach()
                alphaTable = torch.cat((lum_alpha, chrom_alpha), 0).cpu()
            
                # save the exact alpha_table
                alpha_tables_file = os.path.join(opt.alpha_folder, 'alpha_table_epoch_{}.pt'.format(epoch))
                torch.save(alphaTable, alpha_tables_file)
                
                if not opt.JPEG_layer_blockwise:
                    # save lum and chrom alpha_table for ploting trend
                    lum_alpha = lum_alpha.reshape(opt.num_jpeg_layers, 8, 8)
                    chrom_alpha = chrom_alpha.reshape(opt.num_jpeg_layers, 8, 8)
                    
                    lum_alpha_tables_file = os.path.join(opt.alpha_folder, 'lum_alpha_table.txt')
                    with open(lum_alpha_tables_file, 'a') as f:
                        f.write('epoch_{}\n'.format(epoch))
                        for layer_index in range(opt.num_jpeg_layers):
                            np.savetxt(f, lum_alpha[layer_index].cpu().numpy(), delimiter=',', fmt='%.04f')
                            f.write('\n')
                    chrom_alpha_tables_file = os.path.join(opt.alpha_folder, 'chrom_alpha_table.txt')
                    with open(chrom_alpha_tables_file, 'a') as f:
                        f.write('epoch_{}\n'.format(epoch))
                        for layer_index in range(opt.num_jpeg_layers):
                            np.savetxt(f, chrom_alpha[layer_index].cpu().numpy(), delimiter=',', fmt='%.04f')
                            f.write('\n')

            # save acc and loss
            new_log = {
                'epoch': epoch,
                'test_acc': test_acc, 'test_loss': test_loss,
                'val_acc': val_acc, 'val_mcmi_loss': val_mcmi_loss, 'val_cmi_value': val_cmi_value, 'val_ce_loss': val_ce_loss, 
                'train_acc': train_acc, 'train_mcmi_loss': train_mcmi_loss, 'train_cmi_value': train_cmi_value, 'train_ce_loss': train_ce_loss, 'train_total_loss': train_loss}
            if opt.JPEG_enable:
                new_log['Q_min'] = qTable.min().item()
                new_log['Q_max'] = qTable.max().item()
            with open(opt.logs_fname, 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(new_log.values())

            if not opt.freeze_model:
                print("\n==> Freeze model, saving.")
                checkpoint = {'model': model.module.state_dict(), 'optimizer': optimizer.state_dict()}
                torch.save(checkpoint, os.path.join(opt.trial_folder, "{}.pth",format(epoch)))
            
    if (not hasattr(opt, 'rank') or opt.rank == 0) and not opt.freeze_model:
        # save model of last epoch if train underlying model
        print("\n==> save last model.")
        checkpoint = {'model': model.module.state_dict(), 'optimizer': optimizer.state_dict()}
        torch.save(checkpoint, os.path.join(opt.trial_folder, "last.pth"))
    
    print("\n==> Finish training: ", opt.trial_folder)
    if dist.is_initialized():
        dist.destroy_process_group()


if __name__ == "__main__":
    opt = parse_option()
    
    timeStart = time.time()
    main(opt)
    timeEnd = time.time()
    
    print('\n==> Total time {:.2f} s.'.format(timeEnd - timeStart))

