# Train the residual policy with with hybrid IL and RL
# Support Residual MLP and Residual Diffusion
import random
import time
import hydra
from tqdm import tqdm, trange
from omegaconf import DictConfig, OmegaConf
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import trange
import wandb
import warnings
from collections import defaultdict

import furniture_bench  # noqa
from src.behavior.diffusion import DiffusionPolicy
from src.behavior.residual_diffusion import ResidualDiffusionPolicy
from src.behavior.residual_mlp import ResidualMlpPolicy
from src.behavior.base import Actor
from src.dataset.dataset import StateDataset
from torch.utils.data import DataLoader, ConcatDataset, random_split
from src.common.hydra import to_native
from src.dataset.dataloader import FixedStepsDataloader
from src.dataset.rollout_buffer import RolloutBuffer
from src.common.pytorch_util import dict_to_device
from src.eval.eval_utils import get_model_from_api_or_cached
from diffusers.optimization import get_scheduler
from src.gym.env_rl_wrapper import RLPolicyEnvWrapper
from src.gym import turn_off_april_tags
from src.common.config_util import merge_base_bc_config_with_root_config
from furniture_bench.envs.furniture_rl_sim_env import FurnitureRLSimEnv
from furniture_bench.envs.observation import DEFAULT_STATE_OBS

warnings.filterwarnings('ignore', category=UserWarning)

# Register the eval resolver for omegaconf
OmegaConf.register_new_resolver("eval", eval)


@torch.no_grad()
def calculate_advantage(
    values: torch.Tensor,
    next_value: torch.Tensor,
    rewards: torch.Tensor,
    dones: torch.Tensor,
    next_done: torch.Tensor,
    steps_per_iteration: int,
    discount: float,
    gae_lambda: float,
):
    advantages = torch.zeros_like(rewards)
    lastgaelam = 0
    for t in reversed(range(steps_per_iteration)):
        if t == steps_per_iteration - 1:
            nextnonterminal = 1.0 - next_done.to(torch.float)
            nextvalues = next_value
        else:
            nextnonterminal = 1.0 - dones[t + 1].to(torch.float)
            nextvalues = values[t + 1]

        delta = rewards[t] + discount * nextvalues * nextnonterminal - values[t]
        advantages[t] = lastgaelam = (
            delta + discount * gae_lambda * nextnonterminal * lastgaelam
        )
    returns = advantages + values
    return advantages, returns


@torch.no_grad()
def compute_values(agent, residual_policy, batch, device):
    agent.eval()
    residual_policy.eval()
    nobs = agent._training_obs(batch, flatten=agent.flatten_obs).to(device)
    naction = agent._normalized_action(nobs).to(device)
    obs0 = batch["obs"][:, 0, :].to(device)
    action0 = naction[:, 0, :].to(device)
    residual_nobs = torch.cat([obs0, action0], dim=-1).to(device)
    values = residual_policy.get_value(residual_nobs).squeeze()
    return values


@torch.no_grad()
def compute_advantages_and_values(agent, residual_policy, batch, device):
    if "returns" not in batch:
        returns = torch.zeros_like(batch["rewards"])
    else:
        returns = batch["returns"]
    values = compute_values(agent, residual_policy, batch, device)
    advantages = returns - values
    return advantages, values


@torch.no_grad()
def evaluate_il_online(agent, residual_policy, dataloader, device):
    """
    Compute mean advantage and value for the IL policy on the test set.
    Uses the residual policy's critic to estimate values, and computes advantages
    based on returns from the dataset.
    """
    all_advantages = []
    all_values = []

    for batch in dataloader:
        batch = dict_to_device(batch, device)
        advantages, values = compute_advantages_and_values(agent, residual_policy, batch, device)
        all_advantages.append(advantages)
        all_values.append(values)

    all_advantages = torch.cat(all_advantages)
    all_values = torch.cat(all_values)

    mean_advantage = all_advantages.mean().item()
    mean_value = all_values.mean().item()

    print(f"[IL Online] Mean Advantage: {mean_advantage:.4f}, Mean Value: {mean_value:.4f}")

    return mean_advantage, mean_value


