"""
Finetune Agent for Parameterized Behavior Cloning (PBC) on grasp (pick_up_cup) task.

Given a target demo, this agent:
1. Loads the demo from the specified path
2. Encodes the demo to get warm start z
3. Optimizes z using direct action regression (much simpler than diffusion!)
4. Saves the optimized z to a file
5. Automatically runs the eval agent with the optimized z

Demo selection via grasp_style:
- grasp_style=1: Mode 0° (train, episode0)
- grasp_style=2: Mode 270° (train, episode30, noise-free)
- grasp_style=3: Test 45° (novel angle)

Key difference from PDP-FiLM finetuning:
- PDP-FiLM: MSE(predicted_noise, actual_noise) with timestep sampling
- PBC: MSE(predicted_actions, demo_actions) - direct regression, no timesteps, no noise
"""

import os
import sys
import logging
import subprocess
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import hydra

# Add paths
dppo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(dppo_root)

from model.baselines.parameterized_bc_eval import ParameterizedBCEval
from model.baselines.parameterized_bc import FiLMBCNetwork

# Import encoder and relative trajectory conversion from the correct grasp encoder directory
# Using importlib to avoid sys.path conflicts with other encoder directories
import importlib.util

_encoder_path = os.path.join(
    dppo_root, "RLBench_grasp", "encoder", "trajectory_encoder.py"
)
_spec = importlib.util.spec_from_file_location("grasp_trajectory_encoder", _encoder_path)
_encoder_module = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_encoder_module)
TrajectoryVAE = _encoder_module.TrajectoryVAE

_train_encoder_path = os.path.join(
    dppo_root, "RLBench_grasp", "encoder", "train_grasp_encoder.py"
)
_spec2 = importlib.util.spec_from_file_location("train_grasp_encoder", _train_encoder_path)
_train_encoder_module = importlib.util.module_from_spec(_spec2)
_spec2.loader.exec_module(_train_encoder_module)
make_trajectory_relative = _train_encoder_module.make_trajectory_relative

log = logging.getLogger(__name__)


