import os
import subprocess
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Literal
import torch
import tyro
from transformers import TrainingArguments
from gr00t.data.dataset import LeRobotMixtureDatasetPrior, LeRobotPriorDatasetV2
from gr00t.data.schema import EmbodimentTag
from gr00t.experiment.data_config import load_data_config
from gr00t.experiment.runner import TrainRunner
from gr00t.model.SOMA import SOMA
from gr00t.model.transforms import EMBODIMENT_TAG_MAPPING
from gr00t.utils.peft import get_lora_model
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["WANDB_BASE_URL"] = "https://api.bandw.top"
os.environ["WANDB_PROJECT"] = "SOMA-TEST"
os.environ['WANDB_RUN_GROUP'] = 'Robocasa'
@dataclass
class ArgsConfig:
    """Configuration for GR00T model fine-tuning."""
    dataset_path: List[str] = field(default_factory=lambda:
        [
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.CanToDrawer",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.CupToDrawer",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.CuttingboardToBasket",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.CuttingboardToCardboardBox",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.CuttingboardToPan",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.CuttingboardToPot",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.CuttingboardToTieredBasket",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.PlaceBottleToCabinet",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.PlacematToBasket",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.PlacematToBowl",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.PlacematToPlate",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.PlacematToTieredShelf",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.PlaceMilkToMicrowave",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.PlateToBowl",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.PlateToCardboardBox",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.PlateToPan",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.PlateToPlate",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.PotatoToMicrowave",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.TrayToCardboardBox",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.TrayToPlate",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.TrayToPot",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.TrayToTieredBasket",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.TrayToTieredShelf",
            "/workspaces/Jeff/Isaac-GR00T/data/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim/gr1_arms_waist.WineToCabinet",
        ])
    """Path to the dataset directory or directories, we assume all datasets have the same data config"""
    output_dir: str = "./abl_checkpoint/SOMA_Robocasa_300_shot_interval_10"
    """Directory to save model checkpoints."""
    few_shot_num: int = 300
    data_config: str = "fourier_gr1_arms_waist"
    """
    Data configuration to use for training.
    Options:
    - Built-in configs: Use predefined config names like 'so100', 'fourier_gr1_arms_only', 'unitree_g1'.
    - External configs: Use 'module:ClassName' format to load custom configs from external files. e.g. 'my_dir.my_configs:RobotConfig'
    See gr00t/experiment/data_config.py for more details.
    """
    batch_size: int = 60
    """Batch size per GPU for training."""
    max_steps: int = 30000
    """Maximum number of training steps."""
    num_gpus: int = 32
    """Number of GPUs to use for training."""
    save_steps: int = 5000
    """Number of steps between saving checkpoints."""
    base_model_path: str = "/workspaces/Jeff/Isaac-GR00T/checkpoint/GR00T-N1.5-3B"
    """Path or HuggingFace model ID for the base model."""
    tune_llm: bool = False
    """Whether to fine-tune the language model backbone."""
    tune_visual: bool = True
    """Whether to fine-tune the vision tower."""
    tune_projector: bool = True
    """Whether to fine-tune the projector."""
    tune_diffusion_model: bool = True
    """Whether to fine-tune the diffusion model."""
    resume: bool = False
    """Whether to resume from a checkpoint."""
    learning_rate: float = 3e-5
    """Learning rate for training."""
    weight_decay: float = 1e-5
    """Weight decay for AdamW optimizer."""
    warmup_ratio: float = 0.05
    """Ratio of total training steps used for warmup."""
    lora_rank: int = 0
    """Rank for the LORA model. If 0, no LORA will be used."""
    lora_alpha: int = 16
    """Alpha value for the LORA model."""
    lora_dropout: float = 0.1
    """Dropout rate for the LORA model."""
    lora_full_model: bool = False
    """Whether to use the full model for LORA. If False, only the action head will be trained."""
    dataloader_num_workers: int = 10
    """Number of workers for data loading per GPU."""
    dataloader_prefetch_factor: int = 4
    """Prefetch factor for data loading."""
    report_to: Literal["wandb", "tensorboard", "azure_ml"] = "wandb"
    """Where to report training metrics (e.g., 'wandb', 'tensorboard', 'azure_ml')."""
    embodiment_tag: Literal[tuple(EMBODIMENT_TAG_MAPPING.keys())] = "gr1"
    """Embodiment tag to use for training. e.g. 'new_embodiment', 'gr1'"""
    video_backend: Literal["decord", "torchvision_av"] = "decord"
    """Video backend to use for training. [decord, torchvision_av]"""
    balance_dataset_weights: bool = True
    """Used in LeRobotMixtureDataset. If True, we will balance the dataset weights, by multiplying the total trajectory to each dataset"""
    balance_trajectory_weights: bool = True
    """Used in LeRobotMixtureDataset. If True, sample trajectories within a dataset weighted by their length; otherwise, equal weighting."""
