"""
pretrain_fe.py

Pre-trains OpenVLA with Function Encoder using multi-dataset calibration buffer mechanism.
Based on the finetune.py architecture and test_fe_calibration.py implementation.
"""

import logging
import os
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import draccus
import torch
import torch.distributed as dist
import tqdm
from huggingface_hub import snapshot_download
from peft import LoraConfig, PeftModel, get_peft_model, get_peft_model_state_dict
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor

import wandb
from experiments.robot.openvla_utils import (
    check_model_logic_mismatch,
    model_is_on_hf_hub,
    update_auto_map,
)
from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
from prismatic.models.action_heads import FunctionEncoderActionHead
from prismatic.models.backbones.llm.prompting import PurePromptBuilder
from prismatic.training.train_utils import (
    get_current_action_mask,
    get_next_actions_mask,
)
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.calibration_buffer import CalibrationManager
from prismatic.vla.constants import (
    ACTION_DIM,
    NUM_ACTIONS_CHUNK,
)
from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset
from prismatic.vla.datasets.rlds.oxe import OXE_NAMED_MIXTURES

# Sane Defaults
os.environ["TOKENIZERS_PARALLELISM"] = "false"


logger = logging.getLogger(__name__)


@dataclass
class PretrainFEConfig:
    # fmt: off
    vla_path: str = "openvla/openvla-7b"             # Path to OpenVLA model (on HuggingFace Hub or stored locally)

    # Dataset
    data_root_dir: Path = Path("$WORKDIR")           # Directory containing RLDS datasets
    data_mix: str = "oxe_magic_soup_plus"            # Data mixture name (e.g., oxe_magic_soup_plus)
    run_root_dir: Path = Path("runs")                # Path to directory to store logs & checkpoints
    shuffle_buffer_size: int = 100_000               # Dataloader shuffle buffer size

    # Function Encoder specific parameters
    fe_basis_functions: int = 32                     # Number of basis functions (k)
    n_continuous_actions: int = 6                    # Number of continuous action dimensions

    # Calibration Buffer parameters
    calibration_buffer_size: int = 512               # Size of each dataset's calibration buffer
    calibrate_interval: int = 16                     # Recalibrate every N steps
    prefill_samples_per_dataset: int = 256           # Samples per dataset for prefilling

    # Architecture parameters
    use_film: bool = False                           # If True, uses FiLM to infuse language inputs into visual features
    num_images_in_input: int = 1                     # Number of images in the VLA input (default: 1)
    use_proprio: bool = False                        # If True, includes robot proprioceptive state in input
    image_aug: bool = True                           # If True, trains with image augmentations

    # Training configuration
    batch_size: int = 16                             # Batch size per device
    learning_rate: float = 1e-4                      # Learning rate
    lr_warmup_steps: int = 1000                      # Number of steps to warm up learning rate
    num_steps_before_decay: int = 200_000            # Number of steps before LR decays by 10x
    grad_accumulation_steps: int = 1                 # Number of gradient accumulation steps
    max_steps: int = 300_000                         # Max number of training steps
    save_freq: int = 5_000                           # Checkpoint saving frequency in steps
    save_latest_checkpoint_only: bool = False        # Save all checkpoints for pre-training
    resume: bool = False                             # If True, resumes from checkpoint
    resume_step: Optional[int] = None                # Step to resume from

    # LoRA configuration
    use_lora: bool = True                            # If True, uses LoRA fine-tuning
    lora_rank: int = 32                              # LoRA rank
    lora_dropout: float = 0.05                       # LoRA dropout

    # Logging
    wandb_project: Optional[str] = None              # Weights & Biases project name
    wandb_entity: Optional[str] = None               # Weights & Biases entity name
    run_id_note: Optional[str] = None                # Extra note to add to run ID
    wandb_log_freq: int = 10                         # WandB logging frequency in steps
    seed: int = 7                                    # Random seed

    # fmt: on


def get_run_id(cfg: PretrainFEConfig) -> str:
    """Generate a run ID for the experiment."""
    if cfg.resume:
        # Resume from previous run
        run_id = cfg.vla_path.split("/")[-1]
        if "chkpt" in run_id.split("--")[-1]:
            run_id = "--".join(run_id.split("--")[:-1])
    else:
        run_id = (
            f"fe-pretrain+{cfg.data_mix}"
            f"+k{cfg.fe_basis_functions}"
            f"+b{cfg.batch_size * cfg.grad_accumulation_steps}"
            f"+lr-{cfg.learning_rate}"
        )
        if cfg.use_lora:
            run_id += f"+lora-r{cfg.lora_rank}"
        if cfg.image_aug:
            run_id += "--image_aug"
        if cfg.run_id_note is not None:
            run_id += f"--{cfg.run_id_note}"
    return run_id


