from functools import partial
from fvcore.common.param_scheduler import MultiStepParamScheduler

from detectron2 import model_zoo
from detectron2.config import LazyCall as L
import detectron2.data.transforms as T
from detectron2.solver import WarmupParamScheduler
from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate
from detectron2.data.datasets import register_coco_instances

# dataset
register_coco_instances("neuralens_train", {}, "./DataSet/Neurlens/annotations/instances_train.json", "./DataSet/Neurlens/train/")
register_coco_instances("neuralens_val", {}, "./DataSet/Neurlens/annotations/instances_val.json", "./DataSet/Neurlens/val/")

image_size = 512
dataloader = model_zoo.get_config("common/data/coco.py").dataloader
dataloader.train.dataset.names = "neuralens_train"
dataloader.train.mapper.augmentations = [
    L(T.RandomFlip)(horizontal=True),  # flip first
    L(T.ResizeScale)(
        min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size
    ),
    L(T.FixedSizeCrop)(crop_size=(image_size, image_size), pad=False),
]
dataloader.train.mapper.image_format = "RGB"
dataloader.train.total_batch_size = 64
# recompute boxes due to cropping
dataloader.train.mapper.recompute_boxes = False
dataloader.train.mapper.use_instance_mask = False

dataloader.test.dataset.names = "neuralens_val"
dataloader.test.mapper.augmentations = [
    L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size),
]


# model
model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model
model.backbone.net.drop_path_rate = 0.1
model.backbone.net.img_size = image_size
model.backbone.square_pad = image_size
model.roi_heads.num_classes = 2
model.roi_heads.mask_in_features = None
model.pixel_mean = [37.647, 36.028, 34.562]
model.pixel_std = [22.318, 17.559, 14.245]

# Initialization and trainer settings
train = model_zoo.get_config("common/train.py").train
train.amp.enabled = True
train.ddp.fp16_compression = True
train.init_checkpoint = (
    "base ckpt"
)


# Schedule
train.max_iter = 32203
train.eval_period = 2000
train.log_period = 100
train.output_dir = "./run/base"

lr_multiplier = L(WarmupParamScheduler)(
    scheduler=L(MultiStepParamScheduler)(
        values=[1.0, 0.1, 0.01],
        milestones=[28982, 30592],
        num_updates=train.max_iter,
    ),
    warmup_length=500 / train.max_iter,
    warmup_factor=0.001,
)


# Optimizer
optimizer = model_zoo.get_config("common/optim.py").AdamW
optimizer.lr = 5e-4
optimizer.weight_decay = 0.1
optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7)
optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