def main(config: ArgsConfig):
    """Main training function."""
    embodiment_tag = EmbodimentTag(config.embodiment_tag)
    data_config_cls = load_data_config(config.data_config)
    modality_configs = data_config_cls.modality_config()
    transforms = data_config_cls.transform()
    if len(config.dataset_path) == 1:
        train_dataset = LeRobotPriorDatasetV2(
            dataset_path=config.dataset_path[0],
            modality_configs=modality_configs,
            transforms=transforms,
            embodiment_tag=embodiment_tag,
            video_backend=config.video_backend,
            few_shot_num = config.few_shot_num,
        )
    else:
        single_datasets = []
        for p in config.dataset_path:
            assert os.path.exists(p), f"Dataset path {p} does not exist"
            dataset = LeRobotPriorDatasetV2(
                dataset_path=p,
                modality_configs=modality_configs,
                transforms=transforms,
                embodiment_tag=embodiment_tag,
                video_backend=config.video_backend,
                few_shot_num = config.few_shot_num,
            )
            single_datasets.append(dataset)
        train_dataset = LeRobotMixtureDatasetPrior(
            data_mixture=[
                (dataset, 1.0)
                for dataset in single_datasets
            ],
            mode="train",
            balance_dataset_weights=config.balance_dataset_weights,
            balance_trajectory_weights=config.balance_trajectory_weights,
            seed=42,
            metadata_config={
                "percentile_mixing_method": "weighted_average",
            },
        )
        print(f"Loaded {len(single_datasets)} datasets, with {config.dataset_path} ")
    data_action_horizon = len(data_config_cls.action_indices)
    model = SOMA.from_pretrained(
        pretrained_model_name_or_path=config.base_model_path,
        tune_llm=config.tune_llm,
        tune_visual=config.tune_visual,
        tune_projector=config.tune_projector,
        tune_diffusion_model=config.tune_diffusion_model,
    )
    if data_action_horizon != model.action_head.config.action_horizon:
        print(
            f"Recreating action head with action_horizon {data_action_horizon} (was {model.action_head.config.action_horizon})"
        )
        new_action_head_config = model.action_head.config
        new_action_head_config.action_horizon = data_action_horizon
        from gr00t.model.action_head.flow_matching_action_head import (
            FlowmatchingActionHead,
        )
        new_action_head = FlowmatchingActionHead(new_action_head_config)
        new_action_head.load_state_dict(model.action_head.state_dict(), strict=False)
        model.action_head = new_action_head
        model.config.action_horizon = data_action_horizon
        model.action_horizon = data_action_horizon
        model.config.action_head_cfg["action_horizon"] = data_action_horizon
        model.action_head.set_trainable_parameters(
            tune_projector=config.tune_projector, tune_diffusion_model=config.tune_diffusion_model
        )
    model.compute_dtype = "bfloat16"
    model.config.compute_dtype = "bfloat16"
    if config.lora_rank > 0:
        model = get_lora_model(
            model,
            rank=config.lora_rank,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            action_head_only=not config.lora_full_model,
        )
    training_args = TrainingArguments(
        output_dir=config.output_dir,
        run_name=None,
        remove_unused_columns=False,
        deepspeed="",
        gradient_checkpointing=False,
        bf16=True,
        tf32=True,
        per_device_train_batch_size=config.batch_size,
        gradient_accumulation_steps=1,
        dataloader_num_workers=config.dataloader_num_workers,
        dataloader_pin_memory=False,
        dataloader_prefetch_factor=config.dataloader_prefetch_factor,
        dataloader_persistent_workers=config.dataloader_num_workers > 0,
        optim="adamw_torch",
        adam_beta1=0.95,
        adam_beta2=0.999,
        adam_epsilon=1e-8,
        learning_rate=config.learning_rate,
        weight_decay=config.weight_decay,
        warmup_ratio=config.warmup_ratio,
        lr_scheduler_type="cosine",
        logging_steps=10.0,
        num_train_epochs=300,
        max_steps=config.max_steps,
        save_strategy="steps",
        save_steps=config.save_steps,
        save_total_limit=5,
        report_to=config.report_to,
        seed=42,
        do_eval=False,
        ddp_find_unused_parameters=False,
        ddp_bucket_cap_mb=100,
        torch_compile_mode=None,
    )
    experiment = TrainRunner(
        train_dataset=train_dataset,
        model=model,
        training_args=training_args,
        resume_from_checkpoint=config.resume,
    )
    experiment.train()
if __name__ == "__main__":
    config = tyro.cli(ArgsConfig)
    print("\n" + "=" * 50)
    print("GR00T FINE-TUNING CONFIGURATION:")
    print("=" * 50)
    for key, value in vars(config).items():
        print(f"{key}: {value}")
    print("=" * 50 + "\n")
    available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
    assert (
        config.num_gpus <= available_gpus
    ), f"Number of GPUs requested ({config.num_gpus}) is greater than the available GPUs ({available_gpus})"
    assert config.num_gpus > 0, "Number of GPUs must be greater than 0"
    print(f"Using {config.num_gpus} GPUs")
    if config.num_gpus == 1:
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        main(config)
    else:
        if os.environ.get("IS_TORCHRUN", "0") == "1":
            main(config)
        else:
            script_path = Path(__file__).absolute()
            if "CUDA_VISIBLE_DEVICES" in os.environ:
                del os.environ["CUDA_VISIBLE_DEVICES"]
            script_path = Path(__file__).absolute()
            raw_args_list = sys.argv[1:]
            cmd = [
                "torchrun",
                "--standalone",
                f"--nproc_per_node={config.num_gpus}",
                "--nnodes=1",  # default to 1 node for now
                str(script_path),
                *raw_args_list,
            ]
            print("Running torchrun command: ", cmd)
            env = os.environ.copy()
            env["IS_TORCHRUN"] = "1"
            sys.exit(subprocess.run(cmd, env=env).returncode)