def add_successful_trajectories_to_buffer(
    buffer: RolloutBuffer,
    obs: torch.Tensor,
    full_nactions: torch.Tensor,
    rewards: torch.Tensor,
    n_parts_to_assemble: int,
    device: torch.device,
    max_new_samples: int = None
):
    """
    Add successful trajectories to the rollout buffer for hybrid IL/RL training.

    Args:
        buffer (RolloutBuffer): The rollout buffer to add trajectories to
        obs (torch.Tensor): Observations of shape [steps_per_iteration, num_envs, obs_dim]
        full_nactions (torch.Tensor): Actions of shape [steps_per_iteration, num_envs, action_dim]
        rewards (torch.Tensor): Rewards of shape [steps_per_iteration, num_envs]
        n_parts_to_assemble (int): Number of parts that need to be assembled for success
        device (torch.device): Device to use for tensor operations
    """
    # Find successful trajectories (envs that got reward >= n_parts_to_assemble)
    success_idxs = (rewards.sum(dim=0) >= n_parts_to_assemble).cpu()
    num_success = success_idxs.sum().item()

    if num_success == 0:
        print("No successful trajectories to add to buffer")
        return
    
    if max_new_samples is not None and max_new_samples > 0 and num_success > max_new_samples:
        success_env_indices = torch.where(success_idxs)[0].numpy()
        chosen_indices = np.random.choice(
            success_env_indices,
            size=max_new_samples,
            replace=False
        )
        new_mask = np.zeros_like(success_idxs, dtype=bool)
        new_mask[chosen_indices] = True
        success_idxs = torch.from_numpy(new_mask)
        num_success = max_new_samples 

    print(f"Adding {num_success} successful trajectories to buffer")

    # Extract successful trajectories
    success_obs = obs[:, success_idxs]  # [steps, num_success, obs_dim]
    success_actions = full_nactions[:, success_idxs]  # [steps, num_success, action_dim]
    success_rewards = rewards[:, success_idxs]  # [steps, num_success]

    # Calculate dones based on cumulative rewards reaching n_parts_to_assemble
    cumulative_rewards = success_rewards.cumsum(dim=0)
    success_dones = (cumulative_rewards >= n_parts_to_assemble)

    # Ensure only the first achievement of n_parts_to_assemble is marked as done
    first_done_mask = success_dones.cumsum(dim=0) > 1
    success_dones[first_done_mask] = False

    # Move tensors to appropriate device
    success_obs = success_obs.to(device)
    success_actions = success_actions.to(device)
    success_rewards = success_rewards.to(device)
    success_dones = success_dones.to(device)

    state_dim = buffer.state_dim
    states = success_obs[..., :state_dim]

    buffer.add_trajectories(
        actions=success_actions,
        rewards=success_rewards,
        dones=success_dones,
        states=states
    )

    print(f"Buffer has {buffer.n_trajectories} trajectories with size {buffer.size}")


def train_bc_epoch(
    dataloader,
    agent,
    optimizer,
    scheduler,
    cfg,
    device,
    prefix="bc",
    epoch=0,
    use_cached_values=False,
):
    """Train behavior cloning for one epoch.
    Now handles both regular dataloader and pre-computed (batch, value) pairs.
    """
    all_metrics = defaultdict(list)

    # Create progress bar
    pbar = tqdm(enumerate(dataloader), desc=f"{prefix.upper()} Epoch {epoch}", leave=False)

    for batch_idx, batch_data in pbar:
        if use_cached_values:
            batch, value_weight = batch_data
            batch_weight = value_weight.mean()
        else:
            batch = batch_data
            batch = dict_to_device(batch, device)
            batch_weight = 1.0

        optimizer.zero_grad()
        loss, metrics = agent.compute_loss(batch, backfill=cfg.backfill)
        loss = loss * batch_weight

        loss.backward()

        grad_norm = torch.nn.utils.clip_grad_norm_(
            agent.parameters(),
            max_norm=1.0 + 1e3 * (1 - cfg.base_training.clip_grad_norm),
        )

        optimizer.step()
        scheduler.step()

        all_metrics[f"{prefix}_loss"].append(loss.item())
        all_metrics[f"{prefix}_grad_norm"].append(grad_norm.item())
        if use_cached_values:
            all_metrics[f"{prefix}_val_weight"].append(batch_weight.item())

        pbar.set_postfix({
            "loss": f"{np.mean(all_metrics[f'{prefix}_loss'][-100:]):.4f}",
            "lr": f"{optimizer.param_groups[0]['lr']:.2e}"
        })

    return {k: np.mean(v) for k, v in all_metrics.items() if len(v) > 0}


def merge_dataloaders(*loaders, batch_size):
    """Merge multiple dataloaders into one."""
    combined_dataset = ConcatDataset([loader.dataset for loader in loaders])

    return DataLoader(
        combined_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
        drop_last=False
    )


def compute_q_filter(values, iteration, warmup_iterations=5, min_weight=0.3):
    if iteration < warmup_iterations:
        return torch.ones_like(values)

    v_min, v_max = values.min(), values.max()
    weights = (values - v_min) / (v_max - v_min)
    weights = min_weight + (1.0 - min_weight) * weights
    return weights


