# Train the residual policy with with hybrid IL and RL for ManiSkill
import random
import os
import time
import hydra
from tqdm import tqdm, trange
from omegaconf import DictConfig, OmegaConf
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import trange
import wandb
import warnings
from collections import defaultdict

# Import ManiSkill related modules
try:
    import mani_skill.envs
except ImportError:
    warnings.warn("ManiSkill not found. Please install it to use this script.")

from src.behavior.diffusion import DiffusionPolicy
from src.behavior.residual_diffusion import ResidualDiffusionPolicy
from src.dataset.dataset import StateDataset
from src.dataset.maniskill_dataset import ManiSkillStateDataset
from torch.utils.data import DataLoader, ConcatDataset, random_split
from src.common.hydra import to_native
from src.dataset.dataloader import FixedStepsDataloader
from src.dataset.rollout_buffer import RolloutBuffer
from src.common.pytorch_util import dict_to_device
from src.eval.eval_utils import get_model_from_api_or_cached
from diffusers.optimization import get_scheduler
from src.common.config_util import merge_base_bc_config_with_root_config
from src.eval.maniskill_evaluate import create_maniskill_env, evaluate_and_log_maniskill

warnings.filterwarnings('ignore', category=UserWarning)

# Register the eval resolver for omegaconf
OmegaConf.register_new_resolver("eval", eval)


@torch.no_grad()
def calculate_advantage(
    values: torch.Tensor,
    next_value: torch.Tensor,
    rewards: torch.Tensor,
    dones: torch.Tensor,
    next_done: torch.Tensor,
    steps_per_iteration: int,
    discount: float,
    gae_lambda: float,
):
    advantages = torch.zeros_like(rewards)
    lastgaelam = 0
    for t in reversed(range(steps_per_iteration)):
        if t == steps_per_iteration - 1:
            nextnonterminal = 1.0 - next_done.to(torch.float)
            nextvalues = next_value
        else:
            nextnonterminal = 1.0 - dones[t + 1].to(torch.float)
            nextvalues = values[t + 1]

        delta = rewards[t] + discount * nextvalues * nextnonterminal - values[t]
        advantages[t] = lastgaelam = (
            delta + discount * gae_lambda * nextnonterminal * lastgaelam
        )
    returns = advantages + values
    return advantages, returns


@torch.no_grad()
def compute_values(agent, residual_policy, batch, device):
    agent.eval()
    residual_policy.eval()
    nobs = agent._training_obs(batch, flatten=agent.flatten_obs).to(device)
    naction = agent._normalized_action(nobs).to(device)
    obs0 = batch["obs"][:, 0, :].to(device)
    action0 = naction[:, 0, :].to(device)
    residual_nobs = torch.cat([obs0, action0], dim=-1).to(device)
    values = residual_policy.get_value(residual_nobs).squeeze()
    return values


@torch.no_grad()
def compute_advantages_and_values(agent, residual_policy, batch, device):
    # Check if rewards exist in the batch
    if "rewards" not in batch:
        try:
            # Try to infer the right shape from the action tensor
            if "action" in batch:
                action = batch["action"]
                # Handle different possible shapes of the action tensor
                if len(action.shape) >= 2:
                    # Typically actions have shape [sequence_length, batch_size, action_dim]
                    # or [batch_size, sequence_length, action_dim]
                    seq_len = action.shape[0]
                    batch_size = action.shape[1] if len(action.shape) > 2 else 1
                else:
                    # Fallback for single sample
                    seq_len = 1
                    batch_size = 1
                
                rewards_shape = (seq_len, batch_size)
            else:
                # Default fallback if no action tensor
                rewards_shape = (1, 1)
                
            batch["rewards"] = torch.zeros(rewards_shape, device=device)
        except Exception as e:
            batch["rewards"] = torch.zeros(1, device=device)
    
    if "returns" not in batch:
        returns = torch.zeros_like(batch["rewards"])
    else:
        returns = batch["returns"]
        
    values = compute_values(agent, residual_policy, batch, device)
    advantages = returns - values
    return advantages, values


@torch.no_grad()
def evaluate_il_online(agent, residual_policy, dataloader, device):
    """
    Compute mean advantage and value for the IL policy on the test set.
    Uses the residual policy's critic to estimate values, and computes advantages
    based on returns from the dataset.
    """
    all_advantages = []
    all_values = []

    try:
        eval_loss = []
        
        for test_batch in dataloader:
            try:
                test_batch = dict_to_device(test_batch, device)
                loss, _ = agent.compute_loss(test_batch)
                eval_loss.append(loss.item())
            except Exception as inner_e:
                print(f"[IL Online] Error computing batch loss: {inner_e}")
                continue
                
        if eval_loss:
            mean_loss = np.mean(eval_loss)
            print(f"[IL Online] Mean BC Loss: {mean_loss:.4f}")
            return 0.0, mean_loss  
        
        print("[IL Online] Could not compute evaluation metrics")
        return 0.0, 0.0
    except Exception as e:
        print(f"[IL Online] Error evaluating IL policy: {e}")
        return 0.0, 0.0


