# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from hydra.core.config_store import ConfigStore

from cosmos_transfer2._src.imaginaire.lazy_config import LazyCall as L
from cosmos_transfer2._src.imaginaire.utils.checkpoint_db import get_checkpoint_path
from cosmos_transfer2._src.predict2.conditioner import ReMapkey
from cosmos_transfer2.multiview_config import DEFAULT_CHECKPOINT

dataset_dir = "/home/dancer/dataset/waymo_transfer2/training/"

transfer2_waymo_multiview_post_train = dict(
    defaults=[
        f"/experiment/{DEFAULT_CHECKPOINT.experiment}",
        {"override /data_train": "waymo_multiview_train_data_control_input_hdmap"},
    ],
    dataloader_train=dict(
        dataset=dict(
            dataset_dir=dataset_dir,
        ),
    ),
    dataloader_val=dict(
        dataset=dict(
            dataset_dir=dataset_dir,
        ),
    ),
    job=dict(project="cosmos_transfer_v2p5", group="auto_multiview", name="2b_cosmos_multiview_post_train_example"),
    optimizer=dict(
        lr=2e-5,
    ),
    checkpoint=dict(
        save_iter=1000,
        # pyrefly: ignore  # missing-attribute
        load_path=get_checkpoint_path(DEFAULT_CHECKPOINT.s3.uri),
        load_training_state=False,
        strict_resume=False,
        load_from_object_store=dict(
            enabled=False,  # Loading from local filesystem, not S3
        ),
        save_to_object_store=dict(
            enabled=False,
        ),
    ),
    model=dict(
        config=dict(
            base_load_from=None,
            ema=dict(
                enabled=True,
            ),
            condition_locations=['no_cam'],
            net=dict(
                sac_config=dict(
                    mode="predict2_2b_720_aggressive",
                    every_n_blocks=1,
                ),
            ),
            conditioner=dict(
                control_input_hdmap_video=L(ReMapkey)(
                    input_key="control_input_sparse_video",
                    output_key="control_input_sparse_video",
                    dropout_rate=0.0,
                    dtype=None,
                ),
                control_input_sparse_mask=L(ReMapkey)(
                    input_key="control_input_sparse_mask",
                    output_key="control_input_sparse_mask",
                    dropout_rate=0.0,
                    dtype=None,
                ),
                control_input_sparse_ctrl=L(ReMapkey)(
                    input_key="control_input_sparse_ctrl",
                    output_key="control_input_sparse_ctrl",
                    dropout_rate=0.0,
                    dtype=None,
                ),
                control_input_reference_video=L(ReMapkey)(
                    input_key="control_input_reference_video",
                    output_key="control_input_reference_video",
                    dropout_rate=0.0,
                    dtype=None,
                ),
                control_input_reference_mask=L(ReMapkey)(
                    input_key="control_input_reference_mask",
                    output_key="control_input_reference_mask",
                    dropout_rate=0.0,
                    dtype=None,
                ),
                control_input_reference_ctrl=L(ReMapkey)(
                    input_key="control_input_reference_ctrl",
                    output_key="control_input_reference_ctrl",
                    dropout_rate=0.0,
                    dtype=None,
                ),
            ),
        ),
    ),
    trainer=dict(
        grad_accum_iter=1,
        logging_iter=500,
        max_iter=40000,
        callbacks=dict(
            heart_beat=dict(
                save_s3=False,
            ),
            iter_speed=dict(
                hit_thres=200,
                save_s3=False,
            ),
            device_monitor=dict(
                save_s3=False,
            ),
            every_n_sample_reg=dict(
                every_n=5000,
                sample_n_views=5,
                control_weights=[1.0],
                save_s3=False,
            ),
            every_n_sample_ema=dict(
                every_n=1000,
                sample_n_views=5,
                control_weights=[1.0],
                save_s3=False,
            ),
            wandb=dict(
                save_s3=False,
            ),
            wandb_10x=dict(
                save_s3=False,
            ),
            dataloader_speed=dict(
                save_s3=False,
            ),
            frame_loss_log=dict(
                save_s3=False,
            ),
        ),
    ),
    model_parallel=dict(
        context_parallel_size=2,
        tensor_model_parallel_size=1,
    ),
)

cs = ConfigStore.instance()

# Register the configuration with Hydra ConfigStore
for _item in [
    transfer2_waymo_multiview_post_train,
]:
    experiment_name = [name.lower() for name, value in globals().items() if value is _item][0]  # noqa: RUF015

    cs.store(
        group="experiment",
        package="_global_",
        name=experiment_name,
        node=_item,
    )
