"""
FTL-IGM Finetune Agent for pick_place task with PDP-FiLM model.

Unlike close_drawer (single trajectory), pick_place has TWO learned phases:
1. REACH (64 steps): home -> pregrasp
2. CARRY (64 steps): lift -> prerelease

Given a wall-avoiding demo (from block_setting/style{id}), this agent:
1. Loads the demo directly from block_setting/style{id}/episodes/episode0/
2. Segments the demo into REACH and CARRY phases
3. Encodes each phase to get warm start z_reach and z_carry
4. Optimizes z_reach using FTL-IGM (denoising score matching) on REACH segment
5. Optimizes z_carry using FTL-IGM on CARRY segment
6. Saves both optimized z values
7. Automatically runs the eval agent with the optimized z values

Demo structure (168 observations total):
- REACH: observations[0:64]
- DESCEND: observations[64:72]
- GRASP: observations[72:80]
- LIFT: observations[80:88]
- CARRY: observations[88:152]
- DESCEND_RELEASE: observations[152:160]
- RELEASE: observations[160:168]
"""

import os
import sys
import logging
import subprocess
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import hydra

# Add paths
dppo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(dppo_root)
sys.path.append(os.path.join(dppo_root, 'RLBench'))
sys.path.append(os.path.join(dppo_root, 'RLBench_pick_place', 'encoder'))

log = logging.getLogger(__name__)

# Phase configuration for pick_place
PHASE_CONFIG = {
    'reach': (0, 64),           # REACH phase: observations[0:64]
    'descend': (64, 72),        # DESCEND phase: observations[64:72]
    'grasp': (72, 80),          # GRASP phase: observations[72:80]
    'lift': (80, 88),           # LIFT phase: observations[80:88]
    'carry': (88, 152),         # CARRY phase: observations[88:152]
    'descend_release': (152, 160),  # DESCEND_RELEASE phase: observations[152:160]
    'release': (160, 168),      # RELEASE phase: observations[160:168]
}

# EE position normalization bounds (same as training dataset)
EE_POS_MIN = np.array([0.0, -0.6, 0.0], dtype=np.float32)
EE_POS_MAX = np.array([1.0, 0.6, 1.6], dtype=np.float32)


def normalize_ee_position(ee_raw):
    """Normalize EE position from raw to [-1, 1] using robot workspace bounds."""
    ee_range = EE_POS_MAX - EE_POS_MIN
    return 2.0 * (ee_raw - EE_POS_MIN) / ee_range - 1.0


