# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import argparse
import datetime
import json
import random
import time
from pathlib import Path

import numpy as np
import torch
from torch import tensor
from torch.utils.data import DataLoader, DistributedSampler

import datasets
import util.misc as utils
from datasets import build_dataset, get_coco_api_from_dataset
from engine import *
from models import build_model
import os
import sys
sys.path.append('../')
from quant_detr import letsquant, letsquant_teacher
import pdb

import logging
from logging import handlers
from sys import path

from torch.utils.tensorboard import SummaryWriter
import copy
from util.plot_utils import *


class Logger(object):
    level_relations = {
        'debug': logging.DEBUG,
        'info': logging.INFO,
        'warning': logging.WARNING,
        'error': logging.ERROR,
        'crit': logging.CRITICAL
    }

    # def __init__(self,filename,printflag=False,level='info',when='D',backCount=3,fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'):
    def __init__(self, filename, printflag=False, level='info', when='D', backCount=3, fmt='%(message)s'):
        self.logger = logging.getLogger(filename)
        format_str = logging.Formatter(fmt)
        self.logger.setLevel(self.level_relations.get(level))
        if printflag:
            sh = logging.StreamHandler()
            sh.setFormatter(format_str)
            self.logger.addHandler(sh)
        th = handlers.TimedRotatingFileHandler(filename=filename, when=when,
                                               backupCount=backCount, encoding='utf-8')  
        th.setFormatter(format_str)
        self.logger.addHandler(th)


def add_quant_args(parser):
    # quant_class_embed=args.quant_class_embed
    parser.add_argument('--quant-class-embed', action='store_true')
    parser.add_argument('--quant-bbox-embed', action='store_true')
    parser.add_argument('--quant-input-proj', action='store_true')
    parser.add_argument('--quant-backbone', action='store_true')
    parser.add_argument('--quant-encoder', action='store_true')
    parser.add_argument('--quant-decoder', action='store_true')
    parser.add_argument('--quant-softmax', action='store_true')
    # ------------ quant_detr_segm ------------
    parser.add_argument('--quant-bbox-atten', action='store_true')
    parser.add_argument('--quant-mask-head', action='store_true')

    parser.add_argument('--ILP', default='', help='whether use ILP bit_config or not')

    parser.add_argument('--quant-scheme',
                        type=str,
                        default=None,
                        help='(old_default = detr8w8a) quantization bit configuration')
    parser.add_argument('--bias-bit',
                        type=int,
                        default=16,
                        help='quantizaiton bit-width for bias  (old_config_bit = 32) ')
    parser.add_argument('--channel-wise',
                        action='store_false',
                        help='whether to use channel-wise quantizaiton or not')
    parser.add_argument('--act-range-momentum',
                        type=float,
                        default=0.99,
                        help='momentum of the activation range moving average, '
                             '-1 stands for using minimum of min and maximum of max')
    parser.add_argument('--act-percentile',
                        type=float,
                        default=0,
                        help='the percentage used for activation percentile'
                             '(0 means no percentile, 99.9 means cut off 0.1%)')
    parser.add_argument('--weight-percentile',
                        type=float,
                        default=0,
                        help='the percentage used for weight percentile'
                             '(0 means no percentile, 99.9 means cut off 0.1%)')
    parser.add_argument('--fix-BN',
                        action='store_true',
                        help='whether to fix BN statistics and fold BN during training')
    parser.add_argument('--fix-BN-threshold',
                        type=int,
                        default=None,
                        help='when to start training with fixed and folded BN,'
                             'after the threshold iteration, the original fix-BN will be overwritten to be True')
    parser.add_argument('--checkpoint-iter',
                        type=int,
                        default=-1,
                        help='the iteration that we save all the featuremap for analysis')
    parser.add_argument('--fixed-point-quantization',
                        action='store_true',
                        help='whether to skip deployment-oriented operations and '
                             'use fixed-point rather than integer-only quantization')


