import os
from pathlib import Path
import random
import re
import yaml

import click
import numpy as np
import torch
from dotenv import find_dotenv, load_dotenv

load_dotenv(find_dotenv(usecwd=True))


@click.command()
@click.option("--config", type=str, required=True, help="Path to the config file.")
@click.option("--probing", type=bool, default=False)
@click.option("--pass_ckpt", type=str, default=None)
@click.option("--validate_probing_only", type=bool, default=False)
@click.option("--validate_video_only", type=bool, default=False)
@click.option("--train_adapter_only", type=bool, default=False)
@click.option("--batch_size", type=int, default=None)
@click.option("--num_workers", type=int, default=None)
@click.option("--start_new_wandb_run_for_new_finetune", type=bool, default=None)
@click.option("--dataset_target", type=str, default=None)
@click.option("--training_mode", type=str, default=None)
@click.option("--only_train_fdm", type=bool, default=None)
@click.option("--alternating_stage_steps", type=int, default=None)
@click.option("--alternating_enable_warmup", type=bool, default=None)
@click.option("--alternating_warmup_steps", type=int, default=None)
@click.option("--warmup_warmup_steps", type=int, default=None)
@click.option("--idm_size", type=str, default=None)
@click.option("--use_difference_loss", type=bool, default=None)
@click.option("--drop_level", type=str, default=None)
@click.option("--adaln_fuse_type", type=str, default=None)
@click.option("--init_act_block_last_layer_0", type=bool, default=None)
@click.option("--motion_loss_weight", type=bool, default=None)
@click.option("--clip_grad", type=float, default=None)
@click.option("--lr_warmup_steps", type=int, default=None)
@click.option("--wm_lr", type=float, default=None)
@click.option("--use_action_attention", type=bool, default=None)
@click.option("--action_attention_rmsnorm", type=bool, default=None)
@click.option("--action_attention_outnorm", type=bool, default=None)
@click.option("--action_attention_num_blocks", type=int, default=None)
@click.option("--action_attention_use_causal", type=bool, default=None)
@click.option("--action_attention_self_adaln", type=bool, default=None)
@click.option("--n_max_state_action", type=int, default=None)
@click.option("--finetune_wm_adapt_realaction", type=bool, default=None)
@click.option("--finetune_wm_use_adapter", type=bool, default=None)
@click.option("--load_adapter_path", type=str, default=None)
@click.option("--realaction_add_noise_level", type=float, default=None)
@click.option("--enable_validate_during_train", type=bool, default=None)
@click.option("--probe_action_dim", type=int, default=None)
@click.option("--val_interval", type=int, default=None)
@click.option("--load_transformer_pretrained_weights", type=bool, default=None)
@click.option("--validate_bootstrap_only", type=bool, default=None)
@click.option("--robodesk_dataset_dir", type=str, default="/path/to/robodesk_dataset")
@click.option("--wm_resize_obs_64", type=bool, default=None)
@click.option("--load_idm_only", type=bool, default=None)
def main(config: str, probing: bool, pass_ckpt: str, validate_probing_only: bool, validate_video_only: bool, train_adapter_only: bool,
         batch_size: int, num_workers: int, start_new_wandb_run_for_new_finetune: bool, dataset_target: str,
         training_mode: str, only_train_fdm: bool, alternating_stage_steps: int, alternating_enable_warmup: bool,
         alternating_warmup_steps: int, warmup_warmup_steps: int, idm_size: str, use_difference_loss: bool,
         drop_level: str, adaln_fuse_type: str, init_act_block_last_layer_0: bool, motion_loss_weight: bool,
         clip_grad: float, lr_warmup_steps: int, wm_lr: float, use_action_attention: bool,
         action_attention_rmsnorm: bool, action_attention_outnorm: bool, action_attention_num_blocks: int, action_attention_use_causal: bool,
         action_attention_self_adaln: bool, n_max_state_action: int, finetune_wm_adapt_realaction: bool, load_adapter_path: str,
         realaction_add_noise_level: float, enable_validate_during_train: bool, finetune_wm_use_adapter: bool, probe_action_dim: int,
         val_interval: int, load_transformer_pretrained_weights: bool, validate_bootstrap_only: bool,
         robodesk_dataset_dir: str, wm_resize_obs_64: bool, load_idm_only: bool):
    from torch.multiprocessing import set_start_method
    from .utils.config import LoadConfig
    from .utils.data import load_dataset_from_config, load_dataset
    from . import AutoModelForLatentAction
    from .models import ModelConfig

    set_start_method("spawn")
    config = Path(config)

    ds_config = LoadConfig.load(config / "dataset_config.yaml")
    model_config = ModelConfig.load(config / "model_config.yaml")

    if dataset_target is not None:
        ds_target_dict = {
            "full": [["agibotworld_beta", 2.0], ["magic_soup_v2", 3.0], ["egofull_v2", 5.0]],
            "default": [["mini_ego_v1", 3.0], ["bridgert1", 1.0]],
            "bridgert1": [["bridgert1", 1.0]],
            "bridge": [["bridge_dataset", 1.0]],
            "rt1": [["fractal20220817_data", 1.0]],
            "ssv2": [["ssv2", 1.0]],
            "xhand": [["xhand", 1.0]],
            "agibotoxe": [["agibotworld_beta", 2.0], ["magic_soup_v2", 3.0]],
            "droid": [["droid", 1.0]],
            "kuka": [["kuka", 1.0]],
            "bc_z": [["bc_z", 1.0]],
            "language_table": [["language_table", 1.0]],
            "stanford_hydra_dataset_converted_externally_to_rlds": [["stanford_hydra_dataset_converted_externally_to_rlds", 1.0]],
            "furniture_bench_dataset_converted_externally_to_rlds": [["furniture_bench_dataset_converted_externally_to_rlds", 1.0]],
            "nyu_franka_play_dataset_converted_externally_to_rlds": [["nyu_franka_play_dataset_converted_externally_to_rlds", 1.0]],
            "taco_play": [["taco_play", 1.0]],
            "jaco_play": [["jaco_play", 1.0]],
            "agibot": [["agibotworld_beta", 1.0]],
            "oxe": [["magic_soup_v2", 1.0]],
            "libero": [["libero_v1", 1.0]],
            "libero_spatial": [["libero_spatial_no_noops", 1.0]],
            "libero_object": [["libero_object_no_noops", 1.0]],
            "libero_goal": [["libero_goal_no_noops", 1.0]],
            "libero_10": [["libero_10_no_noops", 1.0]],
            "egofull_v2": [["bridge", 1.0], ["egofull_v2", 1.0]],
            "robodesk": [["robodesk", 1.0]],
        }
        try:
            ds_config.target = ds_target_dict[dataset_target]
        except KeyError:
            raise ValueError(f"Unknown dataset target: {dataset_target}")

    if training_mode is not None:
        model_config.config['training_mode'] = training_mode
    if only_train_fdm is not None:
        model_config.config['only_train_fdm'] = only_train_fdm
    if alternating_stage_steps is not None:
        model_config.config['alternating_config']['stage_steps'] = alternating_stage_steps
    if alternating_enable_warmup is not None:
        model_config.config['alternating_config']['enable_warmup'] = alternating_enable_warmup
    if alternating_warmup_steps is not None:
        model_config.config['alternating_config']['warmup_steps'] = alternating_warmup_steps
    if warmup_warmup_steps is not None:
        model_config.config['warmup_config']['warmup_steps'] = warmup_warmup_steps
    if idm_size == "big":
        model_config.config["encoder_depth"] = 16
        model_config.config["action_latent_dim"] = 128
        model_config.config["num_learned_tokens"] = 4
    if use_difference_loss is not None:
        model_config.config["use_difference_loss"] = use_difference_loss
    if probe_action_dim is not None:
        model_config.config["probe_action_dim"] = probe_action_dim

    if dataset_target is not None:
        if dataset_target == "libero" and model_config.config["probe_action_dim"] != 7:
            model_config.config["probe_action_dim"] = 7
        if dataset_target == "xhand" and model_config.config["probe_action_dim"] != 18:
            model_config.config["probe_action_dim"] = 18
        if dataset_target == "bridge" and model_config.config["probe_action_dim"] != 7:
            model_config.config["probe_action_dim"] = 7
        if dataset_target == "rt1" and model_config.config["probe_action_dim"] != 7:
            model_config.config["probe_action_dim"] = 7
        if dataset_target == "agibot" and model_config.config["probe_action_dim"] != 14:
            model_config.config["probe_action_dim"] = 14
        if dataset_target in ["droid", "kuka", "bc_z", "taco_play", "jaco_play", "language_table",
                              "stanford_hydra_dataset_converted_externally_to_rlds",
                              "furniture_bench_dataset_converted_externally_to_rlds",
                              "nyu_franka_play_dataset_converted_externally_to_rlds"] and model_config.config["probe_action_dim"] != 7:
            model_config.config["probe_action_dim"] = 7
        if dataset_target == "robodesk" and model_config.config["probe_action_dim"] != 5:
            model_config.config["probe_action_dim"] = 5

    if not probing:
        if "FDM_Arch" in model_config.architecture:
            from .trainer import Trainer_FDM as Trainer
            from .trainer import TrainerConfig_FDM as TrainerConfig
        else:
            from .trainer import Trainer, TrainerConfig
    else:
        from .trainer.trainer_prober import Trainer_Prober as Trainer
        from .trainer.trainer_prober import TrainerConfig_Prober as TrainerConfig

    trainer_config = TrainerConfig.load(
        config / "trainer_config.yaml",
        checksum_input=ds_config.model_dump_json() + model_config.model_dump_json(),
    )

    trainer_config.validate_probing_only = validate_probing_only
    trainer_config.validate_video_only = validate_video_only
    trainer_config.train_adapter_only = train_adapter_only
    trainer_config.validate_bootstrap_only = validate_bootstrap_only
    if pass_ckpt is not None and pass_ckpt.lower() not in ["none", "", "null"]:
        trainer_config.init_checkpoint = pass_ckpt
    if dataset_target is not None and dataset_target == "egofull_v2":
        trainer_config.dataset_target = dataset_target
    if batch_size is not None:
        if dataset_target == "egofull_v2":
            trainer_config.batch_size = batch_size * 2
            trainer_config.val_batch_size = batch_size * 2
        else:
            trainer_config.batch_size = batch_size
            trainer_config.val_batch_size = batch_size if "'xhand" not in str(ds_config.target) else min(batch_size, 16)
    if num_workers is not None:
        trainer_config.num_workers = num_workers
    if start_new_wandb_run_for_new_finetune is not None:
        trainer_config.start_new_wandb_run_for_new_finetune = start_new_wandb_run_for_new_finetune
    if clip_grad is not None:
        trainer_config.clip_grad = clip_grad
    if lr_warmup_steps is not None:
        trainer_config.warmup_steps = lr_warmup_steps
    if wm_lr is not None:
        trainer_config.wm_lr = wm_lr
    if n_max_state_action is not None:
        trainer_config.n_max_state_action = n_max_state_action
    if finetune_wm_adapt_realaction is not None:
        trainer_config.finetune_wm_adapt_realaction = finetune_wm_adapt_realaction
    if finetune_wm_use_adapter is not None:
        trainer_config.finetune_wm_use_adapter = finetune_wm_use_adapter
    if load_adapter_path is not None:
        trainer_config.load_adapter_path = load_adapter_path
    if realaction_add_noise_level is not None:
        trainer_config.realaction_add_noise_level = realaction_add_noise_level
    if enable_validate_during_train is not None:
        trainer_config.enable_validate_during_train = enable_validate_during_train
    if val_interval is not None:
        trainer_config.val_interval = val_interval
    if wm_resize_obs_64 is not None:
        trainer_config.wm_resize_obs_64 = wm_resize_obs_64
    if load_idm_only is not None:
        trainer_config.load_idm_only = load_idm_only

    if "FDM_Arch" in model_config.architecture:
        with open(model_config.config["wm_config_path"], 'r', encoding='utf-8') as file:
            arch_config = yaml.safe_load(file)
        arch_config['transformer']['drop_level'] = drop_level if drop_level is not None else arch_config['transformer']['drop_level']
        arch_config['transformer']['adaln_fuse_type'] = adaln_fuse_type if adaln_fuse_type is not None else arch_config['transformer']['adaln_fuse_type']
        arch_config['transformer']['init_act_block_last_layer_0'] = init_act_block_last_layer_0 if init_act_block_last_layer_0 is not None else arch_config['transformer']['init_act_block_last_layer_0']
        arch_config['motion_loss_weight'] = motion_loss_weight if motion_loss_weight is not None else arch_config['motion_loss_weight']
        arch_config['transformer']['use_action_attention'] = use_action_attention if use_action_attention is not None else arch_config['transformer']['use_action_attention']
        arch_config['transformer']['action_attention_rmsnorm'] = action_attention_rmsnorm if action_attention_rmsnorm is not None else arch_config['transformer']['action_attention_rmsnorm']
        arch_config['transformer']['action_attention_outnorm'] = action_attention_outnorm if action_attention_outnorm is not None else arch_config['transformer']['action_attention_outnorm']
        arch_config['transformer']['action_attention_num_blocks'] = action_attention_num_blocks if action_attention_num_blocks is not None else arch_config['transformer']['action_attention_num_blocks']
        arch_config['transformer']['action_attention_use_causal'] = action_attention_use_causal if action_attention_use_causal is not None else arch_config['transformer']['action_attention_use_causal']
        arch_config['transformer']['action_attention_self_adaln'] = action_attention_self_adaln if action_attention_self_adaln is not None else arch_config['transformer']['action_attention_self_adaln']
        arch_config['load_transformer_pretrained_weights'] = load_transformer_pretrained_weights if load_transformer_pretrained_weights is not None else arch_config['load_transformer_pretrained_weights']
        
    def set_seed(seed_value):
        from .utils.seed import set_seed as util_set_seed
        random.seed(seed_value)
        np.random.seed(seed_value)
        torch.manual_seed(seed_value)
        rank = int(os.getenv('LOCAL_RANK', '0'))
        ckpt_dir = trainer_config.ckpt_dir
        
        init_seed_step = 0
        if ckpt_dir.exists() and any(ckpt_dir.iterdir()):
            ckpt_paths = list(ckpt_dir.iterdir())
            steps = [int(re.search(r"step=(\d+)-train", c.name).group(1)) for c in ckpt_paths if re.search(r"step=(\d+)-train", c.name)]
            if steps:
                init_seed_step = sorted(steps)[-1]
        elif trainer_config.init_checkpoint and not trainer_config.start_new_wandb_run_for_new_finetune:
            match = re.search(r"step=(\d+)-train", trainer_config.init_checkpoint.split("/")[-1])
            if match:
                init_seed_step = int(match.group(1))

        current_seed = seed_value + rank * 10000000 + init_seed_step
        util_set_seed(current_seed, include_tensorflow=True)

    set_seed(trainer_config.seed)

    device = f"cuda:{os.getenv('LOCAL_RANK', '0')}"
    if getattr(trainer_config, "datasetwise_itv", None) and trainer_config.datasetwise_itv and "FDM_Arch" in model_config.architecture:
        from .trainer.trainer_fdm import prepare_dataset_cfg, dict_to_namespace
        arch_config.update({"train_wm_seq_length": model_config.config["d_t"], "sample_length": model_config.config["d_t"] * arch_config["vae_max_compress_rate"]})
        target, dataset_cfg_file = prepare_dataset_cfg(ds_config.model_dump(), dict_to_namespace(arch_config))
        per_dataset_config = dataset_cfg_file['per_dataset_config']

        if dataset_target == "robodesk":
            from .data.robodesk_loader import create_robodesk_loader
            seq_len = per_dataset_config['robodesk']['expect_valid_frames']
            
            def get_all_file_paths(directory):
                return [os.path.join(root, f) for root, _, files in os.walk(directory) for f in files]
            
            robodesk_h5_list = [_ for _ in get_all_file_paths(robodesk_dataset_dir) if 'noise_0.1' in _]
            train_dl = create_robodesk_loader(robodesk_h5_list, split="train", batch_size=trainer_config.batch_size, seq_len=seq_len, n_max_state_action=trainer_config.n_max_state_action, num_workers=num_workers)
            val_dl = create_robodesk_loader(robodesk_h5_list, split="valid", batch_size=trainer_config.val_batch_size, seq_len=seq_len, n_max_state_action=trainer_config.n_max_state_action, num_workers=num_workers)
        else:
            per_dataset_cfg_simple = {name: {k: v for k, v in cfg.items() if k not in ['expect_valid_frames', 'micro_frame_size']} for name, cfg in dataset_cfg_file['per_dataset_config'].items()}
            train_ds = load_dataset(target=target, common_config=dataset_cfg_file['common_config'], per_dataset_config=per_dataset_cfg_simple, is_train=True, auto_search_dataset=True, stablize_device=device)
            train_dl = train_ds.get_dataloader(batch_size=trainer_config.batch_size, num_workers=trainer_config.num_workers, include_dataset_name=True, include_action=True, include_control_frequency=True, flatten_state_action=trainer_config.flatten_state_action, n_max_state_action=trainer_config.n_max_state_action)
            
            val_ds = load_dataset(target=target, common_config=dataset_cfg_file['common_config'], per_dataset_config=per_dataset_cfg_simple, is_train=False, auto_search_dataset=True, stablize_device=device)
            val_dl = val_ds.get_dataloader(batch_size=trainer_config.val_batch_size, num_workers=trainer_config.num_workers, include_dataset_name=True, include_action=True, include_control_frequency=True, flatten_state_action=trainer_config.flatten_state_action, n_max_state_action=trainer_config.n_max_state_action)

    else:
        train_ds = load_dataset_from_config(ds_config, is_train=True, stablize_device=device)
        val_ds = load_dataset_from_config(ds_config, is_train=False, stablize_device=device)
        train_dl = train_ds.get_dataloader(batch_size=trainer_config.batch_size, include_action=True, include_control_frequency=True, flatten_state_action=trainer_config.flatten_state_action, n_max_state_action=trainer_config.n_max_state_action, num_workers=trainer_config.num_workers)
        val_dl = val_ds.get_dataloader(batch_size=trainer_config.val_batch_size, include_action=True, include_control_frequency=True, flatten_state_action=trainer_config.flatten_state_action, n_max_state_action=trainer_config.n_max_state_action, num_workers=trainer_config.num_workers)

    if "FDM_Arch" in model_config.architecture:
        model_config.config["wm_config"] = arch_config
        model = AutoModelForLatentAction.from_config(model_config)
    else:
        model = AutoModelForLatentAction.from_config(model_config, n_datasets=37)

    trainer = Trainer(
        model,
        input_keys=trainer_config.input_keys,
        train_dataloader=train_dl,
        val_dataloader=val_dl,
        cfg=trainer_config,
        config_to_log={"dataset": ds_config.model_dump(), "model": model_config.dump()},
    )
    
    if getattr(trainer_config, "datasetwise_itv", None) and trainer_config.datasetwise_itv and "FDM_Arch" in model_config.architecture:
        trainer.per_dataset_config = per_dataset_config
    
    trainer.train()

if __name__ == "__main__":
    main()