def load_checkpoint(path: Path, module_name: str, step: int, device: str = "cpu") -> dict:
    """Load a checkpoint for a given module."""
    checkpoint_path = path / f"{module_name}--{step}_checkpoint.pt"
    if dist.get_rank() == 0:
        logger.info(f"Loading checkpoint: {checkpoint_path}")
    state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device)
    # Remove DDP wrapper if present
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith("module."):
            new_state_dict[k[7:]] = v
        else:
            new_state_dict[k] = v
    return new_state_dict


@draccus.wrap()
def pretrain_fe(cfg: PretrainFEConfig) -> None:
    """Main pre-training function for Function Encoder VLA."""

    if dist.get_rank() == 0:
        logger.info("=" * 20)
        logger.info("FUNCTION ENCODER PRE-TRAINING")
        logger.info("=" * 20)

    # Initialize distributed training
    device_id = dist.get_rank() % torch.cuda.device_count()
    torch.cuda.set_device(device_id)
    torch.cuda.empty_cache()
    torch.manual_seed(cfg.seed)

    # Assert distributed training is initialized
    assert dist.is_initialized(), "Please run with torchrun."

    # Create run directory and ID
    run_id = get_run_id(cfg)
    run_dir = cfg.run_root_dir / run_id
    os.makedirs(run_dir, exist_ok=True)

    if dist.get_rank() == 0:
        logger.info(f"📁 Run directory: {run_dir}")
        logger.info(f"🎯 Training mixture: {cfg.data_mix}")
        logger.info(f"🔧 FE basis functions: {cfg.fe_basis_functions}")

    # Initialize Weights & Biases
    if cfg.wandb_project is not None and dist.get_rank() == 0:
        wandb.init(
            project=cfg.wandb_project,
            entity=cfg.wandb_entity,
            name=run_id,
            config=cfg.__dict__,
        )

    # === Load Base VLA Model ===
    if dist.get_rank() == 0:
        logger.info(f"🤖 Loading VLA model from {cfg.vla_path}")

    if model_is_on_hf_hub(cfg.vla_path):
        vla_download_path = snapshot_download(repo_id=cfg.vla_path)
        cfg.vla_path = vla_download_path
    else:
        # Register models if loading locally
        AutoConfig.register("openvla", OpenVLAConfig)
        AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
        AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
        AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)

    # Update config and sync
    if dist.get_rank() == 0:
        update_auto_map(cfg.vla_path)
        check_model_logic_mismatch(cfg.vla_path)

    dist.barrier()

    # Load processor and VLA
    processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True)
    vla = AutoModelForVision2Seq.from_pretrained(
        cfg.vla_path,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    ).to(device_id)

    # Set number of images
    vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input)

    # === Create Function Encoder Action Head ===
    if dist.get_rank() == 0:
        logger.info(f"🔧 Creating FunctionEncoderActionHead with k={cfg.fe_basis_functions}")

    # Get model dimensions
    vla_module = vla.module if hasattr(vla, "module") else vla
    llm_dim = vla_module.llm_dim

    fe_action_head = (
        FunctionEncoderActionHead(
            input_dim=llm_dim * ACTION_DIM,
            hidden_dim=llm_dim,
            action_dim=ACTION_DIM,
            k=cfg.fe_basis_functions,
            n_continuous=cfg.n_continuous_actions,
        )
        .to(device_id)
        .to(torch.bfloat16)
    )

    # === Apply LoRA ===
    if cfg.use_lora:
        if dist.get_rank() == 0:
            logger.info(f"📌 Applying LoRA with rank {cfg.lora_rank}")
        lora_config = LoraConfig(
            r=cfg.lora_rank,
            lora_alpha=min(cfg.lora_rank, 16),
            lora_dropout=cfg.lora_dropout,
            target_modules="all-linear",
            init_lora_weights="gaussian",
        )
        vla = get_peft_model(vla, lora_config)
        if dist.get_rank() == 0:
            vla.print_trainable_parameters()

    # === Resume from checkpoint if needed ===
    if cfg.resume:
        if dist.get_rank() == 0:
            logger.info(f"🔄 Resuming from checkpoint at step {cfg.resume_step}")
        # Load VLA state (prefer LoRA-only checkpoint if present)
        vla_lora_ckpt = Path(cfg.vla_path) / f"vla_lora--{cfg.resume_step}_checkpoint.pt"
        if vla_lora_ckpt.exists():
            vla_state = load_checkpoint(Path(cfg.vla_path), "vla_lora", cfg.resume_step, device_id)
            vla.load_state_dict(vla_state, strict=False)
        else:
            vla_state = load_checkpoint(Path(cfg.vla_path), "vla", cfg.resume_step, device_id)
            vla.load_state_dict(vla_state)

        # Load FE action head state
        fe_state = load_checkpoint(Path(cfg.vla_path), "fe_action_head", cfg.resume_step, device_id)
        fe_action_head.load_state_dict(fe_state)

        # Load calibration state
        calibration_state_path = Path(cfg.vla_path) / f"calibration_state--{cfg.resume_step}.pt"
        if calibration_state_path.exists():
            calibration_state = torch.load(calibration_state_path, map_location=device_id)
            fe_action_head.dataset_coefficients = calibration_state.get("dataset_coefficients", {})
            if dist.get_rank() == 0:
                logger.info(f"  Loaded calibration state for {len(fe_action_head.dataset_coefficients)} datasets")

    # === Wrap with DDP ===
    vla = DDP(vla, device_ids=[device_id], find_unused_parameters=True)
    fe_action_head = DDP(fe_action_head, device_ids=[device_id])

    # === Get datasets in mixture ===
    datasets = [n for n, _ in OXE_NAMED_MIXTURES[cfg.data_mix]]
    if dist.get_rank() == 0:
        logger.info(f"📋 Datasets in mixture ({len(datasets)}):")
        for i, ds in enumerate(datasets):
            logger.info(f"  {i + 1}. {ds}")

    # === Create CalibrationManager ===
    if dist.get_rank() == 0:
        logger.info("📊 Creating CalibrationManager:")
        logger.info(f"  - Buffer size: {cfg.calibration_buffer_size}")
        logger.info(f"  - Calibrate interval: {cfg.calibrate_interval} steps")
        logger.info(f"  - World size: {dist.get_world_size()}")

    calibration_manager = CalibrationManager(
        dataset_names=datasets,
        buffer_size=cfg.calibration_buffer_size,
    )

    # === Create Dataset and DataLoader ===
    if dist.get_rank() == 0:
        logger.info(f"📚 Loading dataset mixture: {cfg.data_mix}")

    action_tokenizer = ActionTokenizer(processor.tokenizer)
    batch_transform = RLDSBatchTransform(
        action_tokenizer=action_tokenizer,
        base_tokenizer=processor.tokenizer,
        image_transform=processor.image_processor.apply_transform,
        prompt_builder_fn=PurePromptBuilder,
        predict_stop_token=True,
        use_wrist_image=(cfg.num_images_in_input > 1),
        use_proprio=cfg.use_proprio,
    )

    dataset = RLDSDataset(
        data_root_dir=cfg.data_root_dir,
        data_mix=cfg.data_mix,
        batch_transform=batch_transform,
        resize_resolution=(224, 224),
        shuffle_buffer_size=cfg.shuffle_buffer_size,
        train=True,
        image_aug=cfg.image_aug,
    )

    collator = PaddedCollatorForActionPrediction(
        processor.tokenizer.model_max_length,
        processor.tokenizer.pad_token_id,
        padding_side="right",
    )

    dataloader = DataLoader(
        dataset,
        batch_size=cfg.batch_size,
        sampler=None,
        collate_fn=collator,
        num_workers=0,
    )

    # === Initialize uniform coefficients for all datasets ===
    if dist.get_rank() == 0:
        logger.info(f"\nInitializing uniform coefficients for {len(datasets)} datasets")
    for dataset_name in datasets:
        uniform_l1 = torch.ones(cfg.fe_basis_functions, device=device_id) / cfg.fe_basis_functions
        uniform_l2 = torch.ones(cfg.fe_basis_functions, device=device_id) / cfg.fe_basis_functions
        fe_action_head_module = fe_action_head.module if hasattr(fe_action_head, "module") else fe_action_head
        fe_action_head_module.set_dataset_coefficients(dataset_name, uniform_l1, uniform_l2)

    # === Setup Optimizer and Scheduler ===
    optimizer = AdamW(
        list(vla.parameters()) + list(fe_action_head.parameters()), lr=cfg.learning_rate, weight_decay=0.01
    )

    lr_scheduler = MultiStepLR(
        optimizer,
        milestones=[cfg.num_steps_before_decay],
        gamma=0.1,
    )

    # Get number of vision patches for hidden state extraction
    vla_module = vla.module if hasattr(vla, "module") else vla
    num_patches = vla_module.vision_backbone.get_num_patches() * vla_module.vision_backbone.get_num_images_in_input()

    # === Phase 1: Prefill Calibration Buffers ===
    if dist.get_rank() == 0:
        logger.info("=" * 20)
        logger.info("PHASE 1: BUFFER PREFILL")
        logger.info("=" * 20)
        logger.info(f"Prefilling buffers with {cfg.prefill_samples_per_dataset} samples per dataset...")

    logger.info(f"[Rank {dist.get_rank()}] Prefilling buffers for datasets: {list(calibration_manager.my_datasets)}")

    # Track which nodes have completed prefill
    prefill_complete = False

    for batch in dataloader:
        # Check if THIS NODE's datasets have enough samples
        my_datasets_prefilled = all(
            len(calibration_manager.buffers[dataset_name]) >= cfg.prefill_samples_per_dataset
            for dataset_name in calibration_manager.my_datasets
        )

        # Synchronize prefill status across all nodes
        # Check if ALL nodes have completed prefill
        all_prefilled = torch.tensor([1.0 if my_datasets_prefilled else 0.0], device=device_id)
        dist.all_reduce(all_prefilled, op=dist.ReduceOp.MIN)
        prefill_complete = all_prefilled.item() > 0.5

        if prefill_complete:
            # Route any remaining pending samples before breaking
            calibration_manager.route_pending_samples()
            break

        # Extract dataset names
        dataset_names = []
        for name in batch["dataset_names"]:
            if isinstance(name, bytes):
                dataset_names.append(name.decode("utf-8"))
            else:
                dataset_names.append(str(name))

        # Add samples to calibration buffers
        for i, dataset_name in enumerate(dataset_names):
            # Create a sample dict for this item - move all tensors to CPU to avoid GPU memory issues
            sample = {
                "pixel_values": batch["pixel_values"][i],
                "input_ids": batch["input_ids"][i],
                "attention_mask": batch["attention_mask"][i],
                "labels": batch["labels"][i],
                "actions": batch["actions"][i],
            }

            calibration_manager.add_training_sample(sample, dataset_name)

        # Route samples to correct nodes after each batch during prefill
        calibration_manager.route_pending_samples()

    if dist.get_rank() == 0:
        logger.info("Prefill complete!")

    message = f"[Rank {dist.get_rank()}] My datasets with buffers:"
    for ds in calibration_manager.my_datasets:
        count = len(calibration_manager.buffers[ds])
        message += f"\n  - {ds}: {count} samples"
    logger.info(message)

    # === Phase 2: Initial Calibration ===
    if dist.get_rank() == 0:
        logger.info("=" * 20)
        logger.info("PHASE 2: INITIAL CALIBRATION")
        logger.info("=" * 20)
        logger.info("Performing initial calibration...")

    vla.eval()
    fe_action_head.eval()

    calibration_manager.calibrate_all_datasets(
        fe_action_head=fe_action_head,
        vla=vla,
        collator=collator,
    )

    if dist.get_rank() == 0:
        logger.info("✅ Initial calibration complete")

    # === Phase 3: Training Loop ===
    if dist.get_rank() == 0:
        logger.info("=" * 20)
        logger.info("PHASE 3: TRAINING WITH PERIODIC CALIBRATION")
        logger.info("=" * 20)
        logger.info(f"Running {cfg.max_steps} training steps with calibration every {cfg.calibrate_interval} steps...")

    vla.train()
    fe_action_head.train()

    global_step = cfg.resume_step if cfg.resume else 0
    calibration_count = 0
    l1_losses = deque(maxlen=100)
    reg_losses = deque(maxlen=100)

    # Track L1 losses per dataset - initialize with all datasets in mixture
    l1_losses_per_dataset = {}
    for dataset_name in datasets:
        l1_losses_per_dataset[dataset_name] = deque(maxlen=500)  # Longer deque size

    with tqdm.tqdm(
        initial=global_step, total=cfg.max_steps, desc="Training", disable=not (dist.get_rank() == 0)
    ) as pbar:
        for batch in dataloader:
            if global_step >= cfg.max_steps:
                break

            # Check if we should calibrate
            if global_step > 0 and global_step % cfg.calibrate_interval == 0:
                vla.eval()
                fe_action_head.eval()

                calibration_manager.calibrate_all_datasets(
                    fe_action_head=fe_action_head,
                    vla=vla,
                    collator=collator,
                )

                vla.train()
                fe_action_head.train()
                calibration_count += 1
                if dist.get_rank() == 0:
                    logger.info(f"Calibration #{calibration_count} complete")

            # Get ground truth actions and dataset names
            ground_truth_actions = batch["actions"].to(torch.bfloat16).to(device_id)

            # Extract dataset names
            dataset_names = []
            for name in batch["dataset_names"]:
                if isinstance(name, bytes):
                    dataset_names.append(name.decode("utf-8"))
                else:
                    dataset_names.append(str(name))

            # Add samples to calibration buffers
            for i, dataset_name in enumerate(dataset_names):
                sample = {
                    "pixel_values": batch["pixel_values"][i],
                    "input_ids": batch["input_ids"][i],
                    "attention_mask": batch["attention_mask"][i],
                    "labels": batch["labels"][i],
                    "actions": batch["actions"][i],
                }
                calibration_manager.add_training_sample(sample, dataset_name)

            # Forward pass
            with torch.autocast("cuda", dtype=torch.bfloat16):
                output = vla(
                    input_ids=batch["input_ids"].to(device_id),
                    attention_mask=batch["attention_mask"].to(device_id),
                    pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
                    labels=batch["labels"].to(device_id),
                    output_hidden_states=True,
                )

                # Get hidden states
                last_hidden_states = output.hidden_states[-1]
                text_hidden_states = last_hidden_states[:, num_patches:-1]

                # Get action masks
                ground_truth_token_ids = batch["labels"][:, 1:].to(device_id)
                current_action_mask = get_current_action_mask(ground_truth_token_ids)
                next_actions_mask = get_next_actions_mask(ground_truth_token_ids)

                # Extract action hidden states
                batch_size = batch["input_ids"].shape[0]
                actions_hidden_states = text_hidden_states[current_action_mask | next_actions_mask]
                actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1).to(
                    torch.bfloat16
                )

                # Predict actions with FE head
                fe_action_head_module = fe_action_head.module if hasattr(fe_action_head, "module") else fe_action_head
                predicted_actions, reg_loss = fe_action_head_module.predict_action(actions_hidden_states, dataset_names)

                # Compute loss
                l1_loss = torch.nn.functional.l1_loss(predicted_actions, ground_truth_actions)

                # Compute L1 loss per dataset (vectorized per-sample, append individually)
                per_sample_l1 = (predicted_actions - ground_truth_actions).abs().mean(dim=(1, 2))  # (B,)
                for i, dataset_name in enumerate(dataset_names):
                    l1_losses_per_dataset[dataset_name].append(per_sample_l1[i].item())

                loss = l1_loss + reg_loss

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(list(vla.parameters()) + list(fe_action_head.parameters()), max_norm=1.0)
            optimizer.step()
            lr_scheduler.step()

            # Update metrics
            l1_losses.append(l1_loss.item())
            reg_losses.append(reg_loss.item())
            avg_l1_loss = sum(l1_losses) / len(l1_losses) if l1_losses else 0
            avg_reg_loss = sum(reg_losses) / len(reg_losses) if reg_losses else 0
            global_step += 1

            # Update progress bar
            pbar.set_postfix(
                {
                    "l1_loss": f"{avg_l1_loss:.4f}",
                    "reg_loss": f"{avg_reg_loss:.4f}",
                    "calibrations": calibration_count,
                }
            )
            pbar.update()

            # Logging to WandB
            if cfg.wandb_project and global_step % cfg.wandb_log_freq == 0 and dist.get_rank() == 0:
                log_dict = {
                    "train/l1_loss": avg_l1_loss,
                    "train/reg_loss": avg_reg_loss,
                    "train/learning_rate": lr_scheduler.get_last_lr()[0],
                    "train/global_step": global_step,
                    "train/calibrations": calibration_count,
                }

                # Add per-dataset L1 losses
                for dataset_name, losses in l1_losses_per_dataset.items():
                    if losses:  # Only log if we have data
                        avg_dataset_l1 = sum(losses) / len(losses)
                        log_dict[f"train/l1_loss_{dataset_name}"] = avg_dataset_l1

                wandb.log(log_dict)

                # Clear per-dataset loss deques after logging
                l1_losses.clear()
                reg_losses.clear()
                for dataset_name in l1_losses_per_dataset:
                    l1_losses_per_dataset[dataset_name].clear()

            # Save checkpoint
            if global_step % cfg.save_freq == 0 and dist.get_rank() == 0:
                checkpoint_dir = run_dir / f"checkpoint-{global_step}"
                os.makedirs(checkpoint_dir, exist_ok=True)

                # Save VLA LoRA-only state if using LoRA, otherwise full state
                vla_to_save = vla.module if hasattr(vla, "module") else vla
                if isinstance(vla_to_save, PeftModel):
                    lora_state = get_peft_model_state_dict(vla_to_save)
                    torch.save(lora_state, checkpoint_dir / f"vla_lora--{global_step}_checkpoint.pt")
                else:
                    torch.save(vla_to_save.state_dict(), checkpoint_dir / f"vla--{global_step}_checkpoint.pt")

                # Save FE action head state
                fe_to_save = fe_action_head.module if hasattr(fe_action_head, "module") else fe_action_head
                torch.save(fe_to_save.state_dict(), checkpoint_dir / f"fe_action_head--{global_step}_checkpoint.pt")

                # Save calibration state
                torch.save(
                    {
                        "dataset_coefficients": fe_to_save.dataset_coefficients,
                        "global_step": global_step,
                    },
                    checkpoint_dir / f"calibration_state--{global_step}.pt",
                )

                # Save processor
                processor.save_pretrained(checkpoint_dir)

                if dist.get_rank() == 0:
                    logger.info(f"💾 Checkpoint saved at step {global_step}")

                # Remove old checkpoints if needed
                if cfg.save_latest_checkpoint_only and global_step > cfg.save_freq:
                    old_checkpoint_dir = run_dir / f"checkpoint-{global_step - cfg.save_freq}"
                    if old_checkpoint_dir.exists():
                        import shutil

                        shutil.rmtree(old_checkpoint_dir)

    # === Final Summary ===
    if dist.get_rank() == 0:
        logger.info("=" * 20)
        logger.info("TRAINING COMPLETE")
        logger.info("=" * 20)
        logger.info("📊 Final Statistics:")
        logger.info(f"  - Total steps: {global_step}")
        logger.info(f"  - Total calibrations: {calibration_count}")
        logger.info(f"  - Final avg L1 loss: {avg_l1_loss:.4f}")
        logger.info(f"  - Final avg REG loss: {avg_reg_loss:.4f}")
        logger.info(f"  - Datasets tracked: {len(calibration_manager.buffers)}")

    # Save final checkpoint
    if dist.get_rank() == 0:
        final_checkpoint_dir = run_dir / "checkpoint-final"
        os.makedirs(final_checkpoint_dir, exist_ok=True)

        # Save VLA LoRA-only state if using LoRA, otherwise full state
        vla_to_save = vla.module if hasattr(vla, "module") else vla
        if isinstance(vla_to_save, PeftModel):
            lora_state = get_peft_model_state_dict(vla_to_save)
            torch.save(lora_state, final_checkpoint_dir / "vla_lora--final_checkpoint.pt")
        else:
            torch.save(vla_to_save.state_dict(), final_checkpoint_dir / "vla--final_checkpoint.pt")

        fe_to_save = fe_action_head.module if hasattr(fe_action_head, "module") else fe_action_head
        torch.save(fe_to_save.state_dict(), final_checkpoint_dir / "fe_action_head--final_checkpoint.pt")

        torch.save(
            {
                "dataset_coefficients": fe_to_save.dataset_coefficients,
                "global_step": global_step,
            },
            final_checkpoint_dir / "calibration_state--final.pt",
        )

        processor.save_pretrained(final_checkpoint_dir)

        logger.info(f"✅ Final checkpoint saved to: {final_checkpoint_dir}")

    if dist.get_rank() == 0:
        logger.info("🎉 FE PRE-TRAINING COMPLETE!")

    # Cleanup
    dist.destroy_process_group()


if __name__ == "__main__":
    pretrain_fe()
