# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import argparse
import datetime
import os
import random
import time
from pathlib import Path
import sys
import shutil

import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler

import datasets
import util.misc as utils
from datasets import build_dataset
from engine import test
from models import build_model


def get_outdir(output_dir):
    # load config
    name = str(datetime.datetime.now().strftime("%y%m%d-%H%M%S"))
    output_dir = os.path.join(output_dir, name)
    return output_dir


def get_args_parser():
    parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
    parser.add_argument('--phaseII', default=True, type=bool,
                        help="Whether to train phaseII")
    parser.add_argument('--lr', default=5e-4, type=float)
    parser.add_argument('--lr_backbone', default=5e-4, type=float)
    parser.add_argument('--batch_size', default=4, type=int)
    parser.add_argument('--num_workers', default=4, type=int)
    parser.add_argument('--weight_decay', default=1e-4, type=float)
    parser.add_argument('--epochs', default=240, type=int)
    parser.add_argument('--lr_drop', default=[260], type=list)
    parser.add_argument('--clip_max_norm', default=0.1, type=float,
                        help='gradient clipping max norm')

    # Model parameters
    parser.add_argument('--frozen_weights', type=str, default=None,
                        help="Path to the pretrained model. If set, only the mask head will be trained")
    parser.add_argument('--pos_refinement', default=False, type=bool,
                        help="Whether to refine the position with PE")  # always true

    # * Backbone
    parser.add_argument('--backbone', default='resnet50', type=str,
                        help="Name of the convolutional backbone to use")  # resnet50 hourglassnet
    parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned', 'sine1d'),
                        help="Type of positional embedding to use on top of the image features")

    # * Transformer
    parser.add_argument('--deformable', default=True, type=bool,
                        help="Whether to use deformable detr")
    parser.add_argument('--return_interm_layers', default=True, type=bool,
                        help="Whether to use multiply layers in feature extractor")
    parser.add_argument('--enc_layers', default=6, type=int,
                        help="Number of encoding layers in the transformer")
    parser.add_argument('--dec_layers', default=6, type=int,
                        help="Number of decoding layers in the transformer")
    parser.add_argument('--dim_feedforward', default=1024, type=int,  # deformable 1024
                        help="Intermediate size of the feedforward layers in the transformer blocks")
    parser.add_argument('--hidden_dim', default=128, type=int,
                        help="Size of the embeddings (dimension of the transformer)")
    parser.add_argument('--dropout', default=0.1, type=float,
                        help="Dropout applied in the transformer")
    parser.add_argument('--nheads', default=4, type=int,
                        help="Number of attention heads inside the transformer's attentions")
    parser.add_argument('--num_queries', default=2000, type=int,
                        help="Number of query slots")  # this is also control the num nodes; modify yita in vpd at the same time.
    parser.add_argument('--pre_norm', action='store_true')

    # * Matcher
    parser.add_argument('--set_cost_class', default=1, type=float,
                        help="Class coefficient in the matching cost")
    parser.add_argument('--set_cost_line', default=5, type=float,
                        help="L1 box coefficient in the matching cost")
    parser.add_argument('--set_cost_point', default=5, type=float,
                        help="L1 box coefficient in the matching cost")

    # Loss
    parser.add_argument('--aux_loss', default=True, type=bool,
                        help="Auxiliary decoding losses (loss at each layer)")

    # * Loss coefficients
    parser.add_argument('--aux_loss_coef', default=0.8, type=float)
    parser.add_argument('--loss_cls_coef', default=1, type=float)
    parser.add_argument('--loss_line_coef', default=10, type=float)
    parser.add_argument('--loss_junction_map_coef', default=1, type=float)
    parser.add_argument('--loss_line_map_coef', default=1, type=float)
    parser.add_argument('--loss_score_coef', default=1, type=float)
    parser.add_argument('--loss_dis_coef', default=0, type=float)
    parser.add_argument('--loss_endpoint_coef', default=10, type=float)
    parser.add_argument('--cls_weight', default=2, type=float)
    parser.add_argument('--score_weight', default=10, type=float)
    parser.add_argument('--line_weight', default=[8, 4, 2, 1], type=float)
    parser.add_argument('--endpoint_weight', default=[64, 16, 4, 1], type=float)

    # dataset parameters
    parser.add_argument('--dataset_file', default='wireframe')
    parser.add_argument('--num_sample', default=300)
    parser.add_argument('--fsize', default=128)
    parser.add_argument('--remove_difficult', action='store_true')

    parser.add_argument('--output_dir', default='/disk0/projects/HoughTransformer2/out/wireframe',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--visual', action='store_true')
    parser.add_argument('--error_save_path', default=None)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    return parser


def main(args):
    # torch.autograd.set_detect_anomaly(True)
    assert (args.eval and args.test) is False
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))

    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
    if args.test is True:
        args.aux_loss = False
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)

    dataset_train = build_dataset(image_set='train', args=args)
    dataset_val = build_dataset(image_set='valid', args=args)
    dataset_test = build_dataset(image_set='test', args=args)

    model, criterion, postprocessors = build_model(args)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
        sampler_val = DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(
        sampler_train, args.batch_size, drop_last=True)

    data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn, num_workers=args.num_workers)
    data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=None,
                                 drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
    data_loader_test = DataLoader(dataset_test, args.batch_size, sampler=None,
                                 drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)


    if args.frozen_weights is not None:
        checkpoint = torch.load(args.frozen_weights, map_location='cpu')
        model_without_ddp.detr.load_state_dict(checkpoint['model'])

    output_dir = Path(args.output_dir)
    print("output_dir:", output_dir)

    if args.phaseII and not args.test:
        args.resume = "/241023-193144/checkpoint.pth"
        pass

    if args.resume:
        if args.resume.startswith('https') and "deformable" not in args.resume:
            checkpoint = torch.hub.load_state_dict_from_url(
                args.resume, map_location='cpu', check_hash=True)
            new_state_dict = {}
            for k in checkpoint['model']:
                if ("class_embed" in k) or ("query_embed" in k):
                    continue
                if "input_proj" in k:
                    new_state_dict[k[:10] + '.0' + k[10:]] = checkpoint['model'][k]
                    continue
                if "backbone" in k:
                    new_state_dict[k[:8]+k[10:]] = checkpoint['model'][k]  # remove jointer, thus loss a 0 in key
                    continue
                new_state_dict[k] = checkpoint['model'][k]
            model_without_ddp.load_state_dict(new_state_dict, strict=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
            phaseI_model = {k: v for k, v in checkpoint['model'].items() if 'decoder' not in k and 'p2' not in k}
            if args.test or args.eval:
                phaseI_model = {k: v for k, v in checkpoint['model'].items()}  # eval时删掉 if 'decoder' not in k and 'p2' not in k
            missing_keys, unexpected_keys = model_without_ddp.load_state_dict(phaseI_model, strict=False)
            if len(missing_keys) > 0:
                print('Missing Keys: {}'.format(missing_keys))
            if len(unexpected_keys) > 0:
                print('Unexpected Keys: {}'.format(unexpected_keys))

    len(list(model_without_ddp.named_parameters()))
    len([p for n, p in model_without_ddp.named_parameters() if "backbone" in n])

    param_dicts = [
        {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n]},
        {
            "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n],
            "lr": args.lr_backbone,
        },
    ]  # 放在freeze之后直接按照requires_grad
    if args.phaseII:
        # freeze network
        args.epochs = 120
        args.lr_drop = [40, 80]
        # finetune version
        param_dicts = [
            {"params": [p for n, p in model_without_ddp.named_parameters() if
                        "decoder" in n or "p2" in n or "tgt_embed" in n or "query_embed" in n]},
            {
                "params": [p for n, p in model_without_ddp.named_parameters() if
                        "decoder" not in n and "p2" not in n and "tgt_embed" not in n and "query_embed" not in n],
                "lr": args.lr * 0.2,
            },
        ]

    optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_drop)

    if args.test:
        test(model, criterion, postprocessors, data_loader_test, device, args)
        return

    if args.eval:  # for debug
        evaluate(model, criterion, postprocessors, data_loader_val, device, args)
        return

    print("Start training")
    start_time = time.time()
    best_eval = 0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(
            model, criterion, data_loader_train, optimizer, device, epoch,
            args.clip_max_norm)
        lr_scheduler.step()
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            # if epoch >= args.epochs - 10:
            #     checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master({
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'args': args,
                }, checkpoint_path)

        eval_stats = evaluate(
            model, criterion, postprocessors, data_loader_val, device, args
        )

        if eval_stats > best_eval:
            if utils.is_main_process():
                best_eval = eval_stats
                shutil.copyfile(output_dir / 'checkpoint.pth', output_dir / 'best.pth')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    # os.environ['CUDA_VISIBLE_DEVICES'] = "4,5,6,7"
    parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
    args = parser.parse_args()
    try:
        if args.output_dir  and utils.is_main_process() and not args.test and not args.eval:
            args.output_dir = get_outdir(args.output_dir)
            Path(args.output_dir).mkdir(parents=True, exist_ok=True)
            sys.stdout = utils.Logger(os.path.join(args.output_dir, 'train.txt'))
        main(args)
    except BaseException:
        if args.output_dir and utils.is_main_process() and not args.test and not args.eval:
            shutil.rmtree(args.output_dir)
        raise
