from __future__ import print_function
import os
import argparse
import time
import json
import numpy as np
import random
import torch
import torch.optim as optim
import torch.nn as nn
import torch.backends.cudnn as cudnn
from models import cifar100_model_dict
from dataset.cifar100 import get_cifar100_dataloaders_default, get_cifar100_dataloaders_sample
# from helper.util import adjust_learning_rate
from models.util import Embed, ConvReg, LinearEmbed
from models.util import Connector, Translator, Paraphraser
from distiller_zoo import TTM, WTTM, DistillKL, CRDLoss, ITLoss, DIST, DKDloss
from distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss
from distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss, LSKDLoss
from helper.loops import train_distill, validate, init
from helper.JPEG_layer import *


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Model Training will be done on this device: ", device)


def parse_option():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
    parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
    parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=240, help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='150,180,210', help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')

    # dataset
    parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100', 'cifar10'], help='dataset')

    # model
    parser.add_argument('--JPEG_enable', action='store_true')
    parser.add_argument('--train_mode', action='store_true')
    parser.add_argument('--model_t', type=str, default=None, help='teacher model')
    parser.add_argument('--model_s', type=str, default=None, help='student model')
    parser.add_argument("--base_path", type=str, default=None)
    parser.add_argument('--q_table_epoch', type=int, default=20, help='number of training epochs')
    parser.add_argument("--finetune_model_path", type=str, default=None)
    
    # distillation
    parser.add_argument('--add', type=str, default='kd', choices=['kd', 'ttm', 'wttm'])
    parser.add_argument('--distill', type=str, default='kd', choices=['kd', 'cc', 'fitnet', 'ft', 'ab',  'sp', 'rkd', 'fsp', 'at', 'dist', 'dkd', 'itrd', 'crd', 'lskd', 'wttm',])
    parser.add_argument('--trial', type=str, default='1', help='trial id')
    parser.add_argument('--seed', type=int, default=0, help='seed id, set to 0 if do not want to fix the seed')
    parser.add_argument('--init_epochs', type=int, default=30, help='init training for two-stage methods')
    
    parser.add_argument('-r', '--gamma', type=float, default=1, help='weight for classification')
    parser.add_argument('-a', '--alpha', type=float, default=0, help='weight balance for additional loss')
    parser.add_argument('-b', '--beta', type=float, default=0, help='weight balance for main loss')

    # KD distillation
    parser.add_argument('--kd_T', type=float, default=4, help='temperature for KD distillation')

    # TTM and WTTM distillation
    parser.add_argument('--ttm_l', type=float, default=1, help='exponent for TTM and WTTM distillation')

    # DKD distillation
    parser.add_argument('--dkd_alpha', default=1, type=float)
    parser.add_argument('--dkd_beta', default=2, type=float)
    
    # ITRD distillation
    parser.add_argument('--lambda_corr', type=float, default=2.0, help='correlation loss weight')
    parser.add_argument('--lambda_mutual', type=float, default=1.0, help='mutual information loss weight')
    parser.add_argument('--alpha_it', type=float, default=1.01, help='Renyis alpha')

    # DIST distillation
    parser.add_argument('--dist_beta', type=float, default=1.0, help='weight for inter loss')
    parser.add_argument('--dist_gamma', type=float, default=1.0, help='weight for intra loss')
    parser.add_argument('--dist_tau', type=float, default=4.0, help='temperature for DIST distillation')
   
    # NCE distillation
    parser.add_argument('--feat_dim', default=128, type=int, help='feature dimension')
    parser.add_argument('--mode', default='exact', type=str, choices=['exact', 'relax'])
    parser.add_argument('--nce_k', default=16384, type=int, help='number of negative samples for NCE')
    parser.add_argument('--nce_t', default=0.07, type=float, help='temperature parameter for softmax')
    parser.add_argument('--nce_m', default=0.5, type=float, help='momentum for non-parametric updates')

    # FitNet distillation
    parser.add_argument('--hint_layer', default=2, type=int, choices=[0, 1, 2, 3, 4])
    
    args = parser.parse_args()
    return args