def get_args_parser():
    parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
    add_quant_args(parser)
    parser.add_argument('--lr', default=1e-5, type=float)
    parser.add_argument('--lr_backbone', default=1e-5, type=float)
    parser.add_argument('--batch_size', default=2, type=int)
    parser.add_argument('--weight_decay', default=1e-4, type=float)
    parser.add_argument('--epochs', default=300, type=int)
    parser.add_argument('--lr_drop', default=200, type=int)
    parser.add_argument('--clip_max_norm', default=0.1, type=float,
                        help='gradient clipping max norm')

    # Model parameters
    parser.add_argument('--frozen_weights', type=str, default=None,
                        help="Path to the pretrained model. If set, only the mask head will be trained")
    # * Backbone
    parser.add_argument('--backbone', default='resnet50', type=str,
                        help="Name of the convolutional backbone to use")
    parser.add_argument('--dilation', action='store_true',
                        help="If true, we replace stride with dilation in the last convolutional block (DC5)")
    parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
                        help="Type of positional embedding to use on top of the image features")

    # * Transformer
    parser.add_argument('--enc_layers', default=6, type=int,
                        help="Number of encoding layers in the transformer")
    parser.add_argument('--dec_layers', default=6, type=int,
                        help="Number of decoding layers in the transformer")
    parser.add_argument('--dim_feedforward', default=2048, type=int,
                        help="Intermediate size of the feedforward layers in the transformer blocks")
    parser.add_argument('--hidden_dim', default=256, type=int,
                        help="Size of the embeddings (dimension of the transformer)")
    parser.add_argument('--dropout', default=0.1, type=float,
                        help="Dropout applied in the transformer")
    parser.add_argument('--nheads', default=8, type=int,
                        help="Number of attention heads inside the transformer's attentions")
    parser.add_argument('--num_queries', default=100, type=int,
                        help="Number of query slots")
    parser.add_argument('--pre_norm', action='store_true')

    # * Segmentation
    parser.add_argument('--masks', action='store_true',
                        help="Train segmentation head if the flag is provided")

    # Loss
    parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
                        help="Disables auxiliary decoding losses (loss at each layer)")
    # * Matcher
    parser.add_argument('--set_cost_class', default=1, type=float,
                        help="Class coefficient in the matching cost")
    parser.add_argument('--set_cost_bbox', default=5, type=float,
                        help="L1 box coefficient in the matching cost")
    parser.add_argument('--set_cost_giou', default=2, type=float,
                        help="giou box coefficient in the matching cost")
    # * Loss coefficients
    parser.add_argument('--mask_loss_coef', default=1, type=float)
    parser.add_argument('--dice_loss_coef', default=1, type=float)
    parser.add_argument('--bbox_loss_coef', default=5, type=float)
    parser.add_argument('--giou_loss_coef', default=2, type=float)
    parser.add_argument('--eos_coef', default=0.1, type=float,
                        help="Relative classification weight of the no-object class")
    parser.add_argument('--hero_loss_coef', default=0, type=float)
    parser.add_argument('--mode', type=str)

    # dataset parameters
    parser.add_argument('--dataset_file', default='coco')
    parser.add_argument('--coco_path', type=str)
    parser.add_argument('--coco_panoptic_path', type=str)
    parser.add_argument('--remove_difficult', action='store_true')

    parser.add_argument('--output_dir', default='',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--num_workers', default=2, type=int)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')

    # ----------- huanrui_regularization -----------
    parser.add_argument('--huanrui_hessian', action="store_true",
                        help="use huanrui_hessian.")
    parser.add_argument('--grad_L2', default=0, type=float,
                        help='if grad_L2 = 0, normally training; else training with regularization')
    parser.add_argument('--grad_L2_base', default=1e-3, type=float,
                        help='Starting regualrization strength for linear scheduling')
    parser.add_argument('--per_class', action="store_true",
                        help="whether to use per-class or not when eval")
    parser.add_argument('--super_category', type=str, default=None,
                        help="whether to use super_category or not when eval(give the path of super_category.json")
    parser.add_argument('--qat_sp', type=str, default=None,
                        help="whether to use super_category or not when QAT(give the path of super_category.json")
    parser.add_argument('--mAP_super_category', type=str, default=None,
                        help="whether to use super_category or not when eval with New coco indicator(give the path of super_category.json")
    parser.add_argument('--mAP_write_xls_path', type=str, default=None,
                        help="whether to use per-class and write xls or not when eval")
    parser.add_argument('--loss_write_xls_path', type=str, default=None,
                        help="whether to write loss into xls or not when eval")
    parser.add_argument('--grad_norm', action="store_true",
                        help="whether to use grad_norm or not when eval")
    parser.add_argument('--resume_custom', type=str, default=None,
                        help="grad_norm_npy_folder_path when using grad_norm at evaluation")
    parser.add_argument('--dual_loss', action="store_true",
                        help="train with both losses instead of regularization.")

    # ------------ distillation_teacher_model ------------

    parser.add_argument('--teacher', action="store_true",
                        help="use distillation training.")
    parser.add_argument('--teacher_Qmodel_scheme', default='', help='teacher_Qmodel quantization bit configuration')
    parser.add_argument('--teacher_Qmodel_ILP', default='',
                        help='teacher_Qmodel quantization bit configuration ILP file path')
    parser.add_argument('--KL_alpha', default=1, type=float, help='logit_loss (the coefficient of KL loss)')
    parser.add_argument('--hub_delta', default=1, type=float,
                        help='(hubber loss param)when hub_delta = 1, bbox_loss is MSEloss')
    parser.add_argument('--bbox_loss', default='L1', help='bbox_loss_function_type: L1 / hub')
    parser.add_argument('--distillation_aux', action='store_true',
                        help="whether compute aux_outputs loss when distillation")

    # ------------ compute_KL or MSE ------------
    parser.add_argument('--compute_KL', action='store_true')
    parser.add_argument('--float_result_folder_path', default='./KL_result/float',
                        help='path to save float result numpy file')
    parser.add_argument('--quant_result_folder_path', default='./KL_result/quant',
                        help='path to save quant result numpy file')
    parser.add_argument('--KL_path', default='./KL_result/KL.txt', help='path to save KL_result file')

    parser.add_argument('--interlayer', default='backbone', type=str,
                        help="Intermediate layer Name([ backbone, transformer_encoder_layer, transformer_decoder_layer]) is used to compute the output by KL or MSE")
    parser.add_argument('--compute_MSE', action='store_true')
    parser.add_argument('--MSE_weight', default=1, type=float,
                        help='decoder layer output (the coefficient of MSE loss)')

    return parser