def train_bc_epochs(
    dataloader,
    agent,
    residual_policy,
    optimizer,
    scheduler,
    cfg,
    device,
    num_epochs,
    iteration,
    prefix="bc",
):
    """Train for multiple epochs and aggregate metrics across epochs.
    Computes Q-filter values once at the start if enabled.
    """
    epoch_metrics = defaultdict(list)
    total_samples = 0
    total_batches = 0

    if cfg.enable_q_filter:
        print("Pre-computing Q-filter values for all epochs...")
        all_batches = []
        all_values = []

        agent.eval()
        residual_policy.eval()
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Computing V", leave=False):
                batch = dict_to_device(batch, device)
                values = compute_values(agent, residual_policy, batch, device)
                all_values.append(values)
                all_batches.append(batch)

        all_values = torch.cat(all_values)
        weights = compute_q_filter(
            all_values,
            iteration=iteration,
            warmup_iterations=cfg.q_filter_warmup_iterations,
            min_weight=cfg.q_filter_min_weight
        )
        value_stats = {
            'mean': all_values.mean().item(),
            'std': all_values.std().item(),
            'weight_mean': weights.mean().item(),
            'in_warmup': iteration < cfg.q_filter_warmup_iterations
        }
        print(f"Value stats: {value_stats}")

        cached_dataloader = list(zip(all_batches, weights.chunk(len(all_batches))))
    else:
        cached_dataloader = None

    for epoch in trange(num_epochs, desc=f"{prefix.upper()} Epochs"):
        epoch_dataloader = cached_dataloader if cfg.enable_q_filter else dataloader

        metrics = train_bc_epoch(
            epoch_dataloader,
            agent,
            optimizer,
            scheduler,
            cfg,
            device,
            prefix=prefix,
            epoch=epoch,
            use_cached_values=cfg.enable_q_filter
        )

        total_samples += len(dataloader.dataset)
        total_batches += len(dataloader)

        for k, v in metrics.items():
            epoch_metrics[k].append(v)

    agg_metrics = {
        f"{prefix}_training/total_samples": total_samples,
        f"{prefix}_training/total_batches": total_batches,
        f"{prefix}_training/num_epochs": num_epochs,
    }

    for k, v in epoch_metrics.items():
        mean_val = np.mean(v)
        std_val = np.std(v)
        agg_metrics[f"{k}_mean"] = mean_val
        agg_metrics[f"{k}_std"] = std_val

        if any(key in k for key in ['loss', 'value']):
            agg_metrics[f"{k}_min"] = np.min(v)
            agg_metrics[f"{k}_max"] = np.max(v)

    return agg_metrics


def merge_metrics(metrics_list):
    merged_metrics = {}
    for metrics in metrics_list:
        merged_metrics.update(metrics)
    return merged_metrics