def add_successful_trajectories_to_buffer(
    buffer: RolloutBuffer,
    obs: torch.Tensor,
    full_nactions: torch.Tensor,
    rewards: torch.Tensor,
    task_success: torch.Tensor,  
    device: torch.device,
    max_new_samples: int = None
):
    """
    Add successful trajectories to the rollout buffer for hybrid IL/RL training.

    Args:
        buffer (RolloutBuffer): The rollout buffer to add trajectories to
        obs (torch.Tensor): Observations of shape [steps_per_iteration, num_envs, obs_dim]
        full_nactions (torch.Tensor): Actions of shape [steps_per_iteration, num_envs, action_dim]
        rewards (torch.Tensor): Rewards of shape [steps_per_iteration, num_envs]
        task_success (torch.Tensor): Success flags of shape [num_envs]
        device (torch.device): Device to use for tensor operations
        max_new_samples (int, optional): Maximum number of new samples to add. Defaults to None.
    """
    # Find successful trajectories
    success_idxs = task_success.cpu()
    num_success = success_idxs.sum().item()

    if num_success == 0:
        print("No successful trajectories to add to buffer")
        return
    
    if max_new_samples is not None and max_new_samples > 0 and num_success > max_new_samples:
        success_env_indices = torch.where(success_idxs)[0].numpy()
        chosen_indices = np.random.choice(
            success_env_indices,
            size=max_new_samples,
            replace=False
        )
        new_mask = np.zeros_like(success_idxs, dtype=bool)
        new_mask[chosen_indices] = True
        success_idxs = torch.from_numpy(new_mask)
        num_success = max_new_samples 

    print(f"Adding {num_success} successful trajectories to buffer")

    # Extract successful trajectories
    success_obs = obs[:, success_idxs]  # [steps, num_success, obs_dim]
    success_actions = full_nactions[:, success_idxs]  # [steps, num_success, action_dim]
    success_rewards = rewards[:, success_idxs]  # [steps, num_success]

    # Calculate dones as the first non-zero reward
    cumulative_rewards = success_rewards.cumsum(dim=0)
    success_dones = (success_rewards > 0)
    
    # Move tensors to appropriate device
    success_obs = success_obs.to(device)
    success_actions = success_actions.to(device)
    success_rewards = success_rewards.to(device)
    success_dones = success_dones.to(device)

    state_dim = buffer.state_dim
    states = success_obs[..., :state_dim]

    buffer.add_trajectories(
        actions=success_actions,
        rewards=success_rewards,
        dones=success_dones,
        states=states
    )

    print(f"Buffer has {buffer.n_trajectories} trajectories with size {buffer.size}")


def train_bc_epoch(
    dataloader,
    agent,
    optimizer,
    scheduler,
    cfg,
    device,
    prefix="bc",
    epoch=0,
    use_cached_values=False
):
    """Train behavior cloning for one epoch.
    Now handles both regular dataloader and pre-computed (batch, value) pairs.
    """
    all_metrics = defaultdict(list)

    # Create progress bar
    pbar = tqdm(enumerate(dataloader), desc=f"{prefix.upper()} Epoch {epoch}", leave=False)

    for batch_idx, batch_data in pbar:
        if use_cached_values:
            batch, value_weight = batch_data
            batch_weight = value_weight.mean()
        else:
            batch = batch_data
            batch = dict_to_device(batch, device)
            batch_weight = 1.0

        optimizer.zero_grad()
        loss, metrics = agent.compute_loss(batch, base_only=cfg.il_base_only)
        loss = loss * batch_weight

        loss.backward()

        grad_norm = torch.nn.utils.clip_grad_norm_(
            agent.parameters(),
            max_norm=1.0 + 1e3 * (1 - cfg.base_training.clip_grad_norm),
        )

        optimizer.step()
        scheduler.step()

        all_metrics[f"{prefix}_loss"].append(loss.item())
        all_metrics[f"{prefix}_grad_norm"].append(grad_norm.item())
        if use_cached_values:
            all_metrics[f"{prefix}_val_weight"].append(batch_weight.item())

        pbar.set_postfix({
            "loss": f"{np.mean(all_metrics[f'{prefix}_loss'][-100:]):.4f}",
            "lr": f"{optimizer.param_groups[0]['lr']:.2e}"
        })

    return {k: np.mean(v) for k, v in all_metrics.items() if len(v) > 0}


def merge_dataloaders(*loaders, batch_size):
    """Merge multiple dataloaders into one."""
    combined_dataset = ConcatDataset([loader.dataset for loader in loaders])

    return DataLoader(
        combined_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
        drop_last=False
    )


def compute_q_filter(values, iteration, warmup_iterations=5, min_weight=0.3):
    if iteration < warmup_iterations:
        return torch.ones_like(values)

    v_min, v_max = values.min(), values.max()
    weights = (values - v_min) / (v_max - v_min)
    weights = min_weight + (1.0 - min_weight) * weights
    return weights


def train_bc_epochs(
    dataloader,
    agent,
    residual_policy,
    optimizer,
    scheduler,
    cfg,
    device,
    num_epochs,
    iteration,
    prefix="bc",
):
    """Train for multiple epochs and aggregate metrics across epochs.
    Computes Q-filter values once at the start if enabled.
    """
    epoch_metrics = defaultdict(list)
    total_samples = 0
    total_batches = 0

    if cfg.enable_q_filter:
        print("Pre-computing Q-filter values for all epochs...")
        all_batches = []
        all_values = []

        agent.eval()
        residual_policy.eval()
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Computing V", leave=False):
                batch = dict_to_device(batch, device)
                values = compute_values(agent, residual_policy, batch, device)
                all_values.append(values)
                all_batches.append(batch)

        all_values = torch.cat(all_values)
        weights = compute_q_filter(
            all_values,
            iteration=iteration,
            warmup_iterations=cfg.q_filter_warmup_iterations,
            min_weight=cfg.q_filter_min_weight
        )
        value_stats = {
            'mean': all_values.mean().item(),
            'std': all_values.std().item(),
            'weight_mean': weights.mean().item(),
            'in_warmup': iteration < cfg.q_filter_warmup_iterations
        }
        print(f"Value stats: {value_stats}")

        cached_dataloader = list(zip(all_batches, weights.chunk(len(all_batches))))
    else:
        cached_dataloader = None

    for epoch in trange(num_epochs, desc=f"{prefix.upper()} Epochs"):
        epoch_dataloader = cached_dataloader if cfg.enable_q_filter else dataloader

        metrics = train_bc_epoch(
            epoch_dataloader,
            agent,
            optimizer,
            scheduler,
            cfg,
            device,
            prefix=prefix,
            epoch=epoch,
            use_cached_values=cfg.enable_q_filter
        )

        total_samples += len(dataloader.dataset)
        total_batches += len(dataloader)

        for k, v in metrics.items():
            epoch_metrics[k].append(v)

    agg_metrics = {
        f"{prefix}_training/total_samples": total_samples,
        f"{prefix}_training/total_batches": total_batches,
        f"{prefix}_training/num_epochs": num_epochs,
    }

    for k, v in epoch_metrics.items():
        mean_val = np.mean(v)
        std_val = np.std(v)
        agg_metrics[f"{k}_mean"] = mean_val
        agg_metrics[f"{k}_std"] = std_val

        if any(key in k for key in ['loss', 'value']):
            agg_metrics[f"{k}_min"] = np.min(v)
            agg_metrics[f"{k}_max"] = np.max(v)

    return agg_metrics