def prepare_dir(args):
    opt = None
    if args.JPEG_enable:
        opt_json_filepath = os.path.join(args.base_path, 'opt.json')
        with open(opt_json_filepath, 'r') as f:
            data = json.load(f)
        opt = argparse.Namespace(**data)
        experiment_name = opt.model_name + "_trial_" + str(args.trial)
    else:
        experiment_name = "vallina" + "_trial_" + str(args.trial)
        
    if args.finetune_model_path != None:
        experiment_name += "_" + args.finetune_model_path
        # args.finetune_model_path = os.path.join('./save/cifar100/teacher/{}/{}/trial_1/last.pth'.format(args.model_t, args.finetune_model_path))
        args.finetune_model_path = os.path.join('./save/cifar100/mcmi/{}.pth'.format(args.finetune_model_path))
    if args.train_mode:
        experiment_name += "_train_mode"
    args.experiment_name = os.path.join("{}_{}".format(args.model_t, args.model_s), "{}".format(args.distill), experiment_name)
    
    args.save_folder = os.path.join("./save/{}/student".format(args.dataset), args.experiment_name)
    if not os.path.exists(args.save_folder):
        os.makedirs(args.save_folder)
    
    if args.model_s in ['MobileNetV2', 'ShuffleV1', 'ShuffleV2']:
        args.learning_rate = 0.01
    
    iterations = args.lr_decay_epochs.split(',')
    args.lr_decay_epochs = list([])
    for it in iterations:
        args.lr_decay_epochs.append(int(it))
    
    # log file
    args.log_fname = os.path.join(args.save_folder, 'working.txt')
    args.results_fname = os.path.join("./save/{}/student/{}_{}/{}".format(args.dataset, args.model_t, args.model_s, args.distill), 'results.txt')
    
    return args, opt


