import os
import socket
import numpy as np
import random
import json
import argparse
import torch
import time
import torch.nn as nn
import torch.backends.cudnn as cudnn
from helper.JPEG_layer import * 
from mdistiller.dist import utils
from mdistiller.models import cifar_model_dict, imagenet_model_dict, tiny_imagenet_model_dict
from mdistiller.distillers import distiller_dict
from mdistiller.dataset import get_dataset
from mdistiller.engine.utils import load_checkpoint, log_msg
from mdistiller.engine.cfg import CFG as cfg
# from mdistiller.engine.cfg import show_cfg
from mdistiller.engine import trainer_dict


cudnn.benchmark = True
hostname = socket.gethostname()


def main(cfg, args):
    utils.init_distributed_mode(args)

    if args.seed:
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True

    if args.JPEG_enable:
        opt_json_filepath = os.path.join(args.base_path, 'opt.json')
        with open(opt_json_filepath, 'r') as f:
            data = json.load(f)
        opt = argparse.Namespace(**data)
        experiment_name = opt.model_name + "_trial_" + str(args.trial)
    else:
        experiment_name = "vallina" + "_trial_" + str(args.trial)
    
    if args.finetune_model_path != None:
        experiment_name += args.finetune_model_path
        args.finetune_model_path = os.path.join('./save/{}/teacher/{}/{}/trial_1/last.pth'.format(cfg.DATASET.TYPE, 'Resnet34' if cfg.DISTILLER.TEACHER == "ResNet34" else "Resnet50", args.finetune_model_path))
    if args.train_mode:
        experiment_name += "_train_mode"
    args.experiment_name = os.path.join(cfg.EXPERIMENT.TAG, cfg.EXPERIMENT.PROJECT, experiment_name)
    
    # init dataloader & models
    train_loader, val_loader, num_data, num_classes = get_dataset(cfg, args)
    os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
    
    # vanilla
    if cfg.DISTILLER.TYPE == "NONE":
        if cfg.DATASET.TYPE == "imagenet":
            model_student = imagenet_model_dict[cfg.DISTILLER.STUDENT](pretrained=False)
        elif cfg.DATASET.TYPE == "tiny_imagenet":
            model_student = tiny_imagenet_model_dict[cfg.DISTILLER.STUDENT][0](num_classes=num_classes)
        else:
            model_student = cifar_model_dict[cfg.DISTILLER.STUDENT][0](num_classes=num_classes)
        distiller = distiller_dict[cfg.DISTILLER.TYPE](model_student)
    
    # distillation
    else:
        print(log_msg("Loading teacher model", "INFO"))
        print(cfg.DATASET.TYPE)
        if cfg.DATASET.TYPE == "imagenet":
            model_student = imagenet_model_dict[cfg.DISTILLER.STUDENT](pretrained=False)
            if args.JPEG_enable:
                if args.finetune_model_path != None:
                    print("load JPEG + fine-tuned model.")
                    underlying_model = imagenet_model_dict[cfg.DISTILLER.TEACHER](pretrained=False)
                    underlying_model.load_state_dict(torch.load(args.finetune_model_path)['model'])
                else:
                    print("load JPEG + pretrained model.")
                    underlying_model = imagenet_model_dict[cfg.DISTILLER.TEACHER](pretrained=True)
                jpeg_layer = JPEG_layers(opt=opt, img_shape=(opt.train_crop_size, opt.train_crop_size, 3), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))        
                model_teacher = CustomModel(jpeg_layer, underlying_model)
                # load q_table
                qTable = torch.load(os.path.join(opt.q_tables_folder, 'q_table_epoch_{}.pt'.format(args.q_table_epoch)))
                lum_qtable, chrom_qtable = qTable[0], qTable[1]
                model_teacher.jpeg_layer.lum_qtable = nn.Parameter(lum_qtable.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(-1))
                model_teacher.jpeg_layer.chrom_qtable = nn.Parameter(chrom_qtable.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(-1))
            else:
                if args.finetune_model_path != None:
                    print("load fine-tuned model.")
                    model_teacher = imagenet_model_dict[cfg.DISTILLER.TEACHER](pretrained=False)
                    model_teacher.load_state_dict(torch.load(args.finetune_model_path)['model'])
                else:
                    print("load pretrained model.")
                    model_teacher = imagenet_model_dict[cfg.DISTILLER.TEACHER](pretrained=True)
                
        elif cfg.DATASET.TYPE == "tiny_imagenet":
            net, pretrain_model_path = tiny_imagenet_model_dict[cfg.DISTILLER.TEACHER]
            assert (pretrain_model_path is not None), "no pretrain model for teacher {}".format(cfg.DISTILLER.TEACHER)
            model_teacher = net(num_classes=num_classes)
            model_teacher.load_state_dict(load_checkpoint(pretrain_model_path)["model"])
            model_student = tiny_imagenet_model_dict[cfg.DISTILLER.STUDENT][0](num_classes=num_classes)
            
        else:
            model_student = cifar_model_dict[cfg.DISTILLER.STUDENT][0](num_classes=num_classes)
            net, pretrain_model_path = cifar_model_dict[cfg.DISTILLER.TEACHER]
            assert (pretrain_model_path is not None), "no pretrain model for teacher {}".format(cfg.DISTILLER.TEACHER)
            if args.JPEG_enable:
                underlying_model = net(num_classes=num_classes)
                if args.finetune_model_path != None:
                    print("load JPEG + fine-tuned model.")
                    underlying_model.load_state_dict(load_checkpoint(args.finetune_model_path)["model"])
                else:
                    print("load JPEG + pretrained model.")
                    underlying_model.load_state_dict(load_checkpoint(pretrain_model_path)["model"])
                jpeg_layer = JPEG_layers(opt=opt, img_shape=train_loader.dataset.data[0].shape, mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761))        
                model_teacher = CustomModel(jpeg_layer, underlying_model)
                # load q_table
                qTable = torch.load(os.path.join(opt.q_tables_folder, 'q_table_epoch_{}.pt'.format(args.q_table_epoch)))
                lum_qtable, chrom_qtable = qTable[0], qTable[1]
                model_teacher.jpeg_layer.lum_qtable = nn.Parameter(lum_qtable.squeeze(0).unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(-1))
                model_teacher.jpeg_layer.chrom_qtable = nn.Parameter(chrom_qtable.squeeze(0).unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(-1))
            else:
                model_teacher = net(num_classes=num_classes)
                if args.finetune_model_path != None:
                    print("load fine-tuned model.")
                    model_teacher.load_state_dict(load_checkpoint(args.finetune_model_path)["model"])
                else:
                    print("load pretrained model.")
                    model_teacher.load_state_dict(load_checkpoint(pretrain_model_path)["model"])
            
        if args.distributed:
            if cfg.DISTILLER.TYPE == "CRD":
                distiller = distiller_dict[cfg.DISTILLER.TYPE](model_student, model_teacher, cfg, num_data, wrap_student_in_ddp=True, local_rank=args.gpu)
            else:
                distiller = distiller_dict[cfg.DISTILLER.TYPE](model_student, model_teacher, cfg, wrap_student_in_ddp=True, local_rank=args.gpu)
        else:
            if cfg.DISTILLER.TYPE == "CRD":
                distiller = distiller_dict[cfg.DISTILLER.TYPE](model_student, model_teacher, cfg, num_data)
            else:
                distiller = distiller_dict[cfg.DISTILLER.TYPE](model_student, model_teacher, cfg)
            distiller = torch.nn.DataParallel(distiller.cuda())
    
    if cfg.DISTILLER.TYPE != "NONE":
        if args.distributed:
            print(log_msg("Extra parameters of {}: {}\033[0m".format(cfg.DISTILLER.TYPE, distiller.get_extra_parameters()), "INFO"))
        else:
            print(log_msg("Extra parameters of {}: {}\033[0m".format(cfg.DISTILLER.TYPE, distiller.module.get_extra_parameters()), "INFO"))

    # train
    trainer = trainer_dict[cfg.SOLVER.TRAINER](args.experiment_name, distiller, train_loader, val_loader, cfg, args)

    # torch.autograd.set_detect_anomaly(True)
    if args.analysis:
        trainer.analysis()
    else:
        trainer.train(resume=args.resume)
    
    if args.distributed:
        torch.distributed.destroy_process_group()


if __name__ == "__main__":
    time1 = time.time()
    import argparse
    parser = argparse.ArgumentParser("training for knowledge distillation.")
    
    parser.add_argument("--cfg", type=str, default="")
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--analysis", action="store_true")
    parser.add_argument("--train_mode", action="store_true")
    parser.add_argument("--JPEG_enable", action="store_true")
    parser.add_argument("--trial", type=int, default=1)
    parser.add_argument('--seed', type=int, default=0, help='seed id, set to 0 if do not want to fix the seed')
    parser.add_argument("--log", type=str, default="")
    parser.add_argument("--base_path", type=str, default=None)
    parser.add_argument("--finetune_model_path", type=str, default=None)
    parser.add_argument("--q_table_epoch", type=int, default=1, help='the epoch of JPEG layer qtable')
    parser.add_argument("--distributed", action="store_true")
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
    parser.add_argument("options", default=None, nargs=argparse.REMAINDER)
    
    args = parser.parse_args()
    cfg.merge_from_file(args.cfg)
    cfg.merge_from_list(args.options)
    cfg.freeze()
    
    main(cfg, args)

    time2 = time.time()
    print('total time: {}'.format(time2 - time1))