def merge_metrics(metrics_list):
    merged_metrics = {}
    for metrics in metrics_list:
        merged_metrics.update(metrics)
    return merged_metrics


def format_maniskill_obs_for_agent(obs, device):
    """
    Format ManiSkill observation for use with the BC agent.
    
    Args:
        obs: Observation from ManiSkill environment
        device: Device to put tensors on
        
    Returns:
        Formatted observation for the agent
    """
    result = {}
    
    # Handle dictionary observations from ManiSkill
    if isinstance(obs, dict):
        # If it has state key (ManiSkill environment)
        if "state" in obs:
            state = obs["state"].to(device)
            if state.dim() == 1:
                state = state.unsqueeze(0)  # Add batch dimension
                
            # Create placeholder
            # TODO Hardcoded
            robot_dim = min(8, state.shape[-1])  # First 8 dims are robot state
            
            # Split the state
            result["robot_state"] = state[..., :robot_dim]
            result["parts_poses"] = state[..., robot_dim:] if state.shape[-1] > robot_dim else torch.zeros((state.shape[0], 0), device=device)
            
            # Also store full state for compatibility
            result["obs"] = state
            
            return result
    
    if isinstance(obs, torch.Tensor):
        state = obs.to(device)
        if state.dim() == 1:
            state = state.unsqueeze(0) 
            
        # TODO Hardcoded
        robot_dim = min(8, state.shape[-1])  # First 8 dims are robot state
        
        # Split the state
        result["robot_state"] = state[..., :robot_dim]
        result["parts_poses"] = state[..., robot_dim:] if state.shape[-1] > robot_dim else torch.zeros((state.shape[0], 0), device=device)
        
        result["obs"] = state
        
        return result
    
    # Handle numpy arrays
    if isinstance(obs, np.ndarray):
        state = torch.from_numpy(obs).to(device)
        if state.dim() == 1:
            state = state.unsqueeze(0)  # Add batch dimension
            
        # TODO Hardcoded
        robot_dim = min(8, state.shape[-1])  # First 8 dims are robot state
        
        # Split the state
        result["robot_state"] = state[..., :robot_dim]
        result["parts_poses"] = state[..., robot_dim:] if state.shape[-1] > robot_dim else torch.zeros((state.shape[0], 0), device=device)
        
        # Also store full state for compatibility
        result["obs"] = state
        
        return result
    
    print(f"Warning: Unexpected observation format: {type(obs)}. Creating dummy observation.")
    result["robot_state"] = torch.zeros((1, 8), device=device)
    result["parts_poses"] = torch.zeros((1, 0), device=device)
    result["obs"] = torch.zeros((1, 8), device=device)
    
    return result