def main():
    best_acc = 0
    args = parse_option()
    args, opt = prepare_dir(args)

    if torch.cuda.is_available():
        cudnn.benchmark = True
    print(torch.cuda.is_available())    
    if args.seed:
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # dataloader
    if args.dataset == 'cifar100' or args.dataset == 'cifar10':
        mean=(0.5071, 0.4867, 0.4408)
        std=(0.2675, 0.2565, 0.2761)
        if args.JPEG_enable:
            mean_datatloader=(0., 0., 0.)
            std_datatloader=(1/255., 1/255., 1/255.)
        else:
            mean_datatloader=mean
            std_datatloader=std
    else:
        raise NotImplementedError(args.dataset)
    
    # dataloader
    if args.dataset == 'cifar100':
        num_classes = 100
        if args.distill in ['crd']:
            train_loader, val_loader, n_data = get_cifar100_dataloaders_sample(batch_size=args.batch_size, num_workers=args.num_workers, k=args.nce_k, mode=args.mode, mean=mean_datatloader, std=std_datatloader)
            train_loader_original , val_loader_original, _ = get_cifar100_dataloaders_sample(batch_size=args.batch_size, num_workers=args.num_workers, k=args.nce_k, mode=args.mode, mean=mean, std=std)
        else:
            train_loader, val_loader, n_data = get_cifar100_dataloaders_default(batch_size=args.batch_size, num_workers=args.num_workers, opt=args, mean=mean_datatloader, std=std_datatloader, is_instance=True)
            train_loader_original , val_loader_original, _ = get_cifar100_dataloaders_default(batch_size=args.batch_size, num_workers=args.num_workers, opt=args, mean=mean, std=std, is_instance=True)
    else:
        raise NotImplementedError(args.dataset)

    # model
    model_s = cifar100_model_dict[args.model_s](num_classes=num_classes)
    net = cifar100_model_dict[args.model_t]
    pretrain_model_path = "./save/models/{}_vanilla/ckpt_epoch_240.pth".format(args.model_t)
    if args.JPEG_enable:
        underlying_model = net(num_classes=num_classes)
        if args.finetune_model_path != None:
            print("load JPEG + fine-tuned model.")
            underlying_model.load_state_dict(torch.load(args.finetune_model_path, weights_only=False)["model"])
        else:
            print("load JPEG + pretrained model.")
            underlying_model.load_state_dict(torch.load(pretrain_model_path, weights_only=False)["model"])
        jpeg_layer = JPEG_layers(opt=opt, img_shape=train_loader.dataset.data[0].shape, mean=mean, std=std)        
        model_t = CustomModel(jpeg_layer, underlying_model)
        # load q_table
        qTable = torch.load(os.path.join(opt.q_tables_folder, 'q_table_epoch_{}.pt'.format(args.q_table_epoch)))
        lum_qtable, chrom_qtable = qTable[0], qTable[1]
        model_t.jpeg_layer.lum_qtable = nn.Parameter(lum_qtable.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(-1))
        model_t.jpeg_layer.chrom_qtable = nn.Parameter(chrom_qtable.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(-1))
    else:
        model_t = net(num_classes=num_classes)
        if args.finetune_model_path != None:
            print("load fine-tuned model.")
            model_t.load_state_dict(torch.load(args.finetune_model_path, weights_only=False)["model"])
        else:
            print("load pretrained model.")
            model_t.load_state_dict(torch.load(pretrain_model_path, weights_only=False)["model"])
    
    if args.train_mode:
        model_t.train()
    else:
        model_t.eval()
    # model_s.eval()
    model_t.to(device)
    model_s.to(device)

    data = torch.randn(2, 3, 32, 32).to(device)
    feat_s, _ = model_s(data, is_feat=True)
    feat_t, _ = model_t(data, is_feat=True)
    
    module_list = nn.ModuleList([])
    module_list.append(model_s)
    trainable_list = nn.ModuleList([])
    trainable_list.append(model_s)

    criterion_cls = nn.CrossEntropyLoss()
    
    # for wttm
    if args.add == 'kd':
        criterion_div = DistillKL(args.kd_T)
    elif args.add == 'ttm':
        criterion_div = TTM(args.ttm_l)
    elif args.add == 'wttm':
        criterion_div = WTTM(args.ttm_l)
    else:
        raise NotImplementedError(args.add)
    
    if args.distill == 'kd':
        criterion_kd = DistillKL(args.kd_T)
    elif args.distill == 'lskd':
        criterion_div = LSKDLoss(args.kd_T)
        criterion_kd = DistillKL(args.kd_T)
    elif args.distill == 'ttm':
        criterion_kd = TTM(args.ttm_l)
    elif args.distill == 'wttm':
        criterion_kd = WTTM(args.ttm_l)
    elif args.distill == 'crd':
        args.s_dim = feat_s[-1].shape[1]
        args.t_dim = feat_t[-1].shape[1]
        args.n_data = n_data
        criterion_kd = CRDLoss(args)
        module_list.append(criterion_kd.embed_s)
        module_list.append(criterion_kd.embed_t)
        trainable_list.append(criterion_kd.embed_s)
        trainable_list.append(criterion_kd.embed_t)
    elif args.distill == 'itrd':
        args.s_dim = feat_s[-1].shape[1]
        args.t_dim = feat_t[-1].shape[1]
        args.n_data = n_data
        criterion_kd = ITLoss(args)
        module_list.append(criterion_kd)
        trainable_list.append(criterion_kd)
        module_list.append(criterion_kd.embed)
        trainable_list.append(criterion_kd.embed)
    elif args.distill == 'dist':
        criterion_kd = DIST(args.dist_beta, args.dist_gamma, args.dist_tau)
    elif args.distill == 'dkd':
        criterion_kd = DKDloss(args.kd_T)
    elif args.distill == 'fitnet':
        criterion_kd = HintLoss()
        regress_s = ConvReg(feat_s[args.hint_layer].shape, feat_t[args.hint_layer].shape)
        module_list.append(regress_s)
        trainable_list.append(regress_s)
    elif args.distill == 'crd':
        args.s_dim = feat_s[-1].shape[1]
        args.t_dim = feat_t[-1].shape[1]
        args.n_data = n_data
        criterion_kd = CRDLoss(args)
        module_list.append(criterion_kd.embed_s)
        module_list.append(criterion_kd.embed_t)
        trainable_list.append(criterion_kd.embed_s)
        trainable_list.append(criterion_kd.embed_t)
    elif args.distill == 'at':
        criterion_kd = Attention()
    elif args.distill == 'nst':
        criterion_kd = NSTLoss()
    elif args.distill == 'sp':
        criterion_kd = Similarity()
    elif args.distill == 'rkd':
        criterion_kd = RKDLoss()
    elif args.distill == 'pkt':
        criterion_kd = PKT()
    elif args.distill == 'kdsvd':
        criterion_kd = KDSVD()
    elif args.distill == 'cc':
        criterion_kd = Correlation()
        embed_s = LinearEmbed(feat_s[-1].shape[1], args.feat_dim)
        embed_t = LinearEmbed(feat_t[-1].shape[1], args.feat_dim)
        module_list.append(embed_s)
        module_list.append(embed_t)
        trainable_list.append(embed_s)
        trainable_list.append(embed_t)
    elif args.distill == 'vid':
        s_n = [f.shape[1] for f in feat_s[1:-1]]
        t_n = [f.shape[1] for f in feat_t[1:-1]]
        criterion_kd = nn.ModuleList([VIDLoss(s, t, t) for s, t in zip(s_n, t_n)])
        # add this as some parameters in VIDLoss need to be updated
        trainable_list.append(criterion_kd)
    elif args.distill == 'ab':
        s_shapes = [f.shape for f in feat_s[1:-1]]
        t_shapes = [f.shape for f in feat_t[1:-1]]
        connector = Connector(s_shapes, t_shapes)
        # init stage training
        init_trainable_list = nn.ModuleList([])
        init_trainable_list.append(connector)
        init_trainable_list.append(model_s.get_feat_modules())
        criterion_kd = ABLoss(len(feat_s[1:-1]))
        init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, args)
        # classification
        module_list.append(connector)
    elif args.distill == 'ft':
        s_shape = feat_s[-2].shape
        t_shape = feat_t[-2].shape
        paraphraser = Paraphraser(t_shape)
        translator = Translator(s_shape, t_shape)
        # init stage training
        init_trainable_list = nn.ModuleList([])
        init_trainable_list.append(paraphraser)
        criterion_init = nn.MSELoss()
        init(model_s, model_t, init_trainable_list, criterion_init, train_loader, args)
        # classification
        criterion_kd = FactorTransfer()
        module_list.append(translator)
        module_list.append(paraphraser)
        trainable_list.append(translator)
    elif args.distill == 'fsp':
        s_shapes = [s.shape for s in feat_s[:-1]]
        t_shapes = [t.shape for t in feat_t[:-1]]
        criterion_kd = FSP(s_shapes, t_shapes)
        # init stage training
        init_trainable_list = nn.ModuleList([])
        init_trainable_list.append(model_s.get_feat_modules())
        init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, args)
        # classification training
        pass
    else:
        raise NotImplementedError(args.distill)

    criterion_list = nn.ModuleList([])
    criterion_list.append(criterion_cls)     # classification loss
    criterion_list.append(criterion_div)     # additional loss
    criterion_list.append(criterion_kd)      # distillation loss

    # optimizer
    optimizer = optim.SGD(trainable_list.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    
    # append teacher after optimizer to avoid weight_decay
    module_list.append(model_t)
    if torch.cuda.is_available():
        module_list.cuda()
        criterion_list.cuda()
        if not args.seed:
            cudnn.benchmark = True
    else:
        module_list = [module.to(device) for module in module_list]
        criterion_list = [criterion.to(device) for criterion in criterion_list]

    print("-------------------------------------------------------------")
    # validate teacher accuracy
    if args.JPEG_enable:
        teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, args)
        print('JPEG + teacher accuracy: ', teacher_acc.item())
            
        teacher_acc, _, _ = validate(val_loader_original, model_t.underlying_model, criterion_cls, args)
        print('teacher accuracy: ', teacher_acc.item())
        
        print("Loaded Teacher Validation Accuracy during training: ")
        quantizationTable = torch.cat((model_t.jpeg_layer.lum_qtable, model_t.jpeg_layer.chrom_qtable), 0)
        print(quantizationTable.min().item(),quantizationTable.max().item())
    else:
        teacher_acc, _, _ = validate(val_loader_original, model_t, criterion_cls, args)
        print('teacher accuracy: ', teacher_acc.item())
    test_acc_student, tect_acc_top5_student, test_loss_student = validate(val_loader_original, model_s, criterion_cls, args)
    print('student initial accuracy: ', test_acc_student)
    
    print("-------------------------------------------------------------")
            
    # routine
    for epoch in range(1, args.epochs + 1):
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            new_lr = args.learning_rate * (args.lr_decay_rate ** steps)
            for param_group in optimizer.param_groups:
                param_group['lr'] = new_lr
                    
        time1 = time.time()
        print("==> training...")
        train_acc, train_loss = train_distill(epoch, train_loader, module_list, criterion_list, optimizer, args)
        
        print("==> validating student...")
        test_acc_student, tect_acc_top5_student, test_loss_student = validate(val_loader_original, model_s, criterion_cls, args)
        time2 = time.time()
        print('epoch {}, total time {:.2f}\n'.format(epoch, time2 - time1))
        
        with open(args.log_fname, 'a') as log:
            log.write(str(test_acc_student.cpu().numpy()) + '\n')
        
        # save the best model
        if test_acc_student > best_acc:
            best_acc = test_acc_student
            state = {'epoch': epoch, 'model': model_s.state_dict(), 'best_acc': best_acc}
            save_file = os.path.join(args.save_folder, '{}_best.pth'.format(args.model_s))
            print('saving the best model!\n')
            # torch.save(state, save_file)
    
    print('best accuracy:', best_acc)
    
    # exp_txt = open('{}_experimetal_Results_KD.txt'.format(args.dataset), 'a+')
    # exp_txt.write(args.model_name +"\t"+ str(epoch) + "\t" + str(best_acc.item()) + "\n") # Write some text
    # exp_txt.close() # Close the file
    
    # save best accuracy
    with open(args.log_fname, 'a') as log:
        log.write('best: ' + str(best_acc.cpu().numpy())+'\n')

    with open(args.results_fname, 'a') as log:
        log.write('{}: '.format(args.experiment_name) + str(best_acc.cpu().numpy())+'\n')
    
    # save model
    # state = {'opt': args, 'model': model_s.state_dict(),}
    # save_file = os.path.join(args.save_folder, '{}_last.pth'.format(args.model_s))
    # torch.save(state, save_file)


if __name__ == '__main__':
    timeStart = time.time()
    main()
    timeEnd = time.time()
    
    print('\n==> Total time {:.2f} s.'.format(timeEnd - timeStart))