def main(args):
    torch.multiprocessing.set_start_method('spawn')

    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))

    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # if args.modelname:
    #     model, criterion, postprocessors = build_model_main(args)
    # else:
    model, criterion, criterion_reg, postprocessors = build_model(args)

    model.to(device)
    model_without_ddp = model

    if args.teacher:
        if args.teacher_Qmodel_scheme == '':
            if args.masks:
                model_teacher = model
                # --resume https://dl.fbaipublicfiles.com/detr/detr-r50-panoptic-00ce5173.pth \
                checkpoint_segm = torch.load(args.frozen_weights, map_location='cpu')
                model_teacher.load_state_dict(checkpoint_segm['model'])
            else:
                if args.dataset_file == 'coco_panoptic':
                    temp_qb = args.quant_backbone
                    args.quant_backbone=False
                    model_teacher, _, _, _ = build_model(args)
                    model_teacher.to(device)
                    args.quant_backbone=temp_qb
                    if args.resume.startswith('https'):
                        checkpoint_teacher = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True)
                    else:
                        checkpoint_teacher = torch.load(args.resume, map_location='cpu')
                    ck = dict(checkpoint_teacher['model'].copy())
                    new_ck = ck.copy()
                    for (key, v) in ck.items():
                        new_key = key.replace("detr.","")
                        new_ck[new_key] = new_ck.pop(key)
                    model_teacher.load_state_dict(new_ck, strict=False)
                else:
                    model_teacher = torch.hub.load('facebookresearch/detr:main',
                                                   f'detr_{args.backbone}', pretrained=True).cuda()
        else:
            model_teacher = letsquant_teacher(model_c_without_ddp, log, args=args)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

        if args.teacher:
            model_teacher = torch.nn.parallel.DistributedDataParallel(
                model_teacher, device_ids=[args.gpu], find_unused_parameters=True)
            # model_t_without_ddp = model_teacher.module

    for name, p in model.named_parameters():
        p.requires_grad = True

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    if args.masks:
        i = 0
        segm_parameters = 0
        for name, p in model.named_parameters():
            i += 1

            if 'bbox_attention' in name or 'mask_head' in name:
                if p.requires_grad:
                    segm_parameters += p.numel()
        print('number of Segm_head params:', segm_parameters)

    param_dicts = [
        {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    dataset_train = build_dataset(image_set='train', args=args)
    dataset_val = build_dataset(image_set='val', args=args)

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
        sampler_val = DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(
        sampler_train, args.batch_size, drop_last=True)

    data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn, num_workers=args.num_workers)

    data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
                                 drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)

    if args.dataset_file == "coco_panoptic":
        # We also evaluate AP during panoptic training, on original coco DS
        coco_val = datasets.coco.build("val", args)
        base_ds = get_coco_api_from_dataset(coco_val)
    else:
        base_ds = get_coco_api_from_dataset(dataset_val)

    if args.frozen_weights is not None:
        model_without_ddp = letsquant(model_without_ddp, log, args=args).cuda()
        checkpoint = torch.load(args.frozen_weights, map_location='cpu')
        model_without_ddp.detr.load_state_dict(checkpoint['model'], strict=False)
        #model_without_ddp.load_state_dict(checkpoint['model'])

    writer = SummaryWriter(args.output_dir)
    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                args.resume, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        ck = dict(checkpoint['model'].copy())
        new_ck = ck.copy()
        for (key, v) in ck.items():
            new_key = key.replace("detr.","")
            new_ck[new_key] = new_ck.pop(key)
        model_without_ddp.load_state_dict(new_ck,strict=False)

        model_without_ddp = letsquant(model_without_ddp, log, args=args).cuda()

        if args.resume_custom is not None:
            checkpoint_r = torch.load(args.resume_custom, map_location='cpu')
            model_without_ddp.load_state_dict(checkpoint_r['model'], strict=False)

            if not args.eval and 'optimizer' in checkpoint_r and 'lr_scheduler' in checkpoint_r and 'epoch' in checkpoint_r:
                optimizer.load_state_dict(checkpoint_r['optimizer'])
                lr_scheduler.load_state_dict(checkpoint_r['lr_scheduler'])
                args.start_epoch = checkpoint_r['epoch'] + 1

    #if not args.resume and args.quant_scheme:
        # \
        # and args.frozen_weights is None:
    #    model_without_ddp = letsquant(model_without_ddp, log, args=args).cuda()

    if args.teacher:
        # if args.teacher_Qmodel_scheme == '':
        #     model_teacher = torch.hub.load('facebookresearch/detr:main', 'detr_resnet50', pretrained=True).cuda()
        # else:
        #     model_teacher = letsquant_teacher(model_c_without_ddp,log,args=args)

        model_teacher.aux_loss = True

        if args.bbox_loss == 'L1':
            criterion_teacher = detr_criterion_teacher(args, T=6, bbox_fun='L1', bxloss_style='mean', KL_alpha=args.KL_alpha,
                                                       delta=args.hub_delta, distillation_aux=args.distillation_aux)
        elif args.bbox_loss == 'hub':
            criterion_teacher = detr_criterion_teacher(args, T=6, bbox_fun='hubber', bxloss_style='mean', KL_alpha=args.KL_alpha,
                                                       delta=args.hub_delta, distillation_aux=args.distillation_aux)
        else:
            print('bbox_loss_function only has [ L1 / hub ] now')

        for p in model_teacher.parameters():
            p.requires_grad = False
        model_teacher.eval()

    # elif args.super_category is not None:
    #     criterion = detr_criterion_super(T=6, bbox_fun='L1', bxloss_style='mean', KL_alpha=args.KL_alpha,
    #                                                    delta=args.hub_delta, distillation_aux=args.distillation_aux)

    if args.eval:
        if not args.eval_mini_dataset:
            test_stats, coco_evaluator = evaluate(args, log, model, criterion, postprocessors,
                                                  data_loader_val, base_ds, device, args.output_dir)
        else:
            test_stats, coco_evaluator = evaluate(args, log, model, criterion, postprocessors,
                                                  data_loader_train, base_ds, device, args.output_dir)

        if args.output_dir:
            utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
        return

    for p in model.parameters():
        p.requires_grad = True

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        if args.teacher:
            train_stats = train_one_epoch_teacher(args,
                                                  log, writer, model, model_teacher, criterion, criterion_teacher,
                                                  criterion_reg, data_loader_train, optimizer, device, epoch,
                                                  args.clip_max_norm)
        else:
            train_stats = train_one_epoch(args, writer,
                                          model, criterion, data_loader_train, optimizer, device, epoch,
                                          args.clip_max_norm)
        lr_scheduler.step()
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            # extra checkpoint before LR drop and every 100 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
                checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master({
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'args': args,
                }, checkpoint_path)

        args.super_category = None
        test_stats, coco_evaluator = evaluate(args, log,
                                              model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir)

        args.super_category = args.qat_sp
        sp_test_stats, sp_coco_evaluator = evaluate(args, log,
                                                    model, criterion_reg, postprocessors, data_loader_val,
                                                    base_ds, device, args.output_dir)

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     **{f'test_{k}': v for k, v in test_stats.items()},
                     **{f'sploss_test_{k}': v for k, v in sp_test_stats.items()},
                     'epoch': epoch,
                     'n_parameters': n_parameters}

        args.super_category = None
        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")
            for k in log_stats.keys():
                if k == 'test_coco_eval_bbox' or k == 'test_coco_eval_masks':
                    writer.add_scalar('mAP_all', log_stats[k][0], epoch)
                    mAP = log_stats[k][0]
                elif k == 'sploss_test_coco_eval_bbox' or k == 'sploss_test_coco_eval_masks':
                    writer.add_scalar('sploss_mAP_all', log_stats[k][0], epoch)
                    mAP_SP = log_stats[k][0]
                elif isinstance(log_stats[k], dict):
                    print(f'{k}:{log_stats[k]}')
                else:
                    writer.add_scalar(k, log_stats[k], epoch)
            writer.add_scalar('mAP_diff', mAP-mAP_SP, epoch)

            # for evaluation logs
            if coco_evaluator is not None:
                (output_dir / 'eval').mkdir(exist_ok=True)
                if "bbox" in coco_evaluator.coco_eval:
                    filenames = ['latest.pth']
                    if epoch % 50 == 0:
                        filenames.append(f'{epoch:03}.pth')
                    for name in filenames:
                        torch.save(coco_evaluator.coco_eval["bbox"].eval,
                                   output_dir / "eval" / name)

    if args.masks:
        # segm model
        torch.save(model, f"DetrSegm_{args.epochs}epoch.pth")  
    else:
        # detr model
        if args.dataset_file == 'coco_panoptic':
            torch.save(model, f"Segm-Detr-part_{args.epochs}epoch.pth")  

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
    writer.close()


if __name__ == '__main__':
    # log_dir = './output_plot'
    # plot_logs(log_dir)

    savepath = './log'
    if not os.path.exists(savepath):
        os.makedirs(savepath)
    logpath = savepath + '/run_1223.log'
    # if os.path.exists(logpath):
    #     os.remove(logpath)
    log = Logger(logpath, level='info')

    parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
    args = parser.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)
