"""
Finetune Agent for baseline models (DP, BC, BC-GMM, IBC) on grasp (pick_up_cup) task.

For non-parameterized models, "finetuning" means:
- BC: Actually finetune the model on the demo trajectory
- DP: DP Inversion - optimize initial noise x_T to reconstruct demo
- BC-GMM: Component selection via posterior (no weight updates)
- IBC: Demo-initialized inference (no weight updates, but uses demo to init)

Demo selection via grasp_style:
- grasp_style=1: Mode 0° (train, episode0)
- grasp_style=2: Mode 270° (train, episode30, noise-free)
- grasp_style=3: Test 45° (novel angle)
"""

import os
import sys
import logging
import subprocess
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import hydra

# Add paths
dppo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(dppo_root)

log = logging.getLogger(__name__)


class FinetuneGraspBaselineAgent:
    """Finetune agent for baseline models (DP, BC, BC-GMM, IBC) on grasp task."""

    def __init__(self, cfg):
        self.cfg = cfg
        self.device = cfg.device
        self.model_type = cfg.model_type  # 'dp', 'bc', 'bc_gmm', 'ibc'
        self.do_training = cfg.get('do_training', False)
        self.grasp_style = cfg.grasp_style

        # Set random seed for reproducibility
        self.seed = cfg.seed
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        # Load normalization stats
        self.norm_stats = self._load_normalization(cfg.normalization_path)

        # Load model
        self.model = hydra.utils.instantiate(cfg.model)
        self.model.eval()

        # Create output directory
        self.output_dir = cfg.logdir
        os.makedirs(self.output_dir, exist_ok=True)

    def _load_normalization(self, norm_path):
        """Load normalization statistics."""
        log.info(f"Loading normalization from {norm_path}")
        norm_data = np.load(norm_path)

        obs_min = torch.from_numpy(norm_data['obs_min']).float().to(self.device)
        obs_max = torch.from_numpy(norm_data['obs_max']).float().to(self.device)
        action_min = torch.from_numpy(norm_data['action_min']).float().to(self.device)
        action_max = torch.from_numpy(norm_data['action_max']).float().to(self.device)

        obs_range = obs_max - obs_min
        obs_range = torch.where(obs_range < 1e-6, torch.ones_like(obs_range), obs_range)
        action_range = action_max - action_min
        action_range = torch.where(action_range < 1e-6, torch.ones_like(action_range), action_range)

        return {
            'obs_min': obs_min, 'obs_max': obs_max, 'obs_range': obs_range,
            'action_min': action_min, 'action_max': action_max, 'action_range': action_range,
            'obs_min_np': norm_data['obs_min'], 'obs_max_np': norm_data['obs_max'],
            'action_min_np': norm_data['action_min'], 'action_max_np': norm_data['action_max'],
        }

    def _load_demo_from_pkl(self, demo_path):
        """
        Load demo from low_dim_obs.pkl and extract states/actions.

        State format: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7) = 22
        Action format: joint_pos(7) + gripper_open(1) = 8
        """
        pkl_path = os.path.join(demo_path, "low_dim_obs.pkl")
        if not os.path.exists(pkl_path):
            raise FileNotFoundError(f"Demo file not found: {pkl_path}")

        log.info(f"Loading demo from {pkl_path}")
        with open(pkl_path, "rb") as f:
            demo = pickle.load(f)

        # Extract states and actions from observations
        states = []
        actions = []

        for obs in demo:
            # State: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
            state = np.concatenate([
                obs.joint_positions,           # 7
                obs.joint_velocities,          # 7
                [obs.gripper_open],            # 1
                obs.gripper_pose,              # 7 (pos + quat)
            ])
            states.append(state)

            # Action: from obs.misc['joint_position_action'] if available
            if hasattr(obs, 'misc') and 'joint_position_action' in obs.misc:
                action = obs.misc['joint_position_action']
            else:
                # Fallback: use joint_positions + gripper_open
                action = np.concatenate([obs.joint_positions, [obs.gripper_open]])
            actions.append(action)

        states = np.array(states)
        actions = np.array(actions)

        log.info(f"Loaded demo: {len(demo)} steps, states={states.shape}, actions={actions.shape}")

        # Adjust to match expected horizon
        horizon = self.cfg.horizon_steps
        if len(states) > horizon:
            log.info(f"Truncating demo from {len(states)} to {horizon} steps")
            states = states[:horizon]
            actions = actions[:horizon]
        elif len(states) < horizon:
            # Pad by repeating last state/action
            log.info(f"Padding demo from {len(states)} to {horizon} steps")
            pad_len = horizon - len(states)
            states = np.concatenate([states, np.tile(states[-1:], (pad_len, 1))], axis=0)
            actions = np.concatenate([actions, np.tile(actions[-1:], (pad_len, 1))], axis=0)

        return states, actions

    def _normalize_states_actions(self, states, actions):
        """Normalize states and actions to [-1, 1]."""
        obs_min = self.norm_stats['obs_min_np']
        obs_max = self.norm_stats['obs_max_np']
        action_min = self.norm_stats['action_min_np']
        action_max = self.norm_stats['action_max_np']

        obs_range = obs_max - obs_min
        obs_range = np.where(obs_range < 1e-6, 1.0, obs_range)
        action_range = action_max - action_min
        action_range = np.where(action_range < 1e-6, 1.0, action_range)

        # Normalize to [-1, 1]
        states_norm = 2.0 * (states - obs_min) / obs_range - 1.0
        actions_norm = 2.0 * (actions - action_min) / action_range - 1.0

        return states_norm, actions_norm

    def finetune_bc(self, all_states, all_actions):
        """
        Finetune BC model on all demos.

        Args:
            all_states: list of (T, obs_dim) normalized states per demo
            all_actions: list of (T, action_dim) normalized actions per demo

        Returns:
            loss_history: list of loss values
        """
        log.info(f"Finetuning BC model on {len(all_states)} demos...")

        # Stack all demos
        all_states_tensor = torch.stack([torch.FloatTensor(s).to(self.device) for s in all_states])
        all_actions_tensor = torch.stack([torch.FloatTensor(a).to(self.device) for a in all_actions])

        # Unfreeze model for training
        self.model.network.train()
        for param in self.model.network.parameters():
            param.requires_grad = True

        optimizer = torch.optim.Adam(
            self.model.network.parameters(),
            lr=self.cfg.finetune.lr
        )
        loss_history = []

        n_epochs = self.cfg.finetune.n_epochs
        batch_size = len(all_states)  # Use all demos in each batch

        pbar = tqdm(range(n_epochs), desc="Finetuning BC")
        for epoch in pbar:
            optimizer.zero_grad()

            # BC predicts full trajectory from first observation
            obs = all_states_tensor[:, 0:1, :]  # (batch, 1, obs_dim)
            target_actions = all_actions_tensor  # (batch, T, action_dim)

            # Forward pass
            pred_actions = self.model.network(obs)  # (batch, horizon, action_dim)

            # Compute loss
            loss = F.mse_loss(pred_actions, target_actions)

            loss.backward()
            optimizer.step()

            loss_history.append(loss.item())
            pbar.set_postfix({'loss': f'{loss.item():.6f}'})

        # Set back to eval mode
        self.model.network.eval()
        for param in self.model.network.parameters():
            param.requires_grad = False

        return loss_history

    def finetune_bc_gmm_component_selection(self, all_states, all_actions):
        """
        BC-GMM Component Selection via Posterior for multiple demos.

        Select the best GMM component that explains all demos:
            k* = argmax_k sum_d (log π_k + log N(a*_d; μ_k, Σ_k))

        Args:
            all_states: list of (T, obs_dim) normalized states per demo
            all_actions: list of (T, action_dim) normalized actions per demo

        Returns:
            selected_component: int, the index of the best component
            component_log_likelihoods: list of total log-likelihoods for each component
        """
        log.info(f"BC-GMM: Selecting best component via posterior for {len(all_states)} demos...")

        self.model.eval()

        with torch.no_grad():
            # Aggregate log-likelihoods across all demos
            K = self.cfg.num_modes
            component_log_likelihoods = [0.0] * K

            for demo_idx, (states, actions) in enumerate(zip(all_states, all_actions)):
                states_tensor = torch.FloatTensor(states).to(self.device)
                actions_tensor = torch.FloatTensor(actions).to(self.device)

                # Get state (first observation)
                state = states_tensor[0:1]  # (1, obs_dim)

                # Get GMM parameters from network
                means, scales, logits = self.model.network(state)
                # means: (1, K, trajectory_dim)
                # scales: (1, K, trajectory_dim)
                # logits: (1, K) - unnormalized mixture weights

                # Flatten target trajectory
                target_traj = actions_tensor.view(1, -1)  # (1, trajectory_dim)

                # Compute log mixture weights
                log_weights = torch.log_softmax(logits, dim=-1)  # (1, K)

                trajectory_dim = means.shape[2]

                for k in range(K):
                    mu_k = means[0, k]  # (trajectory_dim,)
                    sigma_k = scales[0, k]  # (trajectory_dim,)

                    # Log-likelihood of target under component k
                    diff = target_traj[0] - mu_k
                    log_prob = -0.5 * (
                        (diff ** 2 / (sigma_k ** 2)).sum()
                        + trajectory_dim * np.log(2 * np.pi)
                        + 2 * torch.log(sigma_k).sum()
                    )

                    # Add log mixture weight
                    total_log_prob = log_weights[0, k] + log_prob
                    component_log_likelihoods[k] += total_log_prob.item()

                log.info(f"  Demo {demo_idx}: processed")

            # Log results
            for k in range(K):
                log.info(f"  Component {k}: total log-likelihood = {component_log_likelihoods[k]:.4f}")

            # Select best component
            selected_component = int(np.argmax(component_log_likelihoods))
            log.info(f"\nSelected component: k* = {selected_component}")

        return selected_component, component_log_likelihoods

    def finetune_ibc_demo_init(self, all_states, all_actions):
        """
        IBC Demo-Initialized Inference: Save demo trajectories for use during inference.

        Args:
            all_states: list of (T, obs_dim) normalized states per demo
            all_actions: list of (T, action_dim) normalized actions per demo

        Returns:
            demo_trajectories: (num_demos, horizon_steps, action_dim) stacked demo trajectories
        """
        log.info(f"IBC: Preparing {len(all_actions)} demo trajectories for demo-initialized inference...")

        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

        # Stack all demo trajectories
        demo_trajectories = torch.stack([
            torch.FloatTensor(a).to(self.device) for a in all_actions
        ])  # (num_demos, T, action_dim)

        log.info(f"Demo trajectories shape: {demo_trajectories.shape}")

        # Evaluate energy for each demo
        with torch.no_grad():
            for demo_idx, (states, actions) in enumerate(zip(all_states, all_actions)):
                state = torch.FloatTensor(states[0:1]).to(self.device)
                actions_flat = torch.FloatTensor(actions).view(1, -1).to(self.device)
                energy = self.model.network(state, actions_flat)
                log.info(f"  Demo {demo_idx}: energy = {energy.item():.4f}")

        return demo_trajectories

    def finetune_dp_inversion(self, all_states, all_actions):
        """
        DP Inversion: Optimize initial noise x_T for each demo trajectory.

        Args:
            all_states: list of (T, obs_dim) normalized states per demo
            all_actions: list of (T, action_dim) normalized actions per demo

        Returns:
            all_optimized_x_T: (num_demos, horizon, action_dim) optimized initial noises
            all_loss_histories: list of loss histories per demo
        """
        log.info(f"DP Inversion: Optimizing x_T for {len(all_states)} demos...")

        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

        all_optimized_x_T = []
        all_loss_histories = []

        for demo_idx, (states, actions) in enumerate(zip(all_states, all_actions)):
            log.info(f"\n--- Demo {demo_idx + 1}/{len(all_states)} ---")

            demo_states = torch.FloatTensor(states).to(self.device)
            demo_actions = torch.FloatTensor(actions).to(self.device)

            horizon_steps = self.cfg.horizon_steps
            action_dim = self.cfg.action_dim

            # Initialize x_T as learnable parameter
            x_T = nn.Parameter(torch.randn(1, horizon_steps, action_dim, device=self.device))

            optimizer = torch.optim.Adam([x_T], lr=self.cfg.dp_inversion.lr)
            loss_history = []

            n_epochs = self.cfg.dp_inversion.n_epochs

            # Get DDIM parameters from model
            ddim_t = self.model.ddim_t
            ddim_alphas = self.model.ddim_alphas
            ddim_alphas_prev = self.model.ddim_alphas_prev
            ddim_sqrt_one_minus_alphas = self.model.ddim_sqrt_one_minus_alphas

            # Condition: first state
            cond = {'state': demo_states[0:1].unsqueeze(0)}  # (1, 1, obs_dim)

            # Target: full action trajectory
            target_actions = demo_actions.unsqueeze(0)  # (1, horizon, action_dim)

            pbar = tqdm(range(n_epochs), desc=f"DP Inversion Demo {demo_idx}")
            for epoch in pbar:
                optimizer.zero_grad()

                # Forward pass through DDIM sampler (deterministic)
                x = x_T.clone()

                for i, t in enumerate(ddim_t):
                    t_b = torch.tensor([t], device=self.device).long()
                    index_b = torch.tensor([i], device=self.device).long()

                    alpha = ddim_alphas[index_b].view(1, 1, 1)
                    alpha_prev = ddim_alphas_prev[index_b].view(1, 1, 1)
                    sqrt_one_minus_alpha = ddim_sqrt_one_minus_alphas[index_b].view(1, 1, 1)

                    with torch.no_grad():
                        noise_pred = self.model.actor(x, t_b, cond=cond)

                    x_recon = (x - sqrt_one_minus_alpha * noise_pred) / (alpha ** 0.5)

                    if self.model.denoised_clip_value is not None:
                        x_recon = x_recon.clamp(-self.model.denoised_clip_value, self.model.denoised_clip_value)

                    dir_xt = (1.0 - alpha_prev).sqrt() * noise_pred
                    x = (alpha_prev ** 0.5) * x_recon + dir_xt

                loss = F.mse_loss(x, target_actions)
                loss.backward()
                optimizer.step()

                loss_history.append(loss.item())
                pbar.set_postfix({'loss': f'{loss.item():.6f}', 'x_T_norm': f'{x_T.norm().item():.3f}'})

            all_optimized_x_T.append(x_T.detach())
            all_loss_histories.append(loss_history)
            log.info(f"Demo {demo_idx}: Final loss = {loss_history[-1]:.6f}")

        # Stack all optimized x_T
        all_optimized_x_T = torch.cat(all_optimized_x_T, dim=0)  # (num_demos, horizon, action_dim)

        return all_optimized_x_T, all_loss_histories

    def finetune_dp_inversion_single(self, states, actions):
        """
        DP Inversion: Optimize initial noise x_T for a single demo trajectory.

        Args:
            states: (T, obs_dim) normalized states
            actions: (T, action_dim) normalized actions

        Returns:
            optimized_x_T: (1, horizon, action_dim) optimized initial noise
            loss_history: list of loss values
        """
        log.info("DP Inversion: Optimizing x_T...")

        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

        demo_states = torch.FloatTensor(states).to(self.device)
        demo_actions = torch.FloatTensor(actions).to(self.device)

        horizon_steps = self.cfg.horizon_steps
        action_dim = self.cfg.action_dim

        # Initialize x_T as learnable parameter
        x_T = nn.Parameter(torch.randn(1, horizon_steps, action_dim, device=self.device))

        optimizer = torch.optim.Adam([x_T], lr=self.cfg.dp_inversion.lr)
        loss_history = []

        n_epochs = self.cfg.dp_inversion.n_epochs

        # Get DDIM parameters from model
        ddim_t = self.model.ddim_t
        ddim_alphas = self.model.ddim_alphas
        ddim_alphas_prev = self.model.ddim_alphas_prev
        ddim_sqrt_one_minus_alphas = self.model.ddim_sqrt_one_minus_alphas

        # Condition: first state
        cond = {'state': demo_states[0:1].unsqueeze(0)}  # (1, 1, obs_dim)

        # Target: full action trajectory
        target_actions = demo_actions.unsqueeze(0)  # (1, horizon, action_dim)

        pbar = tqdm(range(n_epochs), desc="DP Inversion")
        for epoch in pbar:
            optimizer.zero_grad()

            # Forward pass through DDIM sampler (deterministic)
            x = x_T.clone()

            for i, t in enumerate(ddim_t):
                t_b = torch.tensor([t], device=self.device).long()
                index_b = torch.tensor([i], device=self.device).long()

                alpha = ddim_alphas[index_b].view(1, 1, 1)
                alpha_prev = ddim_alphas_prev[index_b].view(1, 1, 1)
                sqrt_one_minus_alpha = ddim_sqrt_one_minus_alphas[index_b].view(1, 1, 1)

                with torch.no_grad():
                    noise_pred = self.model.actor(x, t_b, cond=cond)

                x_recon = (x - sqrt_one_minus_alpha * noise_pred) / (alpha ** 0.5)

                if self.model.denoised_clip_value is not None:
                    x_recon = x_recon.clamp(-self.model.denoised_clip_value, self.model.denoised_clip_value)

                dir_xt = (1.0 - alpha_prev).sqrt() * noise_pred
                x = (alpha_prev ** 0.5) * x_recon + dir_xt

            loss = F.mse_loss(x, target_actions)
            loss.backward()
            optimizer.step()

            loss_history.append(loss.item())
            pbar.set_postfix({'loss': f'{loss.item():.6f}', 'x_T_norm': f'{x_T.norm().item():.3f}'})

        log.info(f"Final loss = {loss_history[-1]:.6f}")

        return x_T.detach(), loss_history

    def run(self):
        """Run the finetune process."""
        log.info("=" * 70)
        log.info(f"Finetune {self.model_type.upper()} for grasp - grasp_style={self.grasp_style}")
        log.info("=" * 70)

        # Get demo path from config
        demo_path = self.cfg.demo_path

        # Load metadata
        metadata_path = os.path.join(demo_path, 'metadata.npy')
        if os.path.exists(metadata_path):
            metadata = np.load(metadata_path, allow_pickle=True).item()
            approach_angle_deg = np.degrees(metadata.get('approach_angle', 0))
            log.info(f"Target demo metadata: approach_angle={approach_angle_deg:.1f}°")
        else:
            log.warning(f"No metadata found at {metadata_path}")
            metadata = {}

        # Load and normalize demo
        states, actions = self._load_demo_from_pkl(demo_path)
        states_norm, actions_norm = self._normalize_states_actions(states, actions)
        log.info(f"Normalized demo: states={states_norm.shape}, actions={actions_norm.shape}")

        # Finetune based on model type
        checkpoint_path = None
        x_T_path = None
        selected_component_path = None
        demo_trajectory_path = None
        loss_history = []

        if self.do_training:
            log.info("\n" + "=" * 70)
            log.info(f"Finetuning {self.model_type.upper()}...")
            log.info("=" * 70)

            if self.model_type == 'bc':
                loss_history = self.finetune_bc([states_norm], [actions_norm])

                if loss_history:
                    log.info(f"Final loss: {loss_history[-1]:.6f}")

                    # Save finetuned checkpoint
                    checkpoint_path = os.path.join(self.output_dir, f'finetuned_model_style{self.grasp_style}.pt')
                    torch.save({
                        'model': self.model.state_dict(),
                    }, checkpoint_path)
                    log.info(f"Saved finetuned model to {checkpoint_path}")

                    self._plot_loss_curve(loss_history)

            elif self.model_type == 'dp':
                optimized_x_T, loss_history = self.finetune_dp_inversion_single(states_norm, actions_norm)

                # Save optimized x_T
                x_T_path = os.path.join(self.output_dir, f'optimized_x_T_style{self.grasp_style}.npy')
                np.save(x_T_path, optimized_x_T.cpu().numpy())
                log.info(f"Saved optimized x_T to {x_T_path} (shape: {optimized_x_T.shape})")

                self._plot_loss_curve(loss_history)

            elif self.model_type == 'bc_gmm':
                selected_component, component_log_likelihoods = self.finetune_bc_gmm_component_selection(
                    [states_norm], [actions_norm]
                )

                selected_component_path = os.path.join(self.output_dir, f'selected_component_style{self.grasp_style}.npy')
                np.save(selected_component_path, {
                    'selected_component': selected_component,
                    'component_log_likelihoods': component_log_likelihoods,
                })
                log.info(f"Saved selected component info to {selected_component_path}")

            elif self.model_type == 'ibc':
                demo_traj = self.finetune_ibc_demo_init([states_norm], [actions_norm])

                demo_trajectory_path = os.path.join(self.output_dir, f'demo_trajectory_style{self.grasp_style}.npy')
                np.save(demo_trajectory_path, demo_traj.cpu().numpy())
                log.info(f"Saved demo trajectory to {demo_trajectory_path}")

            else:
                log.warning(f"Training not implemented for {self.model_type}, skipping...")

        else:
            log.info(f"\n--- No training for {self.model_type.upper()} (inference only) ---")

        # Run eval agent
        log.info("\n" + "=" * 70)
        log.info("Running evaluation...")
        log.info("=" * 70)

        self._run_eval(checkpoint_path, x_T_path, selected_component_path, demo_trajectory_path)

    def _plot_loss_curve(self, loss_history):
        """Plot and save the finetuning loss curve."""
        if len(loss_history) == 0:
            return

        plt.figure(figsize=(10, 6))
        plt.plot(loss_history)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(f'{self.model_type.upper()} Finetuning Loss (grasp_style={self.grasp_style})')
        plt.grid(True, alpha=0.3)
        save_path = os.path.join(self.output_dir, 'loss_curve.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        log.info(f"Saved loss curve to {save_path}")

    def _plot_loss_curves_dp(self, all_loss_histories, all_demo_info):
        """Plot loss curves for DP inversion."""
        fig, ax = plt.subplots(figsize=(10, 6))
        for idx, (loss_history, info) in enumerate(zip(all_loss_histories, all_demo_info)):
            if len(loss_history) > 0:
                label = info.get('description', f'Demo {idx}')
                ax.plot(loss_history, label=label)

        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.set_title('DP Inversion Loss')
        ax.legend()
        ax.grid(True, alpha=0.3)

        save_path = os.path.join(self.output_dir, 'loss_curves.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        log.info(f"Saved loss curves to {save_path}")

    def _run_eval(self, checkpoint_path=None, x_T_path=None, selected_component_path=None, demo_trajectory_path=None):
        """Run the eval agent with finetuned config (40 episodes with style checking)."""
        # Map model type to finetuned eval config
        finetuned_eval_configs = {
            'bc': 'eval_bc_finetuned',
            'bc_gmm': 'eval_bc_gmm_finetuned',
            'dp': 'eval_diffusion_mlp_finetuned',
            'ibc': 'eval_ibc_finetuned',
        }
        eval_config_name = finetuned_eval_configs.get(self.model_type, self.cfg.eval_config_name)

        eval_cmd = [
            "python", "script/run.py",
            f"--config-dir=cfg/rlbench/eval/grasp",
            f"--config-name={eval_config_name}",
            f"grasp_style={self.grasp_style}",  # Pass grasp_style for angle-based success check
        ]

        if checkpoint_path is not None:
            eval_cmd.append(f"base_policy_path={checkpoint_path}")

        if x_T_path is not None:
            eval_cmd.append(f"x_T_file={x_T_path}")

        if selected_component_path is not None:
            eval_cmd.append(f"selected_component_file={selected_component_path}")

        if demo_trajectory_path is not None:
            eval_cmd.append(f"demo_trajectory_file={demo_trajectory_path}")

        log.info(f"Running eval command: {' '.join(eval_cmd)}")

        result = subprocess.run(
            eval_cmd,
            cwd=dppo_root,
            capture_output=False,
        )

        if result.returncode != 0:
            log.error(f"Eval command failed with return code {result.returncode}")
        else:
            log.info("Evaluation completed successfully!")