@hydra.main(
    config_path="../config",
    config_name="base_maniskill_ri",
    version_base="1.2",
)
def main(cfg: DictConfig):

    OmegaConf.set_struct(cfg, False)

    # TRY NOT TO MODIFY: seeding
    if cfg.seed is None:
        cfg.seed = random.randint(0, 2**32 - 1)

    # Ensure valid task name
    if "task" not in cfg.env:
        raise ValueError("Missing required parameter: cfg.env.task")

    run_name = f"{cfg.actor.residual_policy._target_.split('.')[-1]}_RI_{cfg.seed}_{int(time.time())}"

    run_directory = f"runs/{run_name}"
    run_directory += "-delete" if cfg.debug else ""
    print(f"Run directory: {run_directory}")

    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    torch.backends.cudnn.deterministic = cfg.torch_deterministic

    gpu_id = cfg.gpu_id
    device = torch.device(f"cuda:{gpu_id}")
    torch.cuda.set_device(gpu_id)

    # Create ManiSkill environment
    print(f"Creating ManiSkill environment: {cfg.env.task}")
    # Get reward mode from config, defaulting to "sparse" if not specified
    reward_mode = cfg.env.get("reward_mode", "sparse")
    print(f"Using reward mode: {reward_mode}")
    
    # Ensure num_envs is set properly
    print(f"Configured num_envs: {cfg.num_envs}")
    if cfg.num_envs > 100:
        print(f"Using large environment count ({cfg.num_envs}) - ensuring environment is properly configured")
    
    env = create_maniskill_env(
        task_name=cfg.env.task,
        num_envs=cfg.num_envs,
        observation_space="state",  # Always use state-based observations for RL
        control_mode=cfg.control.control_mode,
        sim_backend="physx_cuda",  # physx_cuda-detect
        record_video=False,  # Disable recording for training
        fallback_to_cpu=False,  # Fall back to CPU if GPU PhysX is not available
        reward_mode=reward_mode,  # Pass reward mode from config
    )

    # Load the behavior cloning actor
    base_cfg, base_wts = get_model_from_api_or_cached(
        cfg.base_policy.wandb_id,
        wt_type=cfg.base_policy.wt_type,
        wandb_mode=cfg.wandb.mode,
    )

    merge_base_bc_config_with_root_config(cfg, base_cfg)
    cfg.actor_name = f"residual_{cfg.base_policy.actor.name}"

    # Create dataset using ManiSkill dataset class
    cfg.data.h5_path = '/home/gdch/.maniskill/demos/PickCube-v1/motionplanning/trajectory.state.pd_ee_delta_pos.physx_cuda.h5'
    if cfg.data.h5_path:
        print(f"Loading ManiSkill demonstrations from: {cfg.data.h5_path}")
        demo_dataset = ManiSkillStateDataset(
            dataset_paths=cfg.data.h5_path,
            pred_horizon=cfg.data.pred_horizon,
            obs_horizon=cfg.data.obs_horizon,
            action_horizon=cfg.data.action_horizon,
            data_subset=cfg.data.data_subset,
            control_mode=cfg.control.control_mode,
            predict_past_actions=cfg.data.predict_past_actions,
            pad_after=cfg.data.get("pad_after", True),
        )
    else:
        # Fallback to standard StateDataset for compatibility
        print(f"Loading standard demonstrations from: {base_cfg.data_path}")
        demo_dataset = StateDataset(
            dataset_paths=[Path(p) for p in to_native(base_cfg.data_path)],
            pred_horizon=cfg.data.pred_horizon,
            obs_horizon=cfg.data.obs_horizon,
            action_horizon=cfg.data.action_horizon,
            data_subset=cfg.data.data_subset,
            control_mode=cfg.control.control_mode,
            predict_past_actions=cfg.data.predict_past_actions,
            pad_after=cfg.data.get("pad_after", True),
            max_episode_count=cfg.data.get("max_episode_count", None),
            include_future_obs=cfg.data.include_future_obs,
        )

    train_size = int(len(demo_dataset) * (1 - cfg.data.test_split))
    test_size = len(demo_dataset) - train_size
    print(f"Splitting dataset into {train_size} train and {test_size} test samples.")
    train_dataset, test_dataset = random_split(demo_dataset, [train_size, test_size])

    demo_testload_kwargs = dict(
        dataset=test_dataset,
        batch_size=cfg.base_training.batch_size,
        num_workers=0,
        shuffle=True,
        pin_memory=True,
        drop_last=False,
        persistent_workers=False,
    )
    demo_testloader = (
        FixedStepsDataloader(
            **demo_testload_kwargs,
            n_batches=max(
                int(round(cfg.base_training.steps_per_epoch * cfg.data.test_split)), 1
            ),
        )
        if cfg.base_training.steps_per_epoch != -1
        else DataLoader(**demo_testload_kwargs)
    )
    
    il_dataloader = FixedStepsDataloader(
        dataset=train_dataset,
        batch_size=cfg.base_training.batch_size,
        num_workers=0,
        shuffle=True,
        pin_memory=True,
        persistent_workers=False,
        n_batches=train_size//cfg.base_training.batch_size
    )
    print(f"Training dataset size: {len(train_dataset)} | Test dataset size: {len(test_dataset)}")
    print(f"IL dataloader size: {len(il_dataloader)} | Test dataloader size: {len(demo_testloader)}")

    steps_per_iteration = cfg.data_collection_steps

    # Create the residual policy
    agent = ResidualDiffusionPolicy(device, base_cfg)
    if cfg.load_pretrained_wts:
        agent.load_base_state_dict(base_wts)
    agent.set_normalizer(demo_dataset.normalizer.to(device))
    agent.to(device)
    agent.eval()
    
    # Set the inference steps of the actor
    if isinstance(agent, DiffusionPolicy):
        agent.inference_steps = 4

    residual_policy = agent.residual_policy

    # Create optimizers for base actor, residual actor, and critic
    opt_actor = torch.optim.AdamW(
        params=agent.base_actor_parameters,
        lr=cfg.base_training.actor_lr,
        weight_decay=cfg.base_regularization.weight_decay,
    )
    lr_sche_actor = get_scheduler(
        name=cfg.base_lr_scheduler.name,
        optimizer=opt_actor,
        num_warmup_steps=cfg.base_lr_scheduler.warmup_steps,
        num_training_steps=len(demo_dataset) // cfg.base_training.batch_size * cfg.base_training.num_epochs,
    )

    opt_res_actor = optim.AdamW(
        agent.actor_parameters,
        lr=cfg.learning_rate_actor,
        eps=1e-5,
        weight_decay=1e-6,
    )
    lr_sche_res_actor = get_scheduler(
        name=cfg.lr_scheduler.name,
        optimizer=opt_res_actor,
        num_warmup_steps=cfg.lr_scheduler.actor_warmup_steps,
        num_training_steps=cfg.num_iterations,
    )

    opt_res_critic = optim.AdamW(
        agent.critic_parameters,
        lr=cfg.learning_rate_critic,
        eps=1e-5,
        weight_decay=1e-6,
    )
    lr_sche_res_critic = get_scheduler(
        name=cfg.lr_scheduler.name,
        optimizer=opt_res_critic,
        num_warmup_steps=cfg.lr_scheduler.critic_warmup_steps,
        num_training_steps=cfg.num_iterations,
    )

    optimizers = [("residual_actor", opt_res_actor), ("residual_critic", opt_res_critic), ("actor", opt_actor)]
    schedulers = [("residual_actor", lr_sche_res_actor),
                 ("residual_critic", lr_sche_res_critic), ("actor", lr_sche_actor)]

    # Set up replay buffer if enabled
    if cfg.enable_rl_replay:
        print("Enabling RL replay buffer")
        buffer = RolloutBuffer(
            max_size=cfg.base_bc.replay_buffer_size,
            state_dim=agent.obs_dim,
            action_dim=agent.action_dim,
            pred_horizon=agent.pred_horizon,
            obs_horizon=agent.obs_horizon,
            action_horizon=agent.action_horizon,
            device=device,
            predict_past_actions=cfg.data.predict_past_actions,
            include_future_obs=cfg.data.include_future_obs,
        )
        buffer.set_normalizer(demo_dataset.normalizer)

    print(f"PPO batch size: {cfg.batch_size}; mini-batch size: {cfg.minibatch_size}")
    print(f"Total RL timesteps: {cfg.total_timesteps}; Num iterations: {cfg.num_iterations}")
    print(
        f"BC dataset size {len(train_dataset)}; BC batch size: {cfg.base_training.batch_size}; Steps per epoch: {cfg.base_training.steps_per_epoch}")

    print(OmegaConf.to_yaml(cfg, resolve=True))

    # Initialize WandB
    print(f"Initializing WandB with project: {cfg.wandb.project}, entity: {cfg.wandb.entity}, mode: {cfg.wandb.mode}")
    run = wandb.init(
        id=cfg.wandb.continue_run_id,
        resume=None if cfg.wandb.continue_run_id is None else "must",
        project=cfg.wandb.project,
        entity=cfg.wandb.entity,
        config=OmegaConf.to_container(cfg, resolve=True),
        name=run_name,
        save_code=True,
        mode=cfg.wandb.mode if not cfg.debug else "disabled",
    )
    
    # Ensure config is properly logged
    wandb.config.update(OmegaConf.to_container(cfg, resolve=True))
    
    # Set up periodic syncing for offline mode
    use_sync_hook = cfg.wandb.mode == "offline" and wandb.run is not None
    if use_sync_hook:
        print("Setting up WandB sync hooks for offline mode")
        from wandb_osh.hooks import TriggerWandbSyncHook, _comm_default_dir
        trigger_sync = TriggerWandbSyncHook(
            communication_dir=os.environ.get("WANDB_OSH_COMM_DIR", _comm_default_dir),
        )

    # Handle continuing a run
    if cfg.wandb.continue_run_id is not None:
        print(f"[WandB] Continuing run {cfg.wandb.continue_run_id}, {run.name}")

        run_id = f"{cfg.wandb.project}/{cfg.wandb.continue_run_id}"

        # Load the weights from the run
        _, wts = get_model_from_api_or_cached(
            run_id, "latest", wandb_mode=cfg.wandb.mode
        )

        print(f"[WandB] Loading weights from {wts}")

        run_state_dict = torch.load(wts)
        if "model_state_dict" in run_state_dict:
            agent.load_state_dict(run_state_dict["model_state_dict"])
            for (name, opt), (__, scheduler) in zip(optimizers, schedulers):
                opt.load_state_dict(run_state_dict[f"{name}_optimizer_state_dict"])
                scheduler.load_state_dict(run_state_dict[f"{name}_scheduler_state_dict"])

        # Set the best test loss and success rate to the one from the run
        try:
            best_eval_success_rate = run.summary["eval/best_eval_success_rate"]
        except KeyError:
            best_eval_success_rate = run.summary["eval/success_rate"]

        iteration = run.summary["iteration"]
        global_step = run.step
        bc_step = 0

    else:
        global_step = 0
        iteration = 0
        best_eval_success_rate = 0.0
        bc_step = 0

    # Initialize tensors for rollout data
    obs: torch.Tensor = torch.zeros(
        (
            steps_per_iteration,
            cfg.num_envs,
            residual_policy.obs_dim,
        )
    )
    
    # Get the actual action space shape, not the environment shape which might be incorrect
    print(f"Environment action_space.shape: {env.action_space.shape}")
    # Ensure we're using the correct action shape - ManiSkill environments may have unexpected shapes
    action_dim = env.action_space.shape[-1] if len(env.action_space.shape) > 0 else env.action_space.shape[0]
    action_shape = (action_dim,)
    print(f"Using action shape: {action_shape} (extracted from environment)")
    
    actions = torch.zeros((steps_per_iteration, cfg.num_envs) + action_shape)
    full_nactions = torch.zeros(
        (steps_per_iteration, cfg.num_envs) + action_shape
    )
    logprobs = torch.zeros((steps_per_iteration, cfg.num_envs))
    rewards = torch.zeros((steps_per_iteration, cfg.num_envs))
    dones = torch.zeros((steps_per_iteration, cfg.num_envs))
    values = torch.zeros((steps_per_iteration, cfg.num_envs))

    start_time = time.time()
    training_cum_time = 0
    running_mean_success_rate = 0.0

    next_done = torch.zeros(cfg.num_envs)
    next_obs, info = env.reset()
    agent.reset()

    # Create model save dir
    model_save_dir: Path = Path("models") / wandb.run.name
    model_save_dir.mkdir(parents=True, exist_ok=True)
    
    num_bc_epochs = cfg.initial_num_bc_epochs
    rl_per_bc = cfg.rl_per_bc
    rl_counter = 0

    # Main training loop
    while global_step < cfg.total_timesteps:
        iteration += 1
        # Calculate how many training iterations we've done
        training_iterations = iteration - cfg.eval_first
        training_iterations -= (iteration - cfg.eval_first) // cfg.eval_interval
        print(f"Iteration: {iteration}/{cfg.num_iterations}")
        print(f"Run name: {run_name}")
        iteration_start_time = time.time()

        # Determine if this is an evaluation iteration
        eval_rl = (iteration - int(cfg.eval_first)) % cfg.eval_interval == 0

        # Reset the env to have more consistent results
        if eval_rl or cfg.reset_every_iteration:
            if not cfg.eval_first or iteration != 1:
                next_obs, info = env.reset()
                agent.reset()

        print(f"Eval mode: {eval_rl}")

        # Train the base policy with BC if needed
        print("Training base with BC...")
        if not eval_rl and num_bc_epochs > 0 and rl_counter % rl_per_bc == 0:
            bc_steps_this_iter = 0
            all_metrics = []

            if cfg.enable_rl_replay and len(buffer) > 0:
                print("Training on expert data and buffer data")
                buffer_dataloader = DataLoader(
                    buffer, batch_size=cfg.base_training.batch_size,
                    shuffle=True, num_workers=0, pin_memory=True
                )
                print(f"Buffer dataloader size: {len(buffer_dataloader)}")

                merged_dataloader = merge_dataloaders(
                    il_dataloader,
                    buffer_dataloader,
                    batch_size=cfg.base_training.batch_size
                )
                il_metrics = train_bc_epochs(
                    merged_dataloader, agent, residual_policy,
                    opt_actor, lr_sche_actor, cfg, device,
                    num_epochs=num_bc_epochs,
                    iteration=iteration,
                    prefix="bc"
                )
                bc_steps_this_iter += il_metrics["bc_training/total_batches"]
                all_metrics.append(il_metrics)
            else:
                print("Training on expert data only")
                il_metrics = train_bc_epochs(
                    il_dataloader, agent, residual_policy,
                    opt_actor, lr_sche_actor, cfg, device,
                    num_epochs=num_bc_epochs,
                    iteration=iteration,
                    prefix="bc"
                )
                bc_steps_this_iter += il_metrics["bc_training/total_batches"]
                all_metrics.append(il_metrics)

            # Update global BC step counter
            bc_step += bc_steps_this_iter

            # Add combined metrics
            combined_metrics = {
                "bc_training/total_steps": bc_step,
                "bc_training/num_bc_epochs": num_bc_epochs,
                "bc_training/steps_this_iter": bc_steps_this_iter,
                "bc_training/steps_per_rl_step": bc_steps_this_iter / cfg.num_envs,
                "training/bc_learning_rate": opt_actor.param_groups[0]["lr"],
            }
            all_metrics.append(combined_metrics)

            print(f"Completed BC training with {bc_steps_this_iter} steps")
            merged_metrics = merge_metrics(all_metrics)
            # Add iteration for tracking
            merged_metrics["iteration"] = iteration
            wandb.log(merged_metrics, step=global_step)
            
            # If using offline mode, manually trigger sync
            if use_sync_hook:
                print(f"Triggering WandB sync after BC training...")
                trigger_sync()
        else:
            print("BC training skipped")

        # Evaluate IL policy
        print("Evaluating BC policy...")
        il_test_dataloader = iter(demo_testloader)
        agent.eval()
        eval_loss = []
        test_tepoch = tqdm(il_test_dataloader, desc="BC Eval")
        for test_batch in test_tepoch:
            with torch.no_grad():
                test_batch = dict_to_device(test_batch, device)
                loss, _ = agent.compute_loss(test_batch)
                test_loss_cpu = loss.item()
                eval_loss.append(test_loss_cpu)
                test_tepoch.set_postfix(loss=test_loss_cpu)
        test_tepoch.close()

        mean_il_advantage, mean_il_value = evaluate_il_online(agent, residual_policy, demo_testloader, device)
        
        # Log BC evaluation metrics
        eval_metrics = {
            "eval/bc_loss": np.mean(eval_loss),
            "eval/mean_bc_advantage": mean_il_advantage,
            "eval/mean_bc_value": mean_il_value,
            "iteration": iteration,  # Add iteration for tracking
        }
        wandb.log(eval_metrics, step=global_step)
        
        # If using offline mode, manually trigger sync
        if use_sync_hook:
            print(f"Triggering WandB sync after BC evaluation...")
            trigger_sync()

        # ROLLOUT: Collecting online data for RL training
        print("Collecting online data...")
        task_success = torch.zeros(cfg.num_envs, dtype=torch.bool)

        # TODO: not sure if this is correct

        for step in range(0, steps_per_iteration):
            if not eval_rl:
                # Only count environment steps during training
                global_step += cfg.num_envs

            with torch.no_grad():
                formatted_obs = format_maniskill_obs_for_agent(next_obs, device)
                
                
                nobs = agent._normalized_obs([formatted_obs])
                
                # Get the normalized action from the base network
                base_naction = agent._normalized_action(nobs)[:, 0, :]
                
                # Process the observation for the residual policy
                next_nobs = agent.process_obs(formatted_obs)
                next_residual_nobs = torch.cat([next_nobs, base_naction], dim=-1)

            dones[step] = next_done
            obs[step] = next_residual_nobs

            with torch.no_grad():
                residual_naction_samp, logprob, _, value, naction_mean = (
                    residual_policy.get_action_and_value(next_residual_nobs)
                )

            residual_naction = residual_naction_samp if not eval_rl else naction_mean
            naction = base_naction + residual_naction * residual_policy.action_scale

            action = agent.normalizer(naction, "action", forward=False)
            next_obs, reward, terminated, truncated, info = env.step(action)

            if cfg.truncation_as_done:
                next_done = terminated | truncated
            else:
                next_done = terminated

            values[step] = value.flatten().cpu()
            actions[step] = residual_naction.cpu()
            logprobs[step] = logprob.cpu()
            rewards[step] = reward.view(-1).cpu()
            next_done = next_done.view(-1).cpu()
            full_nactions[step] = naction.cpu()

            # Track task success
            if 'success' in info:
                task_success = task_success | info['success'].view(-1).cpu()
                
            if step > 0 and (env_step := step * 1) % 100 == 0:
                print(
                    f"env_step={env_step}, global_step={global_step}, mean_reward={rewards[:step+1].sum(dim=0).mean().item()} fps={env_step * cfg.num_envs / (time.time() - iteration_start_time):.2f}"
                )

        # Calculate success rate - any success during the episode counts
        success_rate = task_success.float().mean().item()
        running_mean_success_rate = 0.5 * running_mean_success_rate + 0.5 * success_rate

        print(
            f"SR: {success_rate:.4%}, SR mean: {running_mean_success_rate:.4%}, SPS: {steps_per_iteration * cfg.num_envs / (time.time() - iteration_start_time):.2f}"
        )

        # EVALUATION
        if eval_rl:
            # Use ManiSkill-specific evaluation
            print(f"Running evaluation for epoch {iteration}")
            
            try:
                best_eval_success_rate = evaluate_and_log_maniskill(
                    config=cfg,
                    actor=agent,
                    best_success_rate=best_eval_success_rate,
                    epoch_idx=iteration,
                    wandb=wandb,
                    device=device,
                    verbose=True,
                    sim_backend="physx_cuda",
                )
            except Exception as e:
                print(f"Error during evaluation: {e}")
                print("Continuing training...")
                # Skip the rest of the evaluation phase
                continue

            # Save the model if the evaluation success rate improves            
            model_path = str(model_save_dir / f"actor_chkpt_best_success_rate.pt")
            save_dict = {
                # Save the weights of the residual policy (base + residual)
                "model_state_dict": agent.state_dict(),
                "global_step": global_step,
                "success_rate": success_rate,
                "task_success": task_success.cpu().numpy().tolist(),
                "iteration": iteration,
                "config": OmegaConf.to_container(cfg, resolve=True),
            }
            for (name, opt), (__, scheduler) in zip(optimizers, schedulers):
                save_dict[f"{name}_optimizer_state_dict"] = opt.state_dict()
                save_dict[f"{name}_scheduler_state_dict"] = scheduler.state_dict()
            torch.save(save_dict, model_path)

            wandb.save(model_path)
            print(f"Model saved to {model_path}")

            # Add successful trajectories to the replay buffer
            if cfg.enable_rl_replay and success_rate >= cfg.replay_from_sr:
                add_successful_trajectories_to_buffer(
                    buffer=buffer,
                    obs=obs,
                    full_nactions=full_nactions,
                    rewards=rewards,
                    task_success=task_success,
                    device=device,
                    max_new_samples=cfg.max_replay_new_samples
                )
                buffer.rebuild_seq_indices()

            # Log evaluation results
            eval_metrics = {
                "eval/success_rate": success_rate,
                "eval/best_eval_success_rate": best_eval_success_rate,
                "iteration": iteration,
                "global_step": global_step,
            }
            wandb.log(eval_metrics, step=global_step)
            
            # If using offline mode, manually trigger sync
            if use_sync_hook:
                print(f"Triggering WandB sync after evaluation...")
                trigger_sync()
            
            continue

        # TRAINING RL
        print("Training the residual policy with RL...")

        # Check the original shapes before reshaping
        print(f"Original shapes - obs: {obs.shape}, actions: {actions.shape}, residual_policy.obs_dim: {residual_policy.obs_dim}")
        
        # Reshape observations and actions for batch processing
        b_obs = obs.reshape((-1, residual_policy.obs_dim))

        action_shape = env.action_space.shape
        print(f"Action space shape: {action_shape}")
        
        # Detect significant shape issues
        print(f"Action tensor size: {actions.numel()} elements, expected shape: {actions.shape}")
        expected_size = steps_per_iteration * cfg.num_envs * action_dim
        if actions.numel() != expected_size:
            print(f"WARNING: Action tensor size mismatch! Got {actions.numel()}, expected {expected_size}")
            
        if len(actions.shape) > 2:
            print(f"Handling complex action tensor with shape {actions.shape}")
            last_dim = actions.shape[-1]
            
            if last_dim == action_dim:
                b_actions = actions.reshape(-1, action_dim)
                print(f"Reshaped actions to {b_actions.shape}")
            else:
                # This is a more complex case, we need to be careful
                print(f"WARNING: Last dimension doesn't match action_dim: {last_dim} vs {action_dim}")
                # Create a new tensor with correct shape
                b_actions = torch.zeros((steps_per_iteration * cfg.num_envs, action_dim))
                # Try to extract appropriate data if possible
                if actions.numel() >= b_actions.numel():
                    print("Taking subset of action data")
                    # Take the first elements needed
                    flat_actions = actions.flatten()[:b_actions.numel()]
                    b_actions = flat_actions.reshape(-1, action_dim)
                else:
                    print("Not enough action data, using zeros")
                    # We'll just use the zeros we initialized with
                
                print(f"Created new action tensor with shape {b_actions.shape}")
        else:
            # Simple case - just reshape normally
            b_actions = actions.reshape(-1, action_dim)
            print(f"Simple reshape to {b_actions.shape}")
            
        b_logprobs = logprobs.reshape(-1)
        b_values = values.reshape(-1)
        
        print(f"Reshaped - b_obs: {b_obs.shape}, b_actions: {b_actions.shape}, b_logprobs: {b_logprobs.shape}, b_values: {b_values.shape}")

        # Use the same observation formatting for ManiSkill compatibility
        with torch.no_grad():
            # Format observation for the base policy
            formatted_obs = format_maniskill_obs_for_agent(next_obs, device)
            
            # Use the _normalized_action method to get the base policy action
            # First normalize the observation
            nobs = agent._normalized_obs([formatted_obs])
            
            # Get the normalized action from the base network
            base_naction = agent._normalized_action(nobs)[:, 0, :]
            
            # Process observation for residual policy
            next_nobs = agent.process_obs(formatted_obs)
            next_residual_nobs = torch.cat([next_nobs, base_naction], dim=-1)
            next_value = residual_policy.get_value(next_residual_nobs).reshape(1, -1).cpu()

        # bootstrap value if not done
        advantages, returns = calculate_advantage(
            values,
            next_value,
            rewards,
            dones,
            next_done,
            steps_per_iteration,
            cfg.discount,
            cfg.gae_lambda,
        )

        b_advantages = advantages.reshape(-1).cpu()
        b_returns = returns.reshape(-1).cpu()

        # Optimizing the policy and value network
        # Make sure batch size does not exceed the actual data size
        actual_batch_size = min(cfg.batch_size, b_obs.shape[0])
        print(f"Using batch size: {actual_batch_size} (configured: {cfg.batch_size}, available: {b_obs.shape[0]})")
        
        b_inds = np.arange(actual_batch_size)
        clipfracs = []
        rl_grad_norms = []
        for epoch in trange(cfg.update_epochs, desc="Policy update"):
            early_stop = False

            np.random.shuffle(b_inds)
            minibatch_size = min(cfg.minibatch_size, actual_batch_size)
            for start in range(0, actual_batch_size, minibatch_size):
                end = min(start + minibatch_size, actual_batch_size)
                mb_inds = b_inds[start:end]

                # Get the minibatch and place it on the device
                mb_obs = b_obs[mb_inds].to(device)
                mb_actions = b_actions[mb_inds].to(device)
                mb_logprobs = b_logprobs[mb_inds].to(device)
                mb_advantages = b_advantages[mb_inds].to(device)
                mb_returns = b_returns[mb_inds].to(device)
                mb_values = b_values[mb_inds].to(device)

                # Calculate the loss
                # Debug shape mismatch
                try:
                    _, newlogprob, entropy, newvalue, action_mean = (
                        residual_policy.get_action_and_value(mb_obs, mb_actions)
                    )
                except ValueError as e:
                    print(f"Error in shape: {e}")
                    # Fix action shape if it's wrong - need to ensure it's [batch_size, action_dim]
                    if len(mb_actions.shape) > 2:
                        print(f"Reshaping actions from {mb_actions.shape} to [{mb_actions.shape[0]}, {mb_actions.shape[-1]}]")
                        mb_actions = mb_actions.reshape(mb_actions.shape[0], mb_actions.shape[-1])
                    _, newlogprob, entropy, newvalue, action_mean = (
                        residual_policy.get_action_and_value(mb_obs, mb_actions)
                    )
                logratio = newlogprob - mb_logprobs
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [
                        ((ratio - 1.0).abs() > cfg.clip_coef).float().mean().item()
                    ]

                if cfg.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (
                        mb_advantages.std() + 1e-8
                    )

                policy_loss = 0

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(
                    ratio, 1 - cfg.clip_coef, 1 + cfg.clip_coef
                )
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if cfg.clip_vloss:
                    v_loss_unclipped = (newvalue - mb_returns) ** 2
                    v_clipped = mb_values + torch.clamp(
                        newvalue - mb_values,
                        -cfg.clip_coef,
                        cfg.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - mb_returns) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean()

                # Entropy loss
                entropy_loss = entropy.mean() * cfg.ent_coef

                ppo_loss = pg_loss - entropy_loss

                # Add the auxiliary regularization loss
                residual_l1_loss = torch.mean(torch.abs(action_mean))
                residual_l2_loss = torch.mean(torch.square(action_mean))

                # Normalize the losses so that each term has the same scale
                if iteration > cfg.n_iterations_train_only_value:
                    # Scale the losses using the calculated scaling factors
                    policy_loss += ppo_loss
                    policy_loss += cfg.residual_l1 * residual_l1_loss
                    policy_loss += cfg.residual_l2 * residual_l2_loss

                rl_loss: torch.Tensor = policy_loss * cfg.rl_coef
                value_loss: torch.Tensor = v_loss * cfg.vf_coef

                loss = rl_loss + value_loss

                opt_res_actor.zero_grad()
                opt_res_critic.zero_grad()

                loss.backward()
                grad_norm = nn.utils.clip_grad_norm(
                    residual_policy.parameters(), cfg.max_grad_norm
                )
                rl_grad_norms.append(grad_norm.item())

                opt_res_actor.step()
                opt_res_critic.step()

                if cfg.target_kl is not None and approx_kl > cfg.target_kl:
                    print(
                        f"Early stopping at epoch {epoch} due to reaching max kl: {approx_kl:.4f} > {cfg.target_kl:.4f}"
                    )
                    early_stop = True
                    break

            if early_stop:
                break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        action_norms = torch.norm(b_actions, dim=-1).cpu()

        training_cum_time += time.time() - iteration_start_time
        sps = int(global_step / training_cum_time) if training_cum_time > 0 else 0

        # Log all metrics to wandb
        log_data = {
            "grads/rl_mean_grad_norms": np.mean(rl_grad_norms),
            "training/rl_learning_rate_actor": opt_res_actor.param_groups[0]["lr"],
            "training/rl_learning_rate_critic": opt_res_critic.param_groups[0]["lr"],
            "charts/SPS": sps,
            "charts/rewards": rewards.sum().item(),
            "charts/policy_entropy": entropy.mean().item(),
            "charts/success_rate": success_rate,
            "charts/action_norm_mean": action_norms.mean(),
            "charts/action_norm_std": action_norms.std(),
            "values/advantages": b_advantages.mean().item(),
            "values/returns": b_returns.mean().item(),
            "values/values": b_values.mean().item(),
            "values/mean_logstd": residual_policy.actor_logstd.mean().item(),
            "losses/value_loss": v_loss.item(),
            "losses/policy_loss": pg_loss.item(),
            "losses/total_loss": loss.item(),
            "losses/entropy_loss": entropy_loss.item(),
            "losses/old_approx_kl": old_approx_kl.item(),
            "losses/approx_kl": approx_kl.item(),
            "losses/clipfrac": np.mean(clipfracs),
            "losses/explained_variance": explained_var,
            "losses/residual_l1": residual_l1_loss.item(),
            "losses/residual_l2": residual_l2_loss.item(),
            "histograms/values": wandb.Histogram(values),
            "histograms/returns": wandb.Histogram(b_returns),
            "histograms/advantages": wandb.Histogram(b_advantages),
            "histograms/logprobs": wandb.Histogram(logprobs),
            "histograms/rewards": wandb.Histogram(rewards),
            "histograms/action_norms": wandb.Histogram(action_norms),
            "iteration": iteration,  # Always log iteration count for tracking progress
        }
        
        # Add additional metadata
        log_data["epoch"] = iteration
        log_data["global_step"] = global_step 
        
        # Log to wandb
        wandb.log(log_data, step=global_step)
        
        # If using offline mode, manually trigger sync
        if use_sync_hook:
            print(f"Triggering WandB sync at iteration {iteration}...")
            trigger_sync()

        # Step the learning rate scheduler
        lr_sche_res_actor.step()
        lr_sche_res_critic.step()

        # Checkpoint every cfg.checkpoint_interval steps
        if cfg.checkpoint_interval > 0 and iteration % cfg.checkpoint_interval == 0:
            model_path = str(model_save_dir / f"actor_chkpt_{iteration}.pt")
            torch.save(
                {
                    "model_state_dict": agent.state_dict(),
                    "optimizer_actor_state_dict": opt_res_actor.state_dict(),
                    "optimizer_critic_state_dict": opt_res_critic.state_dict(),
                    "scheduler_actor_state_dict": lr_sche_res_actor.state_dict(),
                    "scheduler_critic_state_dict": lr_sche_res_critic.state_dict(),
                    "config": OmegaConf.to_container(cfg, resolve=True),
                    "success_rate": success_rate,
                    "iteration": iteration,
                },
                model_path,
            )

            wandb.save(model_path)
            print(f"Model saved to {model_path}")

        # Update RL counter for BC scheduling
        rl_counter += 1
        rl_counter = rl_counter % rl_per_bc

        # Print some stats at the end of the iteration
        print(
            f"Iteration {iteration}/{cfg.num_iterations}, global step {global_step}, SPS {sps}"
        )

        print(
            f"At iteration {iteration}, we've done {training_iterations} training iterations "
            f"and {iteration - training_iterations} evaluation iterations"
        )

    print(f"Training finished in {(time.time() - start_time):.2f}s")


if __name__ == "__main__":
    main()