class FinetuneGraspPBCAgent:
    """Finetune agent for Parameterized BC on grasp (pick_up_cup) task."""

    def __init__(self, cfg):
        self.cfg = cfg
        self.device = cfg.device
        self.grasp_style = cfg.grasp_style

        # Set random seed for reproducibility
        self.seed = cfg.seed
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        # Load normalization stats
        self.norm_stats = self._load_normalization(cfg.normalization_path)

        # Load encoder for warm start (optional - can be null for random init)
        encoder_checkpoint = cfg.get("encoder_checkpoint", None)
        if encoder_checkpoint is not None:
            self.encoder = self._load_encoder(encoder_checkpoint)
        else:
            log.info("No encoder checkpoint provided - will use random z initialization")
            self.encoder = None
            # Get latent_dim from model config
            self.latent_dim = cfg.model.network.z_dim

        # Load PBC model
        self.model = hydra.utils.instantiate(cfg.model)
        self.model.eval()

        # Freeze PBC model weights (only optimize z)
        for param in self.model.network.parameters():
            param.requires_grad = False

        # Create output directory
        self.output_dir = cfg.logdir
        os.makedirs(self.output_dir, exist_ok=True)

    def _load_normalization(self, norm_path):
        """Load normalization statistics."""
        log.info(f"Loading normalization from {norm_path}")
        norm_data = np.load(norm_path)

        obs_min = torch.from_numpy(norm_data["obs_min"]).float().to(self.device)
        obs_max = torch.from_numpy(norm_data["obs_max"]).float().to(self.device)
        action_min = torch.from_numpy(norm_data["action_min"]).float().to(self.device)
        action_max = torch.from_numpy(norm_data["action_max"]).float().to(self.device)

        obs_range = obs_max - obs_min
        obs_range = torch.where(obs_range < 1e-6, torch.ones_like(obs_range), obs_range)
        action_range = action_max - action_min
        action_range = torch.where(
            action_range < 1e-6, torch.ones_like(action_range), action_range
        )

        return {
            "obs_min": obs_min,
            "obs_max": obs_max,
            "obs_range": obs_range,
            "action_min": action_min,
            "action_max": action_max,
            "action_range": action_range,
            "obs_min_np": norm_data["obs_min"],
            "obs_max_np": norm_data["obs_max"],
            "action_min_np": norm_data["action_min"],
            "action_max_np": norm_data["action_max"],
        }

    def _load_encoder(self, checkpoint_path):
        """Load trajectory encoder."""
        log.info(f"Loading encoder from {checkpoint_path}")
        checkpoint = torch.load(
            checkpoint_path, map_location=self.device, weights_only=False
        )
        config = checkpoint["config"]

        encoder = TrajectoryVAE(
            state_dim=config["state_dim"],
            action_dim=config["action_dim"],
            hidden_dim=config["hidden_dim"],
            num_layers=config["num_layers"],
            num_heads=config["num_heads"],
            latent_dim=config["latent_dim"],
            horizon=config["horizon"],
        ).to(self.device)

        encoder.load_state_dict(checkpoint["model_state_dict"])
        encoder.eval()
        for param in encoder.parameters():
            param.requires_grad = False

        self.latent_dim = config["latent_dim"]
        return encoder

    def _load_demo_from_pkl(self, demo_path):
        """
        Load demo from low_dim_obs.pkl and extract states/actions.

        State format: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7) = 22
        Action format: joint_pos(7) + gripper_open(1) = 8
        """
        pkl_path = os.path.join(demo_path, "low_dim_obs.pkl")
        if not os.path.exists(pkl_path):
            raise FileNotFoundError(f"Demo file not found: {pkl_path}")

        log.info(f"Loading demo from {pkl_path}")
        with open(pkl_path, "rb") as f:
            demo = pickle.load(f)

        # Extract states and actions from observations
        states = []
        actions = []

        for obs in demo:
            # State: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
            state = np.concatenate(
                [
                    obs.joint_positions,  # 7
                    obs.joint_velocities,  # 7
                    [obs.gripper_open],  # 1
                    obs.gripper_pose,  # 7 (pos + quat)
                ]
            )
            states.append(state)

            # Action: from obs.misc['joint_position_action'] if available
            if hasattr(obs, "misc") and "joint_position_action" in obs.misc:
                action = obs.misc["joint_position_action"]
            else:
                # Fallback: use joint_positions + gripper_open
                action = np.concatenate([obs.joint_positions, [obs.gripper_open]])
            actions.append(action)

        states = np.array(states)
        actions = np.array(actions)

        log.info(
            f"Loaded demo: {len(demo)} steps, states={states.shape}, actions={actions.shape}"
        )

        # Adjust to match expected horizon
        horizon = self.cfg.horizon_steps
        if len(states) > horizon:
            log.info(f"Truncating demo from {len(states)} to {horizon} steps")
            states = states[:horizon]
            actions = actions[:horizon]
        elif len(states) < horizon:
            # Pad by repeating last state/action
            log.info(f"Padding demo from {len(states)} to {horizon} steps")
            pad_len = horizon - len(states)
            states = np.concatenate(
                [states, np.tile(states[-1:], (pad_len, 1))], axis=0
            )
            actions = np.concatenate(
                [actions, np.tile(actions[-1:], (pad_len, 1))], axis=0
            )

        return states, actions

    def _normalize_states_actions(self, states, actions):
        """Normalize states and actions to [-1, 1]."""
        obs_min = self.norm_stats["obs_min_np"]
        obs_max = self.norm_stats["obs_max_np"]
        action_min = self.norm_stats["action_min_np"]
        action_max = self.norm_stats["action_max_np"]

        obs_range = obs_max - obs_min
        obs_range = np.where(obs_range < 1e-6, 1.0, obs_range)
        action_range = action_max - action_min
        action_range = np.where(action_range < 1e-6, 1.0, action_range)

        # Normalize to [-1, 1]
        states_norm = 2.0 * (states - obs_min) / obs_range - 1.0
        actions_norm = 2.0 * (actions - action_min) / action_range - 1.0

        return states_norm, actions_norm

    def _encode_demo(self, states, actions):
        """Encode demo to get z using encoder."""
        states_tensor = torch.FloatTensor(states).to(self.device)
        actions_tensor = torch.FloatTensor(actions).to(self.device)

        with torch.no_grad():
            states_rel, actions_rel = make_trajectory_relative(
                states_tensor.unsqueeze(0), actions_tensor.unsqueeze(0)
            )
            z = self.encoder.encode(states_rel, actions_rel).squeeze(0)

        return z

    def optimize_z(
        self,
        demo_states,
        demo_actions,
        n_epochs=300,
        lr=3e-3,
        z_reg_weight=0.0,
        warm_start=True,
        init_z=None,
    ):
        """
        Optimize z using direct action regression.

        This is MUCH SIMPLER than diffusion finetuning:
        - No timestep sampling
        - No noise addition
        - Just MSE(predicted_actions, demo_actions)

        Args:
            demo_states: (T, state_dim) - normalized states in [-1, 1]
            demo_actions: (T, action_dim) - normalized actions in [-1, 1]
            n_epochs: number of optimization epochs
            lr: learning rate
            z_reg_weight: weight for z regularization (0 = no reg)
            warm_start: if True, use init_z as starting point
            init_z: initial z value (from encoder)

        Returns:
            z: optimized latent code
            loss_history: list of loss values
        """
        demo_states = torch.FloatTensor(demo_states).to(self.device)
        demo_actions = torch.FloatTensor(demo_actions).to(self.device)

        # Initialize z
        if warm_start and init_z is not None:
            log.info(f"Warm start: Using encoder z (norm={init_z.norm().item():.3f})")
            z = nn.Parameter(init_z.clone().detach().requires_grad_(True))
        else:
            log.info("Cold start: Random initialization for z")
            z = nn.Parameter(torch.randn(self.latent_dim).to(self.device) * 0.1)

        optimizer = torch.optim.Adam([z], lr=lr)
        loss_history = []

        # Condition: first state
        cond_state = demo_states[0:1]  # (1, obs_dim)

        # Demo actions for loss computation
        demo_actions_batch = demo_actions.unsqueeze(0)  # (1, T, action_dim)

        pbar = tqdm(range(n_epochs), desc="Optimizing z (PBC)")
        for epoch in pbar:
            optimizer.zero_grad()

            # Expand z for batch
            z_batch = z.unsqueeze(0)  # (1, z_dim)

            # Direct forward pass - NO timesteps, NO noise (unlike diffusion!)
            with torch.enable_grad():
                predicted_actions = self.model.network(
                    cond_state, z_batch
                )  # (1, T, action_dim)

            # Simple MSE loss on actions (unlike diffusion which uses MSE on noise)
            loss = F.mse_loss(predicted_actions, demo_actions_batch)
            if z_reg_weight > 0:
                loss = loss + z_reg_weight * torch.norm(z, p=2)

            loss.backward()
            torch.nn.utils.clip_grad_norm_([z], max_norm=1.0)
            optimizer.step()

            loss_history.append(loss.item())
            pbar.set_postfix(
                {"loss": f"{loss.item():.4f}", "z_norm": f"{z.norm().item():.3f}"}
            )

        return z.detach(), loss_history

    def run(self):
        """Run the finetune process (compute and save optimized z)."""
        log.info("=" * 70)
        log.info(f"PBC Finetune for grasp (pick_up_cup) - grasp_style={self.grasp_style}")
        log.info("=" * 70)

        # Get demo path from config
        demo_path = self.cfg.demo_path

        # Load metadata
        metadata_path = os.path.join(demo_path, "metadata.npy")
        if os.path.exists(metadata_path):
            metadata = np.load(metadata_path, allow_pickle=True).item()
            approach_angle_deg = np.degrees(metadata.get("approach_angle", 0))
            log.info(f"Target demo metadata: approach_angle={approach_angle_deg:.1f}°")
        else:
            log.warning(f"No metadata found at {metadata_path}")
            metadata = {}

        # Load demo
        states, actions = self._load_demo_from_pkl(demo_path)
        states_norm, actions_norm = self._normalize_states_actions(states, actions)
        log.info(
            f"Normalized demo: states={states_norm.shape}, actions={actions_norm.shape}"
        )

        # Get initial z (from encoder if available, else None for random init)
        if self.encoder is not None:
            init_z = self._encode_demo(states_norm, actions_norm)
            log.info(f"Encoder z from demo: {init_z.cpu().numpy()}")
        else:
            init_z = None
            log.info("No encoder - will use random z initialization")

        # Optimize z using direct action regression
        n_epochs = self.cfg.ftl_igm.n_epochs
        if n_epochs > 0:
            log.info("\n--- PBC Z Optimization (Direct Action Regression) ---")
            optimized_z, loss_history = self.optimize_z(
                demo_states=states_norm,
                demo_actions=actions_norm,
                n_epochs=n_epochs,
                lr=self.cfg.ftl_igm.lr,
                z_reg_weight=self.cfg.ftl_igm.z_reg_weight,
                warm_start=self.cfg.ftl_igm.warm_start if init_z is not None else False,
                init_z=init_z,
            )
            log.info(f"Optimized z: {optimized_z.cpu().numpy()}")
            log.info(f"Final loss: {loss_history[-1]:.4f}")
        else:
            if init_z is not None:
                log.info("\n--- Skipping Z Optimization (n_epochs=0) ---")
                log.info("Using encoder z directly from demo")
                optimized_z = init_z
            else:
                log.info("\n--- Using random z (no encoder, n_epochs=0) ---")
                optimized_z = torch.randn(self.latent_dim, device=self.device) * 0.1
                log.info(f"Random z: {optimized_z.cpu().numpy()}")
            loss_history = []

        # Save optimized z (as 2D array for compatibility with eval agent)
        z_save_path = os.path.join(
            self.output_dir, f"optimized_z_style{self.grasp_style}.npy"
        )
        z_to_save = optimized_z.cpu().numpy().reshape(1, -1)  # Shape: (1, z_dim)
        np.save(z_save_path, z_to_save)
        log.info(f"Saved optimized z to {z_save_path} (shape: {z_to_save.shape})")

        # Plot loss curve
        self._plot_loss_curve(loss_history)

        # Run evaluation with optimized z
        log.info("\n" + "=" * 70)
        log.info("PBC Z Optimization Complete! Running evaluation...")
        log.info("=" * 70)

        self._run_eval(z_save_path)

    def _plot_loss_curve(self, loss_history):
        """Plot and save the optimization loss curve."""
        if len(loss_history) == 0:
            log.info("No loss history to plot (optimization skipped)")
            return

        plt.figure(figsize=(10, 6))
        plt.plot(loss_history)
        plt.xlabel("Epoch")
        plt.ylabel("Loss (MSE on Actions)")
        plt.title(f"PBC Z Optimization Loss (grasp_style={self.grasp_style})")
        plt.grid(True, alpha=0.3)
        save_path = os.path.join(self.output_dir, "loss_curve.png")
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
        plt.close()
        log.info(f"Saved loss curve to {save_path}")

    def _run_eval(self, z_file_path):
        """Run the eval agent with the optimized z."""
        eval_config_name = self.cfg.get("eval_config_name", "eval_pbc")

        eval_cmd = [
            "python",
            "script/run.py",
            "--config-dir=cfg/rlbench/eval/grasp",
            f"--config-name={eval_config_name}",
            f"z_file={z_file_path}",
            f"grasp_style={self.grasp_style}",  # Pass grasp_style for angle-based success check
        ]

        log.info(f"Running eval command: {' '.join(eval_cmd)}")

        result = subprocess.run(
            eval_cmd,
            cwd=dppo_root,
            capture_output=False,
        )

        if result.returncode != 0:
            log.error(f"Eval command failed with return code {result.returncode}")
        else:
            log.info("Evaluation completed successfully!")
