import os
import random
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Optional, Union

import hydra
import numpy as np
import torch
import wandb
from diffusers.optimization import get_scheduler
from gymnasium import Env
from ipdb import set_trace as bp
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm, trange

from src.behavior import get_actor
from src.behavior.base import Actor
from src.common.earlystop import EarlyStopper
from src.common.files import get_processed_paths, path_override
from src.common.hydra import to_native
from src.common.pytorch_util import dict_to_device
from src.dataset.dataloader import FixedStepsDataloader
from src.dataset.dataset import ImageDataset, StateDataset
from src.dataset.maniskill_dataset import ManiSkillStateDataset, ManiSkillImageDataset # for maniskill only
from src.eval.eval_utils import get_model_from_api_or_cached
from src.eval.rollout import do_rollout_evaluation
from src.eval.maniskill_evaluate import create_maniskill_eval_env, evaluate_maniskill # for maniskill only
from src.gym import get_rl_env
from src.models.ema import SwitchEMA

# Import the wandb Run type for type hinting
from wandb.apis.public.runs import Run
from wandb.errors.util import CommError
from wandb_osh.hooks import TriggerWandbSyncHook, _comm_default_dir

trigger_sync = TriggerWandbSyncHook(
    communication_dir=os.environ.get("WANDB_OSH_COMM_DIR", _comm_default_dir),
)


print("=== Activate TF32 training? Deactivated for now...")
# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True


def log_action_mse(log_dict, category, pred_action, gt_action, is_maniskill=False):
    B, T, D = pred_action.shape
    
    # Calculate overall MSE error without reshaping
    log_dict[f"action_sample/{category}_action_mse_error"] = (
        torch.nn.functional.mse_loss(pred_action, gt_action)
    )
    
    if is_maniskill:
        try:
            if D >= 3:
                log_dict[f"action_sample/{category}_action_mse_error_pos"] = (
                    torch.nn.functional.mse_loss(pred_action[..., :3], gt_action[..., :3])
                )
            if D >= 6:
                log_dict[f"action_sample/{category}_action_mse_error_ori"] = (
                    torch.nn.functional.mse_loss(pred_action[..., 3:6], gt_action[..., 3:6])
                )
            if D >= 7:
                log_dict[f"action_sample/{category}_action_mse_error_grip"] = (
                    torch.nn.functional.mse_loss(pred_action[..., 6], gt_action[..., 6])
                )
        except:
            pass
        
    else:
        try:
            pred_action_reshaped = pred_action.view(B, T, -1, 10)
            gt_action_reshaped = gt_action.view(B, T, -1, 10)
            
            log_dict[f"action_sample/{category}_action_mse_error_pos"] = (
                torch.nn.functional.mse_loss(pred_action_reshaped[..., :3], gt_action_reshaped[..., :3])
            )
            log_dict[f"action_sample/{category}_action_mse_error_rot"] = (
                torch.nn.functional.mse_loss(pred_action_reshaped[..., 3:9], gt_action_reshaped[..., 3:9])
            )
            log_dict[f"action_sample/{category}_action_mse_error_width"] = (
                torch.nn.functional.mse_loss(pred_action_reshaped[..., 9], gt_action_reshaped[..., 9])
            )
        except:
            pass


def set_dryrun_params(cfg: DictConfig):
    if cfg.dryrun:
        OmegaConf.set_struct(cfg, False)
        cfg.training.steps_per_epoch = 10 if cfg.training.steps_per_epoch != -1 else -1
        cfg.data.data_subset = 5
        cfg.data.dataloader_workers = 0
        cfg.training.sample_every = 1
        cfg.training.eval_every = 1

        if cfg.rollout.rollouts:
            cfg.rollout.every = 1
            # cfg.rollout.num_rollouts = 1
            cfg.rollout.loss_threshold = float("inf")
            # cfg.rollout.max_steps = 10

        cfg.wandb.mode = "disabled"

        OmegaConf.set_struct(cfg, True)


def now():
    return datetime.now().strftime("%Y-%m-%d %H:%M")