@hydra.main(
    config_path="../config",
    config_name="base_ri",
    version_base="1.2",
)
def main(cfg: DictConfig):

    OmegaConf.set_struct(cfg, False)

    # TRY NOT TO MODIFY: seeding
    if cfg.seed is None:
        cfg.seed = random.randint(0, 2**32 - 1)

    if "task" not in cfg.env:
        cfg.env.task = "one_leg"

    run_name = str(cfg.run_name) if cfg.run_name else f"{cfg.actor.residual_policy._target_.split('.')[-1]}_RI_{cfg.seed}_{int(time.time())}"

    run_directory = f"runs/{run_name}"
    run_directory += "-delete" if cfg.debug else ""
    print(f"Run directory: {run_directory}")

    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    torch.backends.cudnn.deterministic = cfg.torch_deterministic

    gpu_id = cfg.gpu_id
    device = torch.device(f"cuda:{gpu_id}")
    torch.cuda.set_device(gpu_id)

    turn_off_april_tags()

    env: FurnitureRLSimEnv = FurnitureRLSimEnv(
        act_rot_repr=cfg.control.act_rot_repr,
        action_type=cfg.control.control_mode,
        april_tags=False,
        concat_robot_state=True,
        ctrl_mode=cfg.control.controller,
        obs_keys=DEFAULT_STATE_OBS,
        furniture=cfg.env.task,
        # gpu_id=1,
        compute_device_id=gpu_id,
        graphics_device_id=gpu_id,
        headless=cfg.headless,
        num_envs=cfg.num_envs,
        observation_space="state",
        randomness=cfg.env.randomness,
        max_env_steps=100_000_000,
    )

    n_parts_to_assemble = env.n_parts_assemble

    env: RLPolicyEnvWrapper = RLPolicyEnvWrapper(
        env,
        max_env_steps=cfg.num_env_steps,
        normalize_reward=cfg.normalize_reward,
        reset_on_success=cfg.reset_on_success,
        reset_on_failure=cfg.reset_on_failure,
        reward_clip=cfg.clip_reward,
        device=device,
    )

    # Load the behavior cloning actor
    base_cfg, base_wts = get_model_from_api_or_cached(
        cfg.base_policy.wandb_id,
        wt_type=cfg.base_policy.wt_type,
        wandb_mode=cfg.wandb.mode,
    )

    merge_base_bc_config_with_root_config(cfg, base_cfg)
    cfg.actor_name = f"residual_{cfg.base_policy.actor.name}"
    # base_cfg.data_path = ["./data/processed/diffik/sim/round_table/teleop/low/success.zarr"]

    demo_dataset = StateDataset(
        dataset_paths=[Path(p) for p in to_native(base_cfg.data_path)],
        pred_horizon=cfg.data.pred_horizon,
        obs_horizon=cfg.data.obs_horizon,
        action_horizon=cfg.data.action_horizon,
        data_subset=cfg.data.data_subset,
        control_mode=cfg.control.control_mode,
        predict_past_actions=cfg.data.predict_past_actions,
        pad_after=cfg.data.get("pad_after", True),
        max_episode_count=cfg.data.get("max_episode_count", None),
        include_future_obs=cfg.data.include_future_obs,
    )
    print(f"Number of samples: {cfg.data.data_subset}, dataloader length: {len(demo_dataset)}")
    
    train_size = int(len(demo_dataset) * (1 - cfg.data.test_split))
    test_size = len(demo_dataset) - train_size
    print(f"IL batch size {cfg.base_training.batch_size}")
    print(f"Splitting dataset into {train_size} train and {test_size} test samples.")
    train_dataset, test_dataset = random_split(demo_dataset, [train_size, test_size])

    demo_testload_kwargs = dict(
        dataset=test_dataset,
        batch_size=cfg.base_training.batch_size,
        num_workers=0,
        shuffle=True,
        pin_memory=True,
        drop_last=False,
        persistent_workers=False,
    )
    demo_testloader = (
        FixedStepsDataloader(
            **demo_testload_kwargs,
            n_batches=max(
                int(round(cfg.base_training.steps_per_epoch * cfg.data.test_split)), 1
            ),
        )
        if cfg.base_training.steps_per_epoch != -1
        else DataLoader(**demo_testload_kwargs)
    )
    
    il_dataloader = FixedStepsDataloader(
        dataset=train_dataset,
        batch_size=cfg.base_training.batch_size,
        num_workers=0,
        shuffle=True,
        pin_memory=True,
        persistent_workers=False,
        n_batches=train_size//cfg.base_training.batch_size
    )
    print(f"Training dataset size: {len(train_dataset)} | Test dataset size: {len(test_dataset)}")
    print(f"IL dataloader size: {len(il_dataloader)} | Test dataloader size: {len(demo_testloader)}")

    steps_per_iteration = cfg.data_collection_steps

    if cfg.base_policy.actor.name == "diffusion":
        print("[Policy] Using diffusion policy")
        agent = ResidualDiffusionPolicy(device, base_cfg)
    elif cfg.base_policy.actor.name == "mlp":
        print("[Policy] Using MLP policy")
        agent = ResidualMlpPolicy(device, base_cfg)
    else:
        raise ValueError(f"Unknown actor type: {cfg.base_policy.actor}")

    if cfg.load_pretrained_wts:
        agent.load_base_state_dict(base_wts)
    agent.set_normalizer(demo_dataset.normalizer.to(device))
    agent.to(device)
    agent.eval()
    # Set the inference steps of the actor
    if isinstance(agent, DiffusionPolicy):
        agent.inference_steps = 4

    residual_policy = agent.residual_policy

    if cfg.il_base_only:
        print("[Policy] Using base actor only")
        opt_actor = torch.optim.AdamW(
            params=agent.base_actor_parameters,
            # TODO: training config is copied from BC base config and should be renamed
            lr=cfg.base_training.actor_lr,
            weight_decay=cfg.base_regularization.weight_decay,
        )
    else:
        print("[Policy] Using full actor")
        opt_actor = torch.optim.AdamW(
            params=agent.full_actor_parameters,
            lr=cfg.base_training.actor_lr,
            weight_decay=cfg.base_regularization.weight_decay,
        )
        
    lr_sche_actor = get_scheduler(
        name=cfg.base_lr_scheduler.name,
        optimizer=opt_actor,
        num_warmup_steps=cfg.base_lr_scheduler.warmup_steps,
        num_training_steps=len(demo_dataset) // cfg.base_training.batch_size * cfg.base_training.num_epochs,
    )

    opt_res_actor = optim.AdamW(
        agent.actor_parameters,
        lr=cfg.learning_rate_actor,
        eps=1e-5,
        weight_decay=1e-6,
    )
    lr_sche_res_actor = get_scheduler(
        name=cfg.lr_scheduler.name,
        optimizer=opt_res_actor,
        num_warmup_steps=cfg.lr_scheduler.actor_warmup_steps,
        num_training_steps=cfg.num_iterations,
    )

    opt_res_critic = optim.AdamW(
        agent.critic_parameters,
        lr=cfg.learning_rate_critic,
        eps=1e-5,
        weight_decay=1e-6,
    )
    lr_sche_res_critic = get_scheduler(
        name=cfg.lr_scheduler.name,
        optimizer=opt_res_critic,
        num_warmup_steps=cfg.lr_scheduler.critic_warmup_steps,
        num_training_steps=cfg.num_iterations,
    )

    optimizers = [("residual_actor", opt_res_actor), ("residual_critic", opt_res_critic), ("actor", opt_actor)]
    schedulers = [("residual_actor", lr_sche_res_actor),
                  ("residual_critic", lr_sche_res_critic), ("actor", lr_sche_actor)]

    if cfg.enable_rl_replay:
        print("Enabling RL replay buffer")
        buffer = RolloutBuffer(
            max_size=cfg.base_bc.replay_buffer_size,
            state_dim=agent.obs_dim,
            action_dim=agent.action_dim,
            pred_horizon=agent.pred_horizon,
            obs_horizon=agent.obs_horizon,
            action_horizon=agent.action_horizon,
            device=device,
            predict_past_actions=cfg.data.predict_past_actions,
            include_future_obs=cfg.data.include_future_obs,
        )
        buffer.set_normalizer(demo_dataset.normalizer)

    print(f"PPO batch size: {cfg.batch_size}; mini-batch size: {cfg.minibatch_size}")
    print(f"Total RL timesteps: {cfg.total_timesteps}; Num iterations: {cfg.num_iterations}")
    print(
        f"BC dataset size {len(train_dataset)}; BC batch size: {cfg.base_training.batch_size}; Steps per epoch: {cfg.base_training.steps_per_epoch}")

    print(OmegaConf.to_yaml(cfg, resolve=True))

    run = wandb.init(
        id=cfg.wandb.continue_run_id,
        resume=None if cfg.wandb.continue_run_id is None else "must",
        project=cfg.wandb.project,
        entity=cfg.wandb.entity,
        config=OmegaConf.to_container(cfg, resolve=True),
        name=run_name,
        save_code=True,
        mode=cfg.wandb.mode if not cfg.debug else "disabled",
    )

    if cfg.wandb.continue_run_id is not None:
        print(f"[WandB] Continuing run {cfg.wandb.continue_run_id}, {run.name}")

        run_id = f"{cfg.wandb.project}/{cfg.wandb.continue_run_id}"

        # Load the weights from the run
        _, wts = get_model_from_api_or_cached(
            run_id, "latest", wandb_mode=cfg.wandb.mode
        )

        print(f"[WandB] Loading weights from {wts}")

        run_state_dict = torch.load(wts)
        if "model_state_dict" in run_state_dict:
            agent.load_state_dict(run_state_dict["model_state_dict"])
            for (name, opt), (__, scheduler) in zip(optimizers, schedulers):
                opt.load_state_dict(run_state_dict[f"{name}_optimizer_state_dict"])
                scheduler.load_state_dict(run_state_dict[f"{name}_scheduler_state_dict"])

        # Set the best test loss and success rate to the one from the run
        try:
            best_eval_success_rate = run.summary["eval/best_eval_success_rate"]
        except KeyError:
            best_eval_success_rate = run.summary["eval/success_rate"]

        iteration = run.summary["iteration"]
        global_step = run.step
        bc_step = 0

    else:
        global_step = 0
        iteration = 0
        best_eval_success_rate = 0.0
        bc_step = 0

    obs: torch.Tensor = torch.zeros(
        (
            steps_per_iteration,
            cfg.num_envs,
            residual_policy.obs_dim,
        )
    )
    actions = torch.zeros((steps_per_iteration, cfg.num_envs) + env.action_space.shape)
    full_nactions = torch.zeros(
        (steps_per_iteration, cfg.num_envs) + env.action_space.shape
    )
    logprobs = torch.zeros((steps_per_iteration, cfg.num_envs))
    rewards = torch.zeros((steps_per_iteration, cfg.num_envs))
    dones = torch.zeros((steps_per_iteration, cfg.num_envs))
    values = torch.zeros((steps_per_iteration, cfg.num_envs))

    start_time = time.time()
    training_cum_time = 0
    running_mean_success_rate = 0.0

    next_done = torch.zeros(cfg.num_envs)
    next_obs = env.reset()
    agent.reset()

    # Create model save dir
    model_save_dir: Path = Path("models") / wandb.run.name
    model_save_dir.mkdir(parents=True, exist_ok=True)
    
    num_bc_epochs = cfg.initial_num_bc_epochs
    rl_per_bc = cfg.rl_per_bc
    rl_counter = 0

    while global_step < cfg.total_timesteps:
        iteration += 1
        # Calculate how many training iterations we've done
        training_iterations = iteration - cfg.eval_first
        training_iterations -= (iteration - cfg.eval_first) // cfg.eval_interval
        print(f"Iteration: {iteration}/{cfg.num_iterations}")
        print(f"Run name: {run_name}")
        iteration_start_time = time.time()

        # If eval first flag is set, we will evaluate the model before doing any training
        eval_rl = (iteration - int(cfg.eval_first)) % cfg.eval_interval == 0

        # Also reset the env to have more consistent results
        if eval_rl or cfg.reset_every_iteration:
            if not cfg.eval_first or iteration != 1:
                next_obs = env.reset()
                agent.reset()

        print(f"Eval mode: {eval_rl}")

        print("Training base with BC...")
        if not eval_rl and num_bc_epochs > 0 and rl_counter % rl_per_bc == 0:
            bc_steps_this_iter = 0
            all_metrics = []

            if cfg.enable_rl_replay and len(buffer) > 0:
                print("Training on expert data and buffer data")
                buffer_dataloader = DataLoader(
                    buffer, batch_size=cfg.base_training.batch_size,
                    shuffle=True, num_workers=0, pin_memory=True
                )
                print(f"Buffer dataloader size: {len(buffer_dataloader)}")

                merged_dataloader = merge_dataloaders(
                    il_dataloader,
                    buffer_dataloader,
                    batch_size=cfg.base_training.batch_size
                )
                il_metrics = train_bc_epochs(
                    merged_dataloader, agent, residual_policy,
                    opt_actor, lr_sche_actor, cfg, device,
                    num_epochs=num_bc_epochs,
                    iteration=iteration,
                    prefix="bc"
                )
                bc_steps_this_iter += il_metrics["bc_training/total_batches"]
                all_metrics.append(il_metrics)
            else:
                print("Training on expert data only")
                il_metrics = train_bc_epochs(
                    il_dataloader, agent, residual_policy,
                    opt_actor, lr_sche_actor, cfg, device,
                    num_epochs=num_bc_epochs,
                    iteration=iteration,
                    prefix="bc"
                )
                bc_steps_this_iter += il_metrics["bc_training/total_batches"]
                all_metrics.append(il_metrics)

            # Update global BC step counter
            bc_step += bc_steps_this_iter

            # Add combined metrics
            combined_metrics = {
                "bc_training/total_steps": bc_step,
                "bc_training/num_bc_epochs": num_bc_epochs,
                "bc_training/steps_this_iter": bc_steps_this_iter,
                "bc_training/steps_per_rl_step": bc_steps_this_iter / cfg.num_envs,
                "training/bc_learning_rate": opt_actor.param_groups[0]["lr"],
            }
            all_metrics.append(combined_metrics)

            print(f"Completed BC training with {bc_steps_this_iter} steps")
            merged_metrics = merge_metrics(all_metrics)
            wandb.log(merged_metrics, step=global_step)
        else:
            print("BC training skipped")

        # Evaluate IL policy
        print("Evaluating BC policy...")
        il_test_dataloader = iter(demo_testloader)
        agent.eval()
        eval_loss = []
        test_tepoch = tqdm(il_test_dataloader, desc="BC Eval")
        for test_batch in test_tepoch:
            with torch.no_grad():
                test_batch = dict_to_device(test_batch, device)
                loss, _ = agent.compute_loss(test_batch, backfill=cfg.backfill)
                test_loss_cpu = loss.item()
                eval_loss.append(test_loss_cpu)
                test_tepoch.set_postfix(loss=test_loss_cpu)
        test_tepoch.close()

        mean_il_advantage, mean_il_value = evaluate_il_online(agent, residual_policy, demo_testloader, device)
        wandb.log({
            "eval/bc_loss": np.mean(eval_loss),
            "eval/mean_bc_advantage": mean_il_advantage,
            "eval/mean_bc_value": mean_il_value,
        }, step=global_step)

        # ROLLOUT: Collecting online data for RL training
        print("Collecting online data...")

        for step in range(0, steps_per_iteration):
            if not eval_rl:
                # Only count environment steps during training
                global_step += cfg.num_envs

            # Get the base normalized action
            base_naction = agent.base_action_normalized(next_obs)

            # Process the obs for the residual policy
            next_nobs = agent.process_obs(next_obs)
            next_residual_nobs = torch.cat([next_nobs, base_naction], dim=-1)

            dones[step] = next_done
            obs[step] = next_residual_nobs

            with torch.no_grad():
                residual_naction_samp, logprob, _, value, naction_mean = (
                    residual_policy.get_action_and_value(next_residual_nobs)
                )

            residual_naction = residual_naction_samp if not eval_rl else naction_mean
            naction = base_naction + residual_naction * residual_policy.action_scale

            action = agent.normalizer(naction, "action", forward=False)
            next_obs, reward, next_done, truncated, info = env.step(action)

            if cfg.truncation_as_done:
                next_done = next_done | truncated

            values[step] = value.flatten().cpu()
            actions[step] = residual_naction.cpu()
            logprobs[step] = logprob.cpu()
            rewards[step] = reward.view(-1).cpu()
            next_done = next_done.view(-1).cpu()
            full_nactions[step] = naction.cpu()

            if step > 0 and (env_step := step * 1) % 100 == 0:
                print(
                    f"env_step={env_step}, global_step={global_step}, mean_reward={rewards[:step+1].sum(dim=0).mean().item()} fps={env_step * cfg.num_envs / (time.time() - iteration_start_time):.2f}"
                )

        # Calculate the success rate
        # Find the rewards that are not zero
        # Env is successful if it received a reward more than or equal to n_parts_to_assemble
        env_success = (rewards > 0).sum(dim=0) >= n_parts_to_assemble
        success_rate = env_success.float().mean().item()

        # Calculate the share of timesteps that come from successful trajectories that account for the success rate and the varying number of timesteps per trajectory
        # Count total timesteps in successful trajectories
        timesteps_in_success = rewards[:, env_success]

        # Find index of last reward in each trajectory
        last_reward_idx = torch.argmax(timesteps_in_success, dim=0)

        # Calculate the total number of timesteps in successful trajectories
        total_timesteps_in_success = last_reward_idx.sum().item()

        # Calculate the share of successful timesteps
        success_timesteps_share = total_timesteps_in_success / rewards.numel()

        running_mean_success_rate = 0.5 * running_mean_success_rate + 0.5 * success_rate

        print(
            f"SR: {success_rate:.4%}, SR mean: {running_mean_success_rate:.4%}, SPS: {steps_per_iteration * cfg.num_envs / (time.time() - iteration_start_time):.2f}"
        )

        # EVALUATION
        if eval_rl:
            # Save the model if the evaluation success rate improves
            if success_rate > best_eval_success_rate:
                best_eval_success_rate = success_rate
                sr_threshold = cfg.get("terminate_schedule_sr_threshold", -1)
                if sr_threshold != -1 and success_rate > sr_threshold:
                    print(f"Terminating schedule early due to success rate {success_rate:.4%} > {sr_threshold:.4%}")
                model_path = str(model_save_dir / f"actor_chkpt_best_success_rate.pt")
                save_dict = {
                    # Save the weights of the residual policy (base + residual)
                    "model_state_dict": agent.state_dict(),
                    "global_step": global_step,
                    "success_rate": success_rate,
                    "success_timesteps_share": success_timesteps_share,
                    "iteration": iteration,
                    "config": OmegaConf.to_container(cfg, resolve=True),
                }
                for (name, opt), (__, scheduler) in zip(optimizers, schedulers):
                    save_dict[f"{name}_optimizer_state_dict"] = opt.state_dict()
                    save_dict[f"{name}_scheduler_state_dict"] = scheduler.state_dict()
                torch.save(save_dict, model_path)

                wandb.save(model_path)
                print(f"Evaluation success rate improved. Model saved to {model_path}")

                # Add successful trajectories to the replay buffer when success rate increases
                if cfg.enable_rl_replay and success_rate >= cfg.replay_from_sr:
                    add_successful_trajectories_to_buffer(
                        buffer=buffer,
                        obs=obs,
                        full_nactions=full_nactions,
                        rewards=rewards,
                        n_parts_to_assemble=n_parts_to_assemble,
                        device=device,
                        max_new_samples=cfg.max_replay_new_samples
                    )
                    buffer.rebuild_seq_indices()

            wandb.log(
                {
                    "eval/success_rate": success_rate,
                    "eval/best_eval_success_rate": best_eval_success_rate,
                    "iteration": iteration,
                },
                step=global_step,
            )
            
            continue

        # TRAINING RL
        print("Training the residual policy with RL...")

        b_obs = obs.reshape((-1, residual_policy.obs_dim))
        b_actions = actions.reshape((-1,) + env.action_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_values = values.reshape(-1)

        # Get the base normalized action
        # Process the obs for the residual policy
        base_naction = agent.base_action_normalized(next_obs)
        next_nobs = agent.process_obs(next_obs)
        next_residual_nobs = torch.cat([next_nobs, base_naction], dim=-1)
        next_value = residual_policy.get_value(next_residual_nobs).reshape(1, -1).cpu()

        # bootstrap value if not done
        advantages, returns = calculate_advantage(
            values,
            next_value,
            rewards,
            dones,
            next_done,
            steps_per_iteration,
            cfg.discount,
            cfg.gae_lambda,
        )

        b_advantages = advantages.reshape(-1).cpu()
        b_returns = returns.reshape(-1).cpu()

        # Optimizing the policy and value network
        b_inds = np.arange(cfg.batch_size)
        clipfracs = []
        rl_grad_norms = []
        for epoch in trange(cfg.update_epochs, desc="Policy update"):
            early_stop = False

            np.random.shuffle(b_inds)
            for start in range(0, cfg.batch_size, cfg.minibatch_size):
                end = start + cfg.minibatch_size
                mb_inds = b_inds[start:end]

                # Get the minibatch and place it on the device
                mb_obs = b_obs[mb_inds].to(device)
                mb_actions = b_actions[mb_inds].to(device)
                mb_logprobs = b_logprobs[mb_inds].to(device)
                mb_advantages = b_advantages[mb_inds].to(device)
                mb_returns = b_returns[mb_inds].to(device)
                mb_values = b_values[mb_inds].to(device)

                # Calculate the loss
                _, newlogprob, entropy, newvalue, action_mean = (
                    residual_policy.get_action_and_value(mb_obs, mb_actions)
                )
                logratio = newlogprob - mb_logprobs
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [
                        ((ratio - 1.0).abs() > cfg.clip_coef).float().mean().item()
                    ]

                if cfg.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (
                        mb_advantages.std() + 1e-8
                    )

                policy_loss = 0

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(
                    ratio, 1 - cfg.clip_coef, 1 + cfg.clip_coef
                )
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if cfg.clip_vloss:
                    v_loss_unclipped = (newvalue - mb_returns) ** 2
                    v_clipped = mb_values + torch.clamp(
                        newvalue - mb_values,
                        -cfg.clip_coef,
                        cfg.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - mb_returns) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean()

                # Entropy loss
                entropy_loss = entropy.mean() * cfg.ent_coef

                ppo_loss = pg_loss - entropy_loss

                # Add the auxiliary regularization loss
                residual_l1_loss = torch.mean(torch.abs(action_mean))
                residual_l2_loss = torch.mean(torch.square(action_mean))

                # Normalize the losses so that each term has the same scale
                if iteration > cfg.n_iterations_train_only_value:

                    # Scale the losses using the calculated scaling factors
                    policy_loss += ppo_loss
                    policy_loss += cfg.residual_l1 * residual_l1_loss
                    policy_loss += cfg.residual_l2 * residual_l2_loss

                rl_loss: torch.Tensor = policy_loss * cfg.rl_coef
                value_loss: torch.Tensor = v_loss * cfg.vf_coef

                loss = rl_loss + value_loss

                opt_res_actor.zero_grad()
                opt_res_critic.zero_grad()

                loss.backward()
                grad_norm = nn.utils.clip_grad_norm(
                    residual_policy.parameters(), cfg.max_grad_norm
                )
                rl_grad_norms.append(grad_norm.item())

                opt_res_actor.step()
                opt_res_critic.step()

                if cfg.target_kl is not None and approx_kl > cfg.target_kl:
                    print(
                        f"Early stopping at epoch {epoch} due to reaching max kl: {approx_kl:.4f} > {cfg.target_kl:.4f}"
                    )
                    early_stop = True
                    break

            if early_stop:
                break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        action_norms = torch.norm(b_actions[:, :3], dim=-1).cpu()

        training_cum_time += time.time() - iteration_start_time
        sps = int(global_step / training_cum_time) if training_cum_time > 0 else 0

        wandb.log(
            {
                "grads/rl_mean_grad_norms": np.mean(rl_grad_norms),
                "training/rl_counter": rl_counter,
                "training/rl_learning_rate_actor": opt_res_actor.param_groups[0]["lr"],
                "training/rl_learning_rate_critic": opt_res_critic.param_groups[0]["lr"],
                "charts/SPS": sps,
                "charts/rewards": rewards.sum().item(),
                "charts/policy_entropy": entropy.mean().item(),
                "charts/success_rate": success_rate,
                "charts/success_timesteps_share": success_timesteps_share,
                "charts/action_norm_mean": action_norms.mean(),
                "charts/action_norm_std": action_norms.std(),
                "values/advantages": b_advantages.mean().item(),
                "values/returns": b_returns.mean().item(),
                "values/values": b_values.mean().item(),
                "values/mean_logstd": residual_policy.actor_logstd.mean().item(),
                "losses/value_loss": v_loss.item(),
                "losses/policy_loss": pg_loss.item(),
                "losses/total_loss": loss.item(),
                "losses/entropy_loss": entropy_loss.item(),
                "losses/old_approx_kl": old_approx_kl.item(),
                "losses/approx_kl": approx_kl.item(),
                "losses/clipfrac": np.mean(clipfracs),
                "losses/explained_variance": explained_var,
                "losses/residual_l1": residual_l1_loss.item(),
                "losses/residual_l2": residual_l2_loss.item(),
                "histograms/values": wandb.Histogram(values),
                "histograms/returns": wandb.Histogram(b_returns),
                "histograms/advantages": wandb.Histogram(b_advantages),
                "histograms/logprobs": wandb.Histogram(logprobs),
                "histograms/rewards": wandb.Histogram(rewards),
                "histograms/action_norms": wandb.Histogram(action_norms),
            },
            step=global_step,
        )

        # Step the learning rate scheduler
        lr_sche_res_actor.step()
        lr_sche_res_critic.step()

        # Checkpoint every cfg.checkpoint_interval steps
        if cfg.checkpoint_interval > 0 and iteration % cfg.checkpoint_interval == 0:
            model_path = str(model_save_dir / f"actor_chkpt_{iteration}.pt")
            torch.save(
                {
                    "model_state_dict": agent.state_dict(),
                    "optimizer_actor_state_dict": opt_res_actor.state_dict(),
                    "optimizer_critic_state_dict": opt_res_critic.state_dict(),
                    "scheduler_actor_state_dict": lr_sche_res_actor.state_dict(),
                    "scheduler_critic_state_dict": lr_sche_res_critic.state_dict(),
                    "config": OmegaConf.to_container(cfg, resolve=True),
                    "success_rate": success_rate,
                    "iteration": iteration,
                },
                model_path,
            )

            wandb.save(model_path)
            print(f"Model saved to {model_path}")

        rl_counter += 1

        # Print some stats at the end of the iteration
        print(
            f"Iteration {iteration}/{cfg.num_iterations}, global step {global_step}, SPS {sps}"
        )

        print(
            f"At iteration {iteration}, we've done {training_iterations} training iterations "
            f"and {iteration - training_iterations} evaluation iterations"
        )

    print(f"Training finished in {(time.time() - start_time):.2f}s")


if __name__ == "__main__":
    main()