class FinetunePickPlaceAgent:
    """FTL-IGM finetuning agent for pick_place PDP-FiLM with two-phase optimization."""

    def __init__(self, cfg):
        self.cfg = cfg
        self.device = cfg.device
        self.style = cfg.style  # Single style for both demo and wall

        # Set random seed for reproducibility
        self.seed = cfg.seed
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        # Load normalization stats
        self.norm_stats = self._load_normalization(cfg.normalization_path)

        # Load encoder for warm start
        if cfg.encoder_checkpoint is not None:
            self.encoder = self._load_encoder(cfg.encoder_checkpoint)
        else:
            self.encoder = None
            self.latent_dim = cfg.model.network.z_dim
            log.info(f"No encoder provided, using z_dim={self.latent_dim} from model config")

        # Load diffusion model
        self.model = hydra.utils.instantiate(cfg.model)
        self.model.eval()

        # Freeze diffusion model weights (only optimize z)
        for param in self.model.network.parameters():
            param.requires_grad = False

        # Load waypoints for subgoal conditioning
        self._load_waypoints()

        # Create output directory
        self.output_dir = cfg.logdir
        os.makedirs(self.output_dir, exist_ok=True)

    def _load_normalization(self, norm_path):
        """Load normalization statistics."""
        log.info(f"Loading normalization from {norm_path}")
        norm_data = np.load(norm_path)

        obs_min = torch.from_numpy(norm_data['obs_min']).float().to(self.device)
        obs_max = torch.from_numpy(norm_data['obs_max']).float().to(self.device)
        action_min = torch.from_numpy(norm_data['action_min']).float().to(self.device)
        action_max = torch.from_numpy(norm_data['action_max']).float().to(self.device)

        obs_range = obs_max - obs_min
        obs_range = torch.where(obs_range < 1e-6, torch.ones_like(obs_range), obs_range)
        action_range = action_max - action_min
        action_range = torch.where(action_range < 1e-6, torch.ones_like(action_range), action_range)

        return {
            'obs_min': obs_min, 'obs_max': obs_max, 'obs_range': obs_range,
            'action_min': action_min, 'action_max': action_max, 'action_range': action_range,
            'obs_min_np': norm_data['obs_min'], 'obs_max_np': norm_data['obs_max'],
            'action_min_np': norm_data['action_min'], 'action_max_np': norm_data['action_max'],
        }

    def _load_encoder(self, checkpoint_path):
        """Load trajectory encoder (same encoder for both REACH and CARRY)."""
        from trajectory_encoder_normalized_ee import TrajectoryVAENormalizedEE

        log.info(f"Loading encoder from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
        config = checkpoint['config']

        encoder = TrajectoryVAENormalizedEE(
            state_dim=config.get('state_dim', 22),
            action_dim=config.get('action_dim', 8),
            hidden_dim=config['hidden_dim'],
            num_layers=config['num_layers'],
            num_heads=config['num_heads'],
            latent_dim=config['latent_dim'],
            horizon=config['horizon']
        ).to(self.device)

        encoder.load_state_dict(checkpoint['model_state_dict'])
        encoder.eval()
        for param in encoder.parameters():
            param.requires_grad = False

        self.latent_dim = config['latent_dim']
        return encoder

    def _load_waypoints(self):
        """Load waypoint positions from demo metadata for subgoal conditioning."""
        # Load from style summary
        summary_path = os.path.join(
            self.cfg.demo_base_path,
            f'style{self.style}',
            'summary.npy'
        )
        if os.path.exists(summary_path):
            summary = np.load(summary_path, allow_pickle=True).item()
            self.waypoints = {
                'home_pos': np.array(summary['home_pos']),
                'pregrasp_pos': np.array(summary['pregrasp_pos']),
                'lift_pos': np.array(summary['lift_pos']),
                'prerelease_pos': np.array(summary['prerelease_pos']),
            }
            log.info(f"Loaded waypoints from {summary_path}")
        else:
            # Fallback to generic init file
            init_path = os.path.join(
                os.path.dirname(self.cfg.demo_base_path),
                'make_dataset', 'stack_blocks_init.npz'
            )
            if os.path.exists(init_path):
                init_data = np.load(init_path)
                self.waypoints = {
                    'home_pos': init_data['home_pos'],
                    'pregrasp_pos': init_data['pregrasp_pos'],
                    'lift_pos': init_data['lift_pos'],
                    'prerelease_pos': init_data['prerelease_pos'],
                }
                log.info(f"Loaded waypoints from {init_path}")
            else:
                raise FileNotFoundError(f"Cannot find waypoints. Tried: {summary_path}, {init_path}")

        # Compute normalized subgoals for conditioning
        self.subgoal_reach = normalize_ee_position(
            np.array(self.waypoints['pregrasp_pos'], dtype=np.float32)
        )
        self.subgoal_carry = normalize_ee_position(
            np.array(self.waypoints['prerelease_pos'], dtype=np.float32)
        )
        log.info(f"Subgoals - REACH: {self.subgoal_reach}, CARRY: {self.subgoal_carry}")

    def _load_demo_from_pkl(self, demo_path):
        """
        Load demo from low_dim_obs.pkl and segment into REACH and CARRY phases.

        State format: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7) = 22
        Action format: joint_pos(7) + gripper_open(1) = 8

        Returns:
            reach_states, reach_actions: (64, state_dim), (64, action_dim)
            carry_states, carry_actions: (64, state_dim), (64, action_dim)
            ee_positions_reach: (64, 3) - raw EE positions for encoding
            ee_positions_carry: (64, 3) - raw EE positions for encoding
        """
        pkl_path = os.path.join(demo_path, "low_dim_obs.pkl")
        if not os.path.exists(pkl_path):
            raise FileNotFoundError(
                f"Demo file not found: {pkl_path}\n"
                f"Please generate demo using the block_setting generator."
            )

        log.info(f"Loading demo from {pkl_path}")
        with open(pkl_path, "rb") as f:
            demo = pickle.load(f)

        log.info(f"Demo has {len(demo)} observations")

        # Extract states and actions from ALL observations
        all_states = []
        all_actions = []
        all_ee_positions = []

        for obs in demo:
            # State: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
            state = np.concatenate([
                obs.joint_positions,           # 7
                obs.joint_velocities,          # 7
                [obs.gripper_open],            # 1
                obs.gripper_pose,              # 7 (pos + quat)
            ])
            all_states.append(state)

            # Raw EE position for encoding
            ee_pos = obs.gripper_pose[:3]
            all_ee_positions.append(ee_pos)

            # Action: from obs.misc['joint_position_action'] if available
            if hasattr(obs, 'misc') and 'joint_position_action' in obs.misc:
                action = obs.misc['joint_position_action']
            else:
                # Fallback: use joint_positions + gripper_open
                action = np.concatenate([obs.joint_positions, [obs.gripper_open]])
            all_actions.append(action)

        all_states = np.array(all_states)
        all_actions = np.array(all_actions)
        all_ee_positions = np.array(all_ee_positions)

        log.info(f"Extracted: states={all_states.shape}, actions={all_actions.shape}, ee={all_ee_positions.shape}")

        # Segment into REACH and CARRY
        reach_start, reach_end = PHASE_CONFIG['reach']
        carry_start, carry_end = PHASE_CONFIG['carry']

        reach_states = all_states[reach_start:reach_end]
        reach_actions = all_actions[reach_start:reach_end]
        reach_ee = all_ee_positions[reach_start:reach_end]

        carry_states = all_states[carry_start:carry_end]
        carry_actions = all_actions[carry_start:carry_end]
        carry_ee = all_ee_positions[carry_start:carry_end]

        log.info(f"REACH segment: states={reach_states.shape}, actions={reach_actions.shape}")
        log.info(f"CARRY segment: states={carry_states.shape}, actions={carry_actions.shape}")

        # Verify horizon
        horizon = self.cfg.horizon_steps
        assert reach_states.shape[0] == horizon, f"REACH has {reach_states.shape[0]} steps, expected {horizon}"
        assert carry_states.shape[0] == horizon, f"CARRY has {carry_states.shape[0]} steps, expected {horizon}"

        return reach_states, reach_actions, reach_ee, carry_states, carry_actions, carry_ee

    def _normalize_states_actions(self, states, actions):
        """Normalize states and actions to [-1, 1]."""
        obs_min = self.norm_stats['obs_min_np']
        obs_max = self.norm_stats['obs_max_np']
        action_min = self.norm_stats['action_min_np']
        action_max = self.norm_stats['action_max_np']

        obs_range = obs_max - obs_min
        obs_range = np.where(obs_range < 1e-6, 1.0, obs_range)
        action_range = action_max - action_min
        action_range = np.where(action_range < 1e-6, 1.0, action_range)

        # Normalize to [-1, 1]
        states_norm = 2.0 * (states - obs_min) / obs_range - 1.0
        actions_norm = 2.0 * (actions - action_min) / action_range - 1.0

        return states_norm, actions_norm

    def _normalize_ee_trajectory(self, ee_pos):
        """
        Normalize EE trajectory to trajectory-relative coordinates.

        Same as pick_place_parameterized_sequence.normalize_ee_trajectory_numpy()

        ee_pos: (T, 3) - end-effector positions (RAW, not normalized)
        Returns: (T, 3) - [progress, perp1_offset, perp2_offset]
        """
        T = ee_pos.shape[0]

        start_pos = ee_pos[0]
        end_pos = ee_pos[-1]

        # Build local frame
        line_vec = end_pos - start_pos
        dist = np.linalg.norm(line_vec)
        if dist < 1e-6:
            dist = 1e-6
        line_dir = line_vec / dist

        # Perpendicular axes
        world_up = np.array([0.0, 0.0, 1.0])
        dot = np.dot(world_up, line_dir)
        perp1 = world_up - dot * line_dir
        perp1_len = np.linalg.norm(perp1)

        if perp1_len < 1e-6:
            world_forward = np.array([0.0, 1.0, 0.0])
            dot_fwd = np.dot(world_forward, line_dir)
            perp1 = world_forward - dot_fwd * line_dir
            perp1_len = np.linalg.norm(perp1)

        perp1 = perp1 / perp1_len
        perp2 = np.cross(line_dir, perp1)
        perp2 = perp2 / np.linalg.norm(perp2)

        # Normalize each point
        ee_normalized = np.zeros((T, 3))

        for t in range(T):
            vec_from_start = ee_pos[t] - start_pos

            # Progress along the line
            progress = np.dot(vec_from_start, line_dir) / dist

            # Point on straight line at this progress
            point_on_line = start_pos + progress * (end_pos - start_pos)

            # Offset from straight line
            offset = ee_pos[t] - point_on_line

            # Project onto perpendicular axes
            perp1_offset = np.dot(offset, perp1) / dist
            perp2_offset = np.dot(offset, perp2) / dist

            ee_normalized[t] = [progress, perp1_offset, perp2_offset]

        return ee_normalized

    def _encode_phase(self, ee_positions, phase_name):
        """Encode a single phase's EE trajectory to get z embedding."""
        if self.encoder is None:
            log.info(f"No encoder - using random z for {phase_name}")
            return torch.randn(self.latent_dim).to(self.device) * 0.1

        # Normalize EE trajectory to trajectory-relative coords
        ee_normalized = self._normalize_ee_trajectory(ee_positions)

        # Encode
        ee_tensor = torch.tensor(
            ee_normalized, dtype=torch.float32, device=self.device
        ).unsqueeze(0)  # (1, T, 3)

        with torch.no_grad():
            z = self.encoder.encode(ee_tensor).squeeze(0)  # (latent_dim,)

        log.info(f"Encoded {phase_name}: z={z.cpu().numpy()}, norm={z.norm().item():.3f}")
        return z

    def optimize_z(self, demo_states, demo_actions, subgoal, n_epochs=300, lr=3e-3,
                   n_timestep_samples=16, z_reg_weight=0.0, warm_start=True, init_z=None,
                   phase_name='phase'):
        """
        Optimize z using FTL-IGM (denoising score matching) for a single phase.

        Args:
            demo_states: (T, state_dim) - normalized states in [-1, 1]
            demo_actions: (T, action_dim) - normalized actions in [-1, 1]
            subgoal: (3,) - normalized subgoal position for conditioning
            n_epochs: number of optimization epochs
            lr: learning rate
            n_timestep_samples: number of timesteps to sample per iteration
            z_reg_weight: weight for z regularization (0 = no reg)
            warm_start: if True, use init_z as starting point
            init_z: initial z value (from encoder)
            phase_name: 'reach' or 'carry' for logging

        Returns:
            z: optimized latent code
            loss_history: list of loss values
        """
        demo_states = torch.FloatTensor(demo_states).to(self.device)
        demo_actions = torch.FloatTensor(demo_actions).to(self.device)
        subgoal_tensor = torch.FloatTensor(subgoal).to(self.device)

        # Initialize z
        if warm_start and init_z is not None:
            log.info(f"[{phase_name}] Warm start: Using encoder z (norm={init_z.norm().item():.3f})")
            z = nn.Parameter(init_z.clone().detach().requires_grad_(True))
        else:
            log.info(f"[{phase_name}] Cold start: Random initialization for z")
            z = nn.Parameter(torch.randn(self.latent_dim).to(self.device) * 0.1)

        optimizer = torch.optim.Adam([z], lr=lr)
        loss_history = []

        # Diffusion parameters
        alphas_cumprod = self.model.alphas_cumprod
        sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

        # Condition: first state + subgoal (same as eval)
        # State format: (1, 1, obs_dim) where obs_dim = 22 (state) + 3 (subgoal) = 25
        first_state = demo_states[0:1]  # (1, 22)
        cond_state = torch.cat([first_state, subgoal_tensor.unsqueeze(0)], dim=-1)  # (1, 25)
        cond_state = cond_state.unsqueeze(0)  # (1, 1, 25)

        horizon_steps = self.cfg.horizon_steps
        action_dim = self.cfg.action_dim

        pbar = tqdm(range(n_epochs), desc=f"Optimizing z ({phase_name})")
        for epoch in pbar:
            optimizer.zero_grad()

            # Sample K timesteps and noises
            K = n_timestep_samples
            t = torch.randint(0, self.model.denoising_steps, (K,), device=self.device).long()
            noise = torch.randn(K, horizon_steps, action_dim).to(self.device)

            # Expand demo actions for batch: (1, T, action_dim) -> (K, T, action_dim)
            demo_actions_batch = demo_actions.unsqueeze(0).expand(K, -1, -1)

            # Add noise to actions
            noised_actions = (sqrt_alphas_cumprod[t].view(K, 1, 1) * demo_actions_batch +
                             sqrt_one_minus_alphas_cumprod[t].view(K, 1, 1) * noise)

            # Expand condition and z for batch
            cond_batch = {'state': cond_state.expand(K, -1, -1)}
            z_batch = z.unsqueeze(0).expand(K, -1)

            # Predict noise with z
            with torch.enable_grad():
                predicted_noise = self.model.network(noised_actions, t, cond=cond_batch, z=z_batch)

            # Loss: denoising score matching + optional z regularization
            loss = F.mse_loss(predicted_noise, noise)
            if z_reg_weight > 0:
                loss = loss + z_reg_weight * torch.norm(z, p=2)

            loss.backward()
            torch.nn.utils.clip_grad_norm_([z], max_norm=1.0)
            optimizer.step()

            loss_history.append(loss.item())
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'z_norm': f'{z.norm().item():.3f}'})

        return z.detach(), loss_history

    def run(self):
        """Run the two-phase finetune process."""
        log.info("=" * 70)
        log.info(f"FTL-IGM Finetune for pick_place (style={self.style})")
        log.info(f"Two-phase optimization: REACH and CARRY")
        log.info("=" * 70)

        # 1. Load target demo directly from block_setting
        demo_path = os.path.join(
            self.cfg.demo_base_path,
            f'style{self.style}',
            'episodes', 'episode0'
        )

        # Load metadata
        summary_path = os.path.join(
            self.cfg.demo_base_path,
            f'style{self.style}',
            'summary.npy'
        )
        if os.path.exists(summary_path):
            summary = np.load(summary_path, allow_pickle=True).item()
            log.info(f"Target demo metadata: style={summary.get('style')}, "
                    f"wall_config={summary.get('wall_config')}")

        # 2. Load and segment demo
        reach_states, reach_actions, reach_ee, carry_states, carry_actions, carry_ee = \
            self._load_demo_from_pkl(demo_path)

        # Normalize states/actions
        reach_states_norm, reach_actions_norm = self._normalize_states_actions(reach_states, reach_actions)
        carry_states_norm, carry_actions_norm = self._normalize_states_actions(carry_states, carry_actions)

        log.info(f"Normalized REACH: states={reach_states_norm.shape}, actions={reach_actions_norm.shape}")
        log.info(f"Normalized CARRY: states={carry_states_norm.shape}, actions={carry_actions_norm.shape}")

        # 3. Encode both phases to get warm start z values
        z_reach_init = self._encode_phase(reach_ee, 'reach')
        z_carry_init = self._encode_phase(carry_ee, 'carry')

        # 4. Optimize z_reach
        n_epochs = self.cfg.ftl_igm.n_epochs
        if n_epochs > 0:
            log.info("\n--- FTL-IGM Optimization: REACH ---")
            z_reach_optimized, loss_history_reach = self.optimize_z(
                demo_states=reach_states_norm,
                demo_actions=reach_actions_norm,
                subgoal=self.subgoal_reach,
                n_epochs=n_epochs,
                lr=self.cfg.ftl_igm.lr,
                n_timestep_samples=self.cfg.ftl_igm.n_timestep_samples,
                z_reg_weight=self.cfg.ftl_igm.z_reg_weight,
                warm_start=self.cfg.ftl_igm.warm_start,
                init_z=z_reach_init,
                phase_name='reach'
            )
            log.info(f"Optimized z_reach: {z_reach_optimized.cpu().numpy()}")
            log.info(f"Final REACH loss: {loss_history_reach[-1]:.4f}")
        else:
            log.info("\n--- Skipping FTL-IGM for REACH (n_epochs=0) ---")
            z_reach_optimized = z_reach_init
            loss_history_reach = []

        # 5. Optimize z_carry
        if n_epochs > 0:
            log.info("\n--- FTL-IGM Optimization: CARRY ---")
            z_carry_optimized, loss_history_carry = self.optimize_z(
                demo_states=carry_states_norm,
                demo_actions=carry_actions_norm,
                subgoal=self.subgoal_carry,
                n_epochs=n_epochs,
                lr=self.cfg.ftl_igm.lr,
                n_timestep_samples=self.cfg.ftl_igm.n_timestep_samples,
                z_reg_weight=self.cfg.ftl_igm.z_reg_weight,
                warm_start=self.cfg.ftl_igm.warm_start,
                init_z=z_carry_init,
                phase_name='carry'
            )
            log.info(f"Optimized z_carry: {z_carry_optimized.cpu().numpy()}")
            log.info(f"Final CARRY loss: {loss_history_carry[-1]:.4f}")
        else:
            log.info("\n--- Skipping FTL-IGM for CARRY (n_epochs=0) ---")
            z_carry_optimized = z_carry_init
            loss_history_carry = []

        # 6. Save optimized z values
        # Save as a single file with both z values stacked: (2, z_dim)
        # Row 0 = z_reach, Row 1 = z_carry
        z_reach_np = z_reach_optimized.cpu().numpy().reshape(1, -1)
        z_carry_np = z_carry_optimized.cpu().numpy().reshape(1, -1)
        z_combined = np.concatenate([z_reach_np, z_carry_np], axis=0)  # (2, z_dim)

        z_save_path = os.path.join(self.output_dir, f'optimized_z_style{self.style}.npy')
        np.save(z_save_path, z_combined)
        log.info(f"Saved optimized z values to {z_save_path} (shape: {z_combined.shape})")
        log.info(f"  z_reach = {z_combined[0]}")
        log.info(f"  z_carry = {z_combined[1]}")

        # Also save separately for clarity
        z_reach_path = os.path.join(self.output_dir, f'z_reach_style{self.style}.npy')
        z_carry_path = os.path.join(self.output_dir, f'z_carry_style{self.style}.npy')
        np.save(z_reach_path, z_reach_np)
        np.save(z_carry_path, z_carry_np)

        # 7. Plot loss curves
        self._plot_loss_curves(loss_history_reach, loss_history_carry)

        # 8. Run eval agent with optimized z values
        log.info("\n" + "=" * 70)
        log.info("FTL-IGM Optimization Complete! Running evaluation...")
        log.info("=" * 70)

        self._run_eval(z_save_path)

    def _plot_loss_curves(self, loss_history_reach, loss_history_carry):
        """Plot and save the optimization loss curves for both phases."""
        if len(loss_history_reach) == 0 and len(loss_history_carry) == 0:
            log.info("No loss history to plot (optimization skipped)")
            return

        fig, axes = plt.subplots(1, 2, figsize=(14, 5))

        # REACH loss
        if len(loss_history_reach) > 0:
            axes[0].plot(loss_history_reach)
            axes[0].set_xlabel('Epoch')
            axes[0].set_ylabel('Loss')
            axes[0].set_title(f'REACH Phase (style={self.style})')
            axes[0].grid(True, alpha=0.3)
        else:
            axes[0].text(0.5, 0.5, 'Skipped', ha='center', va='center', transform=axes[0].transAxes)
            axes[0].set_title('REACH Phase (skipped)')

        # CARRY loss
        if len(loss_history_carry) > 0:
            axes[1].plot(loss_history_carry)
            axes[1].set_xlabel('Epoch')
            axes[1].set_ylabel('Loss')
            axes[1].set_title(f'CARRY Phase (style={self.style})')
            axes[1].grid(True, alpha=0.3)
        else:
            axes[1].text(0.5, 0.5, 'Skipped', ha='center', va='center', transform=axes[1].transAxes)
            axes[1].set_title('CARRY Phase (skipped)')

        plt.tight_layout()
        save_path = os.path.join(self.output_dir, 'loss_curves.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        log.info(f"Saved loss curves to {save_path}")

    def _run_eval(self, z_file_path):
        """Run the eval agent with the optimized z values."""
        eval_config_name = self.cfg.get('eval_config_name', 'eval_para_diffusion_film')
        env_name = self.cfg.get('env_name', 'pick_place')

        # Demo path for comparison
        demo_path = os.path.join(
            self.cfg.demo_base_path,
            f'style{self.style}',
            'episodes', 'episode0'
        )

        # Build the eval command
        # For finetuned z, we use n_noise_samples=40 to get 40 evaluation samples
        # (in dual z mode, there's only 1 "mode" so n_noise_samples controls total samples)
        eval_cmd = [
            "python", "script/run.py",
            f"--config-dir=cfg/rlbench/eval/{env_name}",
            f"--config-name={eval_config_name}",
            f"wall_style={self.style}",
            f"z_file={z_file_path}",
            f"n_noise_samples=40",  # 40 evaluation samples with the finetuned z
        ]

        log.info(f"Running eval command: {' '.join(eval_cmd)}")

        # Run from the dppo root directory
        result = subprocess.run(
            eval_cmd,
            cwd=dppo_root,
            capture_output=False,
        )

        if result.returncode != 0:
            log.error(f"Eval command failed with return code {result.returncode}")
        else:
            log.info("Evaluation completed successfully!")