# @hydra.main(config_path="../config/bc", config_name="base")
@hydra.main(config_path="../config", config_name="base_maniskill", version_base="1.2")
def main(cfg: DictConfig):
    set_dryrun_params(cfg)
    OmegaConf.resolve(cfg)

    # Set the random seed
    if cfg.get("seed") is None:
        OmegaConf.set_struct(cfg, False)
        cfg.seed = np.random.randint(0, 2**32 - 1)
        OmegaConf.set_struct(cfg, True)

    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    print(OmegaConf.to_yaml(cfg))
    # Initialize environments to None
    env: Optional[Env] = None
    eval_env: Optional[Env] = None
    
    # Set up device
    device = torch.device(
        f"cuda:{cfg.training.gpu_id}" if torch.cuda.is_available() else "cpu"
    )
    torch.cuda.set_device(cfg.training.gpu_id)
    
    # For ManiSkill, create the evaluation environment at the beginning
    is_maniskill = cfg.get("dataset_type", "") == "maniskill"
    if is_maniskill and cfg.rollout.rollouts:
        try:
            # Determine backend based on availability
            sim_backend = "physx_cuda"
            if hasattr(cfg, "force_cpu_eval") and cfg.force_cpu_eval or not torch.cuda.is_available():
                sim_backend = "cpu"
            
            task_name = cfg.task
            reward_mode = cfg.env.get("reward_mode", "sparse") if hasattr(cfg, "env") else "sparse"
            
            eval_num_envs = cfg.eval_num_envs if hasattr(cfg, "eval_num_envs") else cfg.rollout.num_envs
            eval_record_video = hasattr(cfg, "eval_record_video") and cfg.eval_record_video
            
            print(f"Creating persistent ManiSkill evaluation environment for {task_name}")
            print(f"Using {eval_num_envs} environments, recording video: {eval_record_video}")
            
            eval_env = create_maniskill_eval_env(
                task_name=task_name,
                num_envs=eval_num_envs,
                observation_space=cfg.observation_type,
                control_mode=cfg.control.control_mode,
                sim_backend=sim_backend,
                record_video=eval_record_video,
                video_dir=os.path.join("runs", "videos") if eval_record_video else None,
                fallback_to_cpu=True,  # Allow fallback for evaluation
                reward_mode=reward_mode,
            )
            print("Successfully created persistent evaluation environment")
        except Exception as e:
            print(f"WARNING: Failed to create persistent evaluation environment: {e}")
            print("Will create evaluation environment when needed")

    state_dict = None

    # Check if we are continuing a run
    run_exists = False
    if cfg.wandb.continue_run_id is not None:
        try:
            run: Run = wandb.Api().run(
                f"{cfg.wandb.project}/{cfg.wandb.continue_run_id}"
            )
            run_exists = True
        except (ValueError, CommError):
            pass

    if run_exists:
        print(f"Continuing run {cfg.wandb.continue_run_id}, {run.name}")

        run_id = cfg.wandb.continue_run_id
        run_path = f"{cfg.wandb.project}/{run_id}"
        wandb_mode = cfg.wandb.mode

        data_paths_override = cfg.data.data_paths_override

        # Load the weights from the run and override the config with the one from the run
        try:
            cfg, wts = get_model_from_api_or_cached(
                run_path, "last", wandb_mode=wandb_mode
            )
        except:
            cfg, wts = get_model_from_api_or_cached(
                run_path, "latest", wandb_mode=wandb_mode
            )

        # Ensure we set the `continue_run_id` to the run_id
        cfg.wandb.continue_run_id = run_id
        cfg.wandb.mode = wandb_mode
        cfg.data.data_paths_override = data_paths_override

        state_dict = torch.load(wts)

        epoch_idx = state_dict.get("epoch", run.summary.get("epoch", 0))
        cfg.training.start_epoch = epoch_idx

        # Set the best test loss and success rate to the one from the run
        best_test_loss = state_dict.get(
            "best_test_loss", run.summary.get("test_epoch_loss", float("inf"))
        )
        test_loss_mean = best_test_loss
        best_success_rate = state_dict.get(
            "best_success_rate", run.summary.get("best_success_rate", 0)
        )
        epoch_idx = state_dict.get("epoch", run.summary.get("epoch", 0))
        global_step = state_dict.get("global_step", run.lastHistoryStep)

        prev_best_success_rate = best_success_rate
    else:
        # Train loop
        best_test_loss = float("inf")
        test_loss_mean = float("inf")
        best_success_rate = 0
        prev_best_success_rate = 0
        global_step = 0

    # Check for dataset type - standard or ManiSkill
    is_maniskill = cfg.get("dataset_type", "") == "maniskill"
    
    if not is_maniskill:
        # Original code for furniturebench
        if cfg.data.data_paths_override is None:
            data_path = get_processed_paths(
                controller=to_native(cfg.control.controller),
                domain=to_native(cfg.data.environment),
                task=to_native(cfg.data.task),
                demo_source=to_native(cfg.data.demo_source),
                randomness=to_native(cfg.data.randomness),
                demo_outcome=to_native(cfg.data.demo_outcome),
                suffix=to_native(cfg.data.suffix),
            )
        else:
            data_path = path_override(cfg.data.data_paths_override)

        print(f"Using data from {data_path}")

        dataset: Union[ImageDataset, StateDataset]

        if cfg.observation_type == "image":
            dataset = ImageDataset(
                dataset_paths=data_path,
                pred_horizon=cfg.data.pred_horizon,
                obs_horizon=cfg.data.obs_horizon,
                action_horizon=cfg.data.action_horizon,
                data_subset=cfg.data.data_subset,
                control_mode=cfg.control.control_mode,
                predict_past_actions=cfg.data.predict_past_actions,
                pad_after=cfg.data.get("pad_after", True),
                max_episode_count=cfg.data.get("max_episode_count", None),
                minority_class_power=cfg.data.get("minority_class_power", False),
                load_into_memory=cfg.data.get("load_into_memory", True),
            )
        elif cfg.observation_type == "state":
            dataset = StateDataset(
                dataset_paths=data_path,
                pred_horizon=cfg.data.pred_horizon,
                obs_horizon=cfg.data.obs_horizon,
                action_horizon=cfg.data.action_horizon,
                data_subset=cfg.data.data_subset,
                control_mode=cfg.control.control_mode,
                predict_past_actions=cfg.data.predict_past_actions,
                pad_after=cfg.data.get("pad_after", True),
                max_episode_count=cfg.data.get("max_episode_count", None),
                include_future_obs=cfg.data.include_future_obs,
            )
        else:
            raise ValueError(f"Unknown observation type: {cfg.observation_type}")
    else:
        # ManiSkill dataset code path
        data_path = cfg.data.h5_path
        print(f"Using ManiSkill data from {data_path}")
        
        if hasattr(cfg, "task") and cfg.task and cfg.get("physx_cuda_detect_dims", True):
            try:
                print("Detecting observation and action dimensions from ManiSkill environment...")
                import gymnasium as gym
                import mani_skill.envs
                
                env_kwargs = {
                    "obs_mode": "state",
                    "control_mode": "pd_ee_delta_pos",
                }
                
                test_env = gym.make(cfg.task, **env_kwargs)
                obs, _ = test_env.reset()
                
                # Get actual dimensions
                if isinstance(obs, dict) and "state" in obs:
                    actual_obs_dim = obs["state"].shape[0]
                elif isinstance(obs, torch.Tensor):
                    actual_obs_dim = obs.shape[0]
                elif isinstance(obs, np.ndarray):
                    actual_obs_dim = obs.shape[0]
                else:
                    actual_obs_dim = None
                
                actual_action_dim = test_env.action_space.shape[0]
                
                if actual_obs_dim is not None:
                    OmegaConf.set_struct(cfg, False)
                    cfg.robot_state_dim = actual_obs_dim
                    OmegaConf.set_struct(cfg, True)
                
                print(f"Detected action dimension: {actual_action_dim}")
                OmegaConf.set_struct(cfg, False)
                cfg.action_dim = actual_action_dim
                OmegaConf.set_struct(cfg, True)
                
                test_env.close()
            except Exception as e:
                print(f"Failed to physx_cuda-detect dimensions: {e}")

        if cfg.observation_type == "image":
            dataset = ManiSkillImageDataset(
                dataset_paths=data_path,
                pred_horizon=cfg.data.pred_horizon,
                obs_horizon=cfg.data.obs_horizon,
                action_horizon=cfg.data.action_horizon,
                data_subset=cfg.data.data_subset,
                control_mode=cfg.control.control_mode,
                predict_past_actions=cfg.data.predict_past_actions,
                pad_after=cfg.data.get("pad_after", True),
            )
        elif cfg.observation_type == "state":
            dataset = ManiSkillStateDataset(
                dataset_paths=data_path,
                pred_horizon=cfg.data.pred_horizon,
                obs_horizon=cfg.data.obs_horizon,
                action_horizon=cfg.data.action_horizon,
                data_subset=cfg.data.data_subset,
                control_mode=cfg.control.control_mode,
                predict_past_actions=cfg.data.predict_past_actions,
                pad_after=cfg.data.get("pad_after", True),
            )
        else:
            raise ValueError(f"Unknown observation type: {cfg.observation_type}")

    # Split the dataset into train and test (effective, meaning that this is after upsampling)
    train_size = int(len(dataset) * (1 - cfg.data.test_split))
    test_size = len(dataset) - train_size
    print(f"Splitting dataset into {train_size} train and {test_size} test samples.")
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    OmegaConf.set_struct(cfg, False)
    if (job_id := os.environ.get("SLURM_JOB_ID")) is not None:
        cfg.slurm_job_id = job_id

    cfg.robot_state_dim = dataset.robot_state_dim
    cfg.action_dim = dataset.action_dim

    # Always set parts_poses_dim for compatibility - for ManiSkill it will be 0
    cfg.parts_poses_dim = dataset.parts_poses_dim

    # Create the policy network
    actor: Actor = get_actor(
        cfg,
        device,
    )
    actor.set_normalizer(dataset.normalizer.to(device))
    actor.to(device)

    # Set the data path in the cfg object
    if isinstance(data_path, list):
        cfg.data_path = [str(f) for f in data_path]
    else:
        cfg.data_path = str(data_path)

    # Update the cfg object with the action dimension
    cfg.n_episodes = len(dataset.episode_ends)
    cfg.n_samples = dataset.n_samples

    # Update the cfg object with the observation dimension
    cfg.timestep_obs_dim = actor.timestep_obs_dim
    OmegaConf.set_struct(cfg, True)

    if cfg.training.load_checkpoint_run_id is not None:
        api = wandb.Api()
        run = api.run(cfg.training.load_checkpoint_run_id)
        model_path = (
            [f for f in run.files() if f.name.endswith(".pt")][0]
            .download(exist_ok=True)
            .name
        )
        print(f"Loading checkpoint from {cfg.training.load_checkpoint_run_id}")
        actor.load_state_dict(torch.load(model_path))

    # Create dataloaders
    trainload_kwargs = dict(
        dataset=train_dataset,
        batch_size=cfg.training.batch_size,
        num_workers=cfg.data.dataloader_workers,
        shuffle=True,
        pin_memory=True,
        drop_last=False,
        persistent_workers=False,
    )
    trainloader = (
        FixedStepsDataloader(**trainload_kwargs, n_batches=cfg.training.steps_per_epoch)
        if cfg.training.steps_per_epoch != -1
        else DataLoader(**trainload_kwargs)
    )

    testload_kwargs = dict(
        dataset=test_dataset,
        batch_size=cfg.training.batch_size,
        num_workers=cfg.data.dataloader_workers,
        shuffle=True,
        pin_memory=True,
        drop_last=False,
        persistent_workers=False,
    )

    testloader = (
        FixedStepsDataloader(
            **testload_kwargs,
            n_batches=max(
                int(round(cfg.training.steps_per_epoch * cfg.data.test_split)), 1
            ),
        )
        if cfg.training.steps_per_epoch != -1
        else DataLoader(**testload_kwargs)
    )

    # Create lists for optimizers and lr schedulers

    opt_noise = torch.optim.AdamW(
        params=actor.actor_parameters(),
        lr=cfg.training.actor_lr,
        weight_decay=cfg.regularization.weight_decay,
    )
    lr_scheduler = get_scheduler(
        name=cfg.lr_scheduler.name,
        optimizer=opt_noise,
        num_warmup_steps=cfg.lr_scheduler.warmup_steps,
        num_training_steps=len(trainloader) * cfg.training.num_epochs,
    )

    optimizers = [("actor", opt_noise)]
    lr_schedulers = [lr_scheduler]

    if cfg.observation_type == "image":

        opt_encoder = torch.optim.AdamW(
            params=actor.encoder_parameters(),
            lr=cfg.training.encoder_lr,
            weight_decay=cfg.regularization.weight_decay,
        )
        lr_scheduler_encoder = get_scheduler(
            name=cfg.lr_scheduler.name,
            optimizer=opt_encoder,
            num_warmup_steps=cfg.lr_scheduler.encoder_warmup_steps,
            num_training_steps=len(trainloader) * cfg.training.num_epochs,
        )

        optimizers.append(("encoder", opt_encoder))
        lr_schedulers.append(lr_scheduler_encoder)

    if state_dict is not None:
        if "model_state_dict" in state_dict:
            actor.load_state_dict(state_dict["model_state_dict"])
            for (name, opt), scheduler in zip(optimizers, lr_schedulers):
                opt.load_state_dict(state_dict[f"{name}_optimizer_state_dict"])
                scheduler.load_state_dict(state_dict[f"{name}_scheduler_state_dict"])

        else:
            actor.load_state_dict(state_dict)

        print(f"Loaded weights from run {run_id}")

    if cfg.training.ema.use:
        ema = SwitchEMA(actor, cfg.training.ema.decay)
        ema.register()

    early_stopper = EarlyStopper(
        patience=cfg.early_stopper.patience,
        smooth_factor=cfg.early_stopper.smooth_factor,
    )
    config_dict = OmegaConf.to_container(cfg, resolve=True)

    # Init wandb
    run = wandb.init(
        id=cfg.wandb.continue_run_id,
        name=cfg.wandb.name,
        resume=None if cfg.wandb.continue_run_id is None else "allow",
        project=cfg.wandb.get("project", "maniskill_bc"),
        entity=cfg.wandb.get("entity", "flare-robot"),
        config=config_dict,
        mode=cfg.wandb.mode,
        notes=cfg.wandb.notes,
    )

    if cfg.wandb.watch_model:
        run.watch(actor, log="all", log_freq=1000)

    # Print the run name and storage location
    print(f"Run name: {run.name}")
    print(f"Run storage location: {run.dir}")

    # In sweeps, the init is ignored, so to make sure that the cfg is saved correctly
    # to wandb we need to log it manually
    wandb.config.update(config_dict)

    # save stats to wandb and update the cfg object
    train_size = int(dataset.n_samples * (1 - cfg.data.test_split))
    test_size = dataset.n_samples - train_size

    dataset_stats = {
        "num_samples_train": train_size,
        "num_samples_test": test_size,
        "num_episodes_train": int(
            len(dataset.episode_ends) * (1 - cfg.data.test_split)
        ),
        "num_episodes_test": int(len(dataset.episode_ends) * cfg.data.test_split),
        "dataset_metadata": dataset.metadata,
    }

    # Add the dataset stats to the wandb summary
    wandb.summary.update(dataset_stats)

    starttime = now()
    wandb.summary["start_time"] = starttime

    # Create model save dir
    model_save_dir = Path(cfg.training.model_save_dir) / wandb.run.name
    model_save_dir.mkdir(parents=True, exist_ok=True)

    print(f"Job started at: {starttime}")

    early_stop = False

    pbar_desc = f"Epoch ({cfg.task}, {cfg.observation_type}{f', {cfg.vision_encoder.model}' if cfg.observation_type == 'image' else ''})"

    tglobal = trange(
        cfg.training.start_epoch,
        cfg.training.num_epochs,
        initial=cfg.training.start_epoch,
        total=cfg.training.num_epochs,
        desc=pbar_desc,
    )

    for epoch_idx in tglobal:
        epoch_loss = list()
        test_loss = list()

        epoch_log = {
            "epoch": epoch_idx,
        }

        train_losses_log = defaultdict(list)

        # batch loop
        actor.train()
        tepoch = tqdm(trainloader, desc="Training", leave=False, total=len(trainloader))
        for batch in tepoch:
            # Zero the gradients in all optimizers
            for _, opt in optimizers:
                opt.zero_grad()

            # Get a batch on device and compute loss and gradients
            batch = dict_to_device(batch, device)
            loss, losses_log = actor.compute_loss(batch)
            loss.backward()

            # Gradient clipping
            grad_norm = torch.nn.utils.clip_grad_norm_(
                actor.parameters(),
                max_norm=1.0 + 1e3 * (1 - cfg.training.clip_grad_norm),
            )

            # Step the optimizers and schedulers
            for (_, opt), scheduler in zip(optimizers, lr_schedulers):
                opt.step()
                scheduler.step()

            if cfg.training.ema.use:
                ema.update()

            # Log the loss and gradients
            loss_cpu = loss.item()

            train_losses_log["grad_norm"] = grad_norm.item()

            for k, v in losses_log.items():
                train_losses_log[k].append(v)

            epoch_loss.append(loss_cpu)

            # Update the global step
            global_step += 1

            tepoch.set_postfix(loss=loss_cpu)

        tepoch.close()

        epoch_log["epoch_loss"] = np.mean(epoch_loss)

        for k, v in train_losses_log.items():
            epoch_log[f"train_{k}"] = np.mean(v)

        # Add the learning rates to the log
        for name, opt in optimizers:
            epoch_log[f"{name}_lr"] = opt.param_groups[0]["lr"]

        # Prepare the save dict once and we can reuse below
        save_dict = {
            "model_state_dict": actor.state_dict(),
            "best_test_loss": best_test_loss,
            "best_success_rate": best_success_rate,
            "epoch": epoch_idx,
            "global_step": global_step,
            "config": OmegaConf.to_container(cfg, resolve=True),
        }

        # Add the optimizer and scheduler states to the save dict
        for (name, opt), scheduler in zip(optimizers, lr_schedulers):
            save_dict[f"{name}_optimizer_state_dict"] = opt.state_dict()
            save_dict[f"{name}_scheduler_state_dict"] = scheduler.state_dict()

        if (
            cfg.training.eval_every > 0
            and (epoch_idx + 1) % cfg.training.eval_every == 0
        ):
            # Evaluation loop
            actor.eval()

            if cfg.training.ema.use:
                ema.apply_shadow()

            eval_losses_log = defaultdict(list)

            test_tepoch = tqdm(testloader, desc="Validation", leave=False)
            for test_batch in test_tepoch:
                with torch.no_grad():
                    # device transfer for test_batch
                    test_batch = dict_to_device(test_batch, device)

                    # Get test loss
                    test_loss_val, losses_log = actor.compute_loss(test_batch)

                    # logging
                    test_loss_cpu = test_loss_val.item()
                    test_loss.append(test_loss_cpu)
                    test_tepoch.set_postfix(loss=test_loss_cpu)

                    # Append the losses to the log
                    for k, v in losses_log.items():
                        eval_losses_log[k].append(v)

            test_tepoch.close()

            epoch_log["test_epoch_loss"] = test_loss_mean = np.mean(test_loss)
            # Update the epoch log with the mean of the evaluation losses

            for k, v in eval_losses_log.items():
                epoch_log[f"test_{k}"] = np.mean(v)

            if (
                cfg.rollout.rollouts
                and (epoch_idx + 1) % cfg.rollout.every == 0
                and np.mean(test_loss_mean) < cfg.rollout.loss_threshold
            ):
                if is_maniskill:
                    # Use dedicated ManiSkill evaluation
                    try:
                        # Get evaluation parameters
                        eval_episodes = cfg.eval_episodes if hasattr(cfg, "eval_episodes") else cfg.rollout.count
                        
                        current_eval_env = eval_env
                        need_cleanup = False
       
                        if hasattr(cfg, "eval_record_video") and cfg.eval_record_video and wandb.run:
                            video_dir = os.path.join("runs", wandb.run.name, "videos", f"epoch_{epoch_idx}")
                            if hasattr(current_eval_env, "set_video_directory") and callable(getattr(current_eval_env, "set_video_directory")):
                                current_eval_env.set_video_directory(video_dir)
                        
                        # Run evaluation
                        print(f"Running evaluation with {eval_episodes} episodes...")
                        metrics = evaluate_maniskill(
                            actor=actor,
                            env=current_eval_env,
                            num_episodes=eval_episodes,
                            device=device,
                        )
                        
                        # Process results
                        success_rate = metrics.get("success_rate", 0.0)
                        best_success_rate = max(best_success_rate, success_rate)
                        
                        # Log to wandb
                        wandb_metrics = {
                            "eval/success_rate": success_rate,
                            "eval/best_success_rate": best_success_rate,
                            "epoch": epoch_idx,
                        }
                        
                        # Add other metrics
                        for k, v in metrics.items():
                            if isinstance(v, np.ndarray):
                                wandb_metrics[f"eval/{k}"] = float(np.mean(v))
                            else:
                                wandb_metrics[f"eval/{k}"] = v
                        
                        wandb.log(wandb_metrics)
                        
                        # Clean up temporary environment if we created one
                        if need_cleanup and current_eval_env is not None:
                            current_eval_env.close()
                        
                        # Print results
                        print(f"Epoch {epoch_idx}: Success rate = {success_rate:.4f}, Best = {best_success_rate:.4f}")
                        
                    except Exception as e:
                        print(f"WARNING: ManiSkill evaluation failed with error: {e}")
                        print(f"Continuing with previous best success rate: {best_success_rate}")
                        # Log metric for the failure
                        wandb.log(
                            {
                                "eval/eval_failure": True,
                                "eval/best_success_rate": best_success_rate,
                                "epoch": epoch_idx,
                            }
                        )
                else:
                    # Use FurnitureBench
                    if env is None:
                        env = get_rl_env(
                            cfg.training.gpu_id,
                            task=cfg.task,
                            num_envs=cfg.rollout.num_envs,
                            randomness=cfg.rollout.randomness,
                            observation_space=cfg.observation_type,
                            resize_img=False,
                            act_rot_repr=cfg.control.act_rot_repr,
                            action_type=cfg.control.control_mode,
                            parts_poses_in_robot_frame=cfg.rollout.parts_poses_in_robot_frame,
                            headless=True,
                            verbose=True,
                        )

                    best_success_rate = do_rollout_evaluation(
                        config=cfg,
                        env=env,
                        save_rollouts_to_file=cfg.rollout.save_rollouts,
                        save_rollouts_to_wandb=False,
                        actor=actor,
                        best_success_rate=best_success_rate,
                        epoch_idx=epoch_idx,
                    )

            # Save the model if the test loss is the best so far
            if (
                cfg.training.store_best_test_loss_model
                and test_loss_mean < best_test_loss
            ):
                best_test_loss = test_loss_mean
                save_path = str(model_save_dir / f"actor_chkpt_best_test_loss.pt")
                torch.save(save_dict, save_path)
                wandb.save(save_path)

            # Save the model if the success rate is the best so far
            if (
                cfg.training.store_best_success_rate_model
                and best_success_rate > prev_best_success_rate
            ):
                prev_best_success_rate = best_success_rate
                save_path = str(model_save_dir / f"actor_chkpt_best_success_rate.pt")
                torch.save(save_dict, save_path)
                wandb.save(save_path)

            if (
                cfg.training.checkpoint_interval > 0
                and (epoch_idx + 1) % cfg.training.checkpoint_interval == 0
            ):
                save_path = str(model_save_dir / f"actor_chkpt_{epoch_idx}.pt")
                torch.save(save_dict, save_path)
                wandb.save(save_path)

            # Run diffusion sampling on a training batch
            if (
                cfg.training.sample_every > 0
                and (epoch_idx + 1) % cfg.training.sample_every == 0
            ):

                with torch.no_grad():
                    # sample trajectory from training set, and evaluate difference
                    train_sampling_batch = dict_to_device(
                        next(iter(trainloader)), device
                    )
                    
                    # For ManiSkill datasets, ensure robot_state and parts_poses are correct
                    # Using same convention as furniture bench
                    if is_maniskill:
                        ####
                        ####
                        ####
                        # TODO HARDCODED: robot_state should be first 8 dims ONLY FOR PANDA
                        ####
                        ####
                        ####
                        if "robot_state" in train_sampling_batch and train_sampling_batch["robot_state"].shape[-1] != 8:
                            # Fix the dimensions - robot_state should be first 8 dims
                            robot_state_dim = min(8, train_sampling_batch["obs"].shape[-1])
                            train_sampling_batch["robot_state"] = train_sampling_batch["obs"][..., :robot_state_dim]
                            
                            # Parts poses should be the remaining dimensions
                            if train_sampling_batch["obs"].shape[-1] > robot_state_dim:
                                train_sampling_batch["parts_poses"] = train_sampling_batch["obs"][..., robot_state_dim:]
                            else:
                                empty_shape = list(train_sampling_batch["obs"].shape[:-1]) + [0]
                                train_sampling_batch["parts_poses"] = torch.zeros(
                                    empty_shape, dtype=torch.float32, device=device
                                )
                    
                    pred_action = actor.action_pred(train_sampling_batch)
                    gt_action = actor.normalizer(
                        train_sampling_batch["action"], "action", forward=False
                    )
                    log_action_mse(epoch_log, "train", pred_action, gt_action, is_maniskill=is_maniskill)

                    val_sampling_batch = dict_to_device(next(iter(testloader)), device)
                    
                    # Apply the same fixes to validation batch
                    if is_maniskill:
                        ####
                        ####
                        ####
                        # TODO HARDCODED: robot_state should be first 8 dims ONLY FOR PANDA
                        ####
                        ####
                        ####
                        if "robot_state" in val_sampling_batch and val_sampling_batch["robot_state"].shape[-1] != 8:
                            # Fix the dimensions - robot_state should be first 8 dims
                            robot_state_dim = min(8, val_sampling_batch["obs"].shape[-1])
                            val_sampling_batch["robot_state"] = val_sampling_batch["obs"][..., :robot_state_dim]
                            
                            # Parts poses should be the remaining dimensions
                            if val_sampling_batch["obs"].shape[-1] > robot_state_dim:
                                val_sampling_batch["parts_poses"] = val_sampling_batch["obs"][..., robot_state_dim:]
                            else:
                                empty_shape = list(val_sampling_batch["obs"].shape[:-1]) + [0]
                                val_sampling_batch["parts_poses"] = torch.zeros(
                                    empty_shape, dtype=torch.float32, device=device
                                )
                    
                    gt_action = actor.normalizer(
                        val_sampling_batch["action"], "action", forward=False
                    )
                    pred_action = actor.action_pred(val_sampling_batch)
                    log_action_mse(epoch_log, "val", pred_action, gt_action, is_maniskill=is_maniskill)

            # If using EMA, restore the model
            if cfg.training.ema.use:
                ema.restore()

            # Since we now have a new test loss, we can update the early stopper
            early_stop = early_stopper.update(test_loss_mean)
            epoch_log["early_stopper/counter"] = early_stopper.counter
            epoch_log["early_stopper/best_loss"] = early_stopper.best_loss
            epoch_log["early_stopper/ema_loss"] = early_stopper.ema_loss

        # We store the last model at the end of each epoch for better checkpointing
        if cfg.training.store_last_model:
            save_path = str(model_save_dir / f"actor_chkpt_last.pt")
            torch.save(save_dict, save_path)
            wandb.save(save_path)

        # If switch is enabled, copy the the shadow to the model at the end of each epoch
        if cfg.training.ema.use and cfg.training.ema.switch:
            ema.copy_to_model()

        # Log epoch stats
        wandb.log(epoch_log, step=global_step)
        tglobal.set_postfix(
            time=now(),
            loss=epoch_log["epoch_loss"],
            test_loss=test_loss_mean,
            best_success_rate=best_success_rate,
            stopper_counter=early_stopper.counter,
        )

        # If we are in offline mode, trigger the sync
        if (
            cfg.wandb.mode == "offline"
            and (epoch_idx % cfg.wandb.get("osh_sync_interval", 1)) == 0
        ):
            trigger_sync()

        # Now that everything is logged and restored, we can check if we need to stop
        if early_stop:
            print(
                f"Early stopping at epoch {epoch_idx} as test loss did not improve for {early_stopper.patience} epochs."
            )
            break

    tglobal.close()
    
    if eval_env is not None:
        print("Closing persistent evaluation environment")
        try:
            eval_env.close()
        except Exception as e:
            print(f"Error closing evaluation environment: {e}")
    
    wandb.finish()


if __name__ == "__main__":
    main()