"""
Finetune Agent for Parameterized Behavior Cloning (PBC) on close_drawer task.

Given a wall-avoiding demo (with full states/actions from low_dim_obs.pkl), this agent:
1. Loads the demo directly from block_setting/style{id}/episodes/episode0/
2. Encodes the demo to get warm start z
3. Optimizes z using direct action regression (much simpler than diffusion!)
4. Saves the optimized z to a file
5. Automatically runs the eval agent with the optimized z

Key difference from PDP-FiLM finetuning:
- PDP-FiLM: MSE(predicted_noise, actual_noise) with timestep sampling
- PBC: MSE(predicted_actions, demo_actions) - direct regression, no timesteps, no noise
"""

import os
import sys
import logging
import subprocess
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import hydra

# Add paths
dppo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(dppo_root)
sys.path.append(os.path.join(dppo_root, "RLBench"))

from model.baselines.parameterized_bc_eval import ParameterizedBCEval
from model.baselines.parameterized_bc import FiLMBCNetwork

# Import encoder and relative trajectory conversion
from RLBench.trajectory_encoder import TrajectoryVAE
from RLBench.train_relative_trajectory_encoder import make_trajectory_relative

log = logging.getLogger(__name__)


class FinetunePBCAgent:
    """Finetune agent for Parameterized BC on close_drawer task."""

    def __init__(self, cfg):
        self.cfg = cfg
        self.device = cfg.device
        self.wall_style = cfg.wall_style

        # Set random seed for reproducibility
        self.seed = cfg.seed
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        # Load normalization stats
        self.norm_stats = self._load_normalization(cfg.normalization_path)

        # Load encoder for warm start (optional - may be None for onehot)
        if cfg.encoder_checkpoint is not None:
            self.encoder = self._load_encoder(cfg.encoder_checkpoint)
        else:
            self.encoder = None
            # Get latent dim from model config
            self.latent_dim = cfg.model.network.z_dim
            log.info(
                f"No encoder provided, using z_dim={self.latent_dim} from model config"
            )

        # Load PBC model
        self.model = hydra.utils.instantiate(cfg.model)
        self.model.eval()

        # Freeze PBC model weights (only optimize z)
        for param in self.model.network.parameters():
            param.requires_grad = False

        # Create output directory
        self.output_dir = cfg.logdir
        os.makedirs(self.output_dir, exist_ok=True)

    def _load_normalization(self, norm_path):
        """Load normalization statistics."""
        log.info(f"Loading normalization from {norm_path}")
        norm_data = np.load(norm_path)

        obs_min = torch.from_numpy(norm_data["obs_min"]).float().to(self.device)
        obs_max = torch.from_numpy(norm_data["obs_max"]).float().to(self.device)
        action_min = torch.from_numpy(norm_data["action_min"]).float().to(self.device)
        action_max = torch.from_numpy(norm_data["action_max"]).float().to(self.device)

        obs_range = obs_max - obs_min
        obs_range = torch.where(obs_range < 1e-6, torch.ones_like(obs_range), obs_range)
        action_range = action_max - action_min
        action_range = torch.where(
            action_range < 1e-6, torch.ones_like(action_range), action_range
        )

        return {
            "obs_min": obs_min,
            "obs_max": obs_max,
            "obs_range": obs_range,
            "action_min": action_min,
            "action_max": action_max,
            "action_range": action_range,
            "obs_min_np": norm_data["obs_min"],
            "obs_max_np": norm_data["obs_max"],
            "action_min_np": norm_data["action_min"],
            "action_max_np": norm_data["action_max"],
        }

    def _load_encoder(self, checkpoint_path):
        """Load trajectory encoder."""
        log.info(f"Loading encoder from {checkpoint_path}")
        checkpoint = torch.load(
            checkpoint_path, map_location=self.device, weights_only=False
        )
        config = checkpoint["config"]

        encoder = TrajectoryVAE(
            state_dim=config["state_dim"],
            action_dim=config["action_dim"],
            hidden_dim=config["hidden_dim"],
            num_layers=config["num_layers"],
            num_heads=config["num_heads"],
            latent_dim=config["latent_dim"],
            horizon=config["horizon"],
        ).to(self.device)

        encoder.load_state_dict(checkpoint["model_state_dict"])
        encoder.eval()
        for param in encoder.parameters():
            param.requires_grad = False

        self.latent_dim = config["latent_dim"]
        return encoder

    def _load_demo_from_pkl(self, demo_path):
        """
        Load demo from low_dim_obs.pkl and extract states/actions.

        State format: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7) = 22
        Action format: joint_pos(7) + gripper_open(1) = 8
        """
        pkl_path = os.path.join(demo_path, "low_dim_obs.pkl")
        if not os.path.exists(pkl_path):
            raise FileNotFoundError(
                f"Demo file not found: {pkl_path}\n"
                f"Please run: python RLBench_close_drawer/make_dataset/visualize_wall_trajectories.py --style {self.wall_style}"
            )

        log.info(f"Loading demo from {pkl_path}")
        with open(pkl_path, "rb") as f:
            demo = pickle.load(f)

        # Extract states and actions from observations
        states = []
        actions = []

        for obs in demo:
            # State: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
            state = np.concatenate(
                [
                    obs.joint_positions,  # 7
                    obs.joint_velocities,  # 7
                    [obs.gripper_open],  # 1
                    obs.gripper_pose,  # 7 (pos + quat)
                ]
            )
            states.append(state)

            # Action: from obs.misc['joint_position_action'] if available
            if hasattr(obs, "misc") and "joint_position_action" in obs.misc:
                action = obs.misc["joint_position_action"]
            else:
                # Fallback: use joint_positions + gripper_open
                action = np.concatenate([obs.joint_positions, [obs.gripper_open]])
            actions.append(action)

        states = np.array(states)
        actions = np.array(actions)

        log.info(
            f"Loaded demo: {len(demo)} steps, states={states.shape}, actions={actions.shape}"
        )

        # Adjust to match expected horizon
        horizon = self.cfg.horizon_steps
        if len(states) > horizon:
            log.info(f"Truncating demo from {len(states)} to {horizon} steps")
            states = states[:horizon]
            actions = actions[:horizon]
        elif len(states) < horizon:
            # Pad by repeating last state/action
            log.info(f"Padding demo from {len(states)} to {horizon} steps")
            pad_len = horizon - len(states)
            states = np.concatenate(
                [states, np.tile(states[-1:], (pad_len, 1))], axis=0
            )
            actions = np.concatenate(
                [actions, np.tile(actions[-1:], (pad_len, 1))], axis=0
            )

        return states, actions

    def _normalize_states_actions(self, states, actions):
        """Normalize states and actions to [-1, 1]."""
        obs_min = self.norm_stats["obs_min_np"]
        obs_max = self.norm_stats["obs_max_np"]
        action_min = self.norm_stats["action_min_np"]
        action_max = self.norm_stats["action_max_np"]

        obs_range = obs_max - obs_min
        obs_range = np.where(obs_range < 1e-6, 1.0, obs_range)
        action_range = action_max - action_min
        action_range = np.where(action_range < 1e-6, 1.0, action_range)

        # Normalize to [-1, 1]
        states_norm = 2.0 * (states - obs_min) / obs_range - 1.0
        actions_norm = 2.0 * (actions - action_min) / action_range - 1.0

        return states_norm, actions_norm

    def _encode_demo(self, states, actions):
        """Encode demo to get z using encoder."""
        states_tensor = torch.FloatTensor(states).to(self.device)
        actions_tensor = torch.FloatTensor(actions).to(self.device)

        with torch.no_grad():
            states_rel, actions_rel = make_trajectory_relative(
                states_tensor.unsqueeze(0), actions_tensor.unsqueeze(0)
            )
            z = self.encoder.encode(states_rel, actions_rel).squeeze(0)

        return z

    def optimize_z(
        self,
        demo_states,
        demo_actions,
        n_epochs=300,
        lr=3e-3,
        z_reg_weight=0.0,
        warm_start=True,
        init_z=None,
    ):
        """
        Optimize z using direct action regression.

        This is MUCH SIMPLER than diffusion finetuning:
        - No timestep sampling
        - No noise addition
        - Just MSE(predicted_actions, demo_actions)

        Args:
            demo_states: (T, state_dim) - normalized states in [-1, 1]
            demo_actions: (T, action_dim) - normalized actions in [-1, 1]
            n_epochs: number of optimization epochs
            lr: learning rate
            z_reg_weight: weight for z regularization (0 = no reg)
            warm_start: if True, use init_z as starting point
            init_z: initial z value (from encoder)

        Returns:
            z: optimized latent code
            loss_history: list of loss values
        """
        demo_states = torch.FloatTensor(demo_states).to(self.device)
        demo_actions = torch.FloatTensor(demo_actions).to(self.device)

        # Initialize z
        if warm_start and init_z is not None:
            log.info(f"Warm start: Using encoder z (norm={init_z.norm().item():.3f})")
            z = nn.Parameter(init_z.clone().detach().requires_grad_(True))
        else:
            log.info("Cold start: Random initialization for z")
            z = nn.Parameter(torch.randn(self.latent_dim).to(self.device) * 0.1)

        optimizer = torch.optim.Adam([z], lr=lr)
        loss_history = []

        # Condition: first state
        cond_state = demo_states[0:1]  # (1, obs_dim)

        # Demo actions for loss computation
        demo_actions_batch = demo_actions.unsqueeze(0)  # (1, T, action_dim)

        pbar = tqdm(range(n_epochs), desc="Optimizing z (PBC)")
        for epoch in pbar:
            optimizer.zero_grad()

            # Expand z for batch
            z_batch = z.unsqueeze(0)  # (1, z_dim)

            # Direct forward pass - NO timesteps, NO noise (unlike diffusion!)
            with torch.enable_grad():
                predicted_actions = self.model.network(
                    cond_state, z_batch
                )  # (1, T, action_dim)

            # Simple MSE loss on actions (unlike diffusion which uses MSE on noise)
            loss = F.mse_loss(predicted_actions, demo_actions_batch)
            if z_reg_weight > 0:
                loss = loss + z_reg_weight * torch.norm(z, p=2)

            loss.backward()
            torch.nn.utils.clip_grad_norm_([z], max_norm=1.0)
            optimizer.step()

            loss_history.append(loss.item())
            pbar.set_postfix({"loss": f"{loss.item():.4f}", "z_norm": f"{z.norm().item():.3f}"})

        return z.detach(), loss_history

    def run(self):
        """Run the finetune process (compute and save optimized z)."""
        log.info("=" * 70)
        log.info(f"PBC Finetune for close_drawer (wall_style={self.wall_style})")
        log.info("=" * 70)

        # 1. Load target demo directly from block_setting
        demo_path = os.path.join(
            self.cfg.demo_base_path, f"style{self.wall_style}", "episodes", "episode0"
        )

        # Load metadata
        metadata_path = os.path.join(demo_path, "metadata.npy")
        if os.path.exists(metadata_path):
            metadata = np.load(metadata_path, allow_pickle=True).item()
            log.info(f"Target demo metadata: {metadata}")

        # 2. Load full demo (states/actions) from low_dim_obs.pkl
        states, actions = self._load_demo_from_pkl(demo_path)
        states_norm, actions_norm = self._normalize_states_actions(states, actions)
        log.info(
            f"Normalized demo: states={states_norm.shape}, actions={actions_norm.shape}"
        )

        # 3. Encode demo to get warm start z (if encoder available)
        if self.encoder is not None:
            init_z = self._encode_demo(states_norm, actions_norm)
            log.info(f"Encoder z from demo: {init_z.cpu().numpy()}")
        else:
            init_z = None
            log.info("No encoder available - will use random initialization for z")

        # 4. Optimize z using direct action regression (or skip if n_epochs=0)
        n_epochs = self.cfg.ftl_igm.n_epochs
        if n_epochs > 0:
            log.info("\n--- PBC Z Optimization (Direct Action Regression) ---")
            optimized_z, loss_history = self.optimize_z(
                demo_states=states_norm,
                demo_actions=actions_norm,
                n_epochs=n_epochs,
                lr=self.cfg.ftl_igm.lr,
                z_reg_weight=self.cfg.ftl_igm.z_reg_weight,
                warm_start=self.cfg.ftl_igm.warm_start,
                init_z=init_z,
            )
            log.info(f"Optimized z: {optimized_z.cpu().numpy()}")
            log.info(f"Final loss: {loss_history[-1]:.4f}")
        else:
            log.info("\n--- Skipping Z Optimization (n_epochs=0) ---")
            if init_z is not None:
                log.info("Using encoder z directly from demo")
                optimized_z = init_z
            else:
                log.info("Using random z (no encoder and no optimization)")
                optimized_z = torch.randn(self.latent_dim).to(self.device) * 0.1
            loss_history = []

        # 5. Save optimized z (as 2D array for compatibility with eval agent)
        z_save_path = os.path.join(
            self.output_dir, f"optimized_z_wall{self.wall_style}.npy"
        )
        z_to_save = optimized_z.cpu().numpy().reshape(1, -1)  # Shape: (1, z_dim)
        np.save(z_save_path, z_to_save)
        log.info(f"Saved optimized z to {z_save_path} (shape: {z_to_save.shape})")

        # 6. Plot loss curve
        self._plot_loss_curve(loss_history)

        # 7. Run eval agent with optimized z
        log.info("\n" + "=" * 70)
        log.info("PBC Z Optimization Complete! Running evaluation...")
        log.info("=" * 70)

        self._run_eval(z_save_path)

    def _plot_loss_curve(self, loss_history):
        """Plot and save the optimization loss curve."""
        if len(loss_history) == 0:
            log.info("No loss history to plot (optimization skipped)")
            return

        plt.figure(figsize=(10, 6))
        plt.plot(loss_history)
        plt.xlabel("Epoch")
        plt.ylabel("Loss (MSE on Actions)")
        plt.title(f"PBC Z Optimization Loss (wall_style={self.wall_style})")
        plt.grid(True, alpha=0.3)
        save_path = os.path.join(self.output_dir, "loss_curve.png")
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
        plt.close()
        log.info(f"Saved loss curve to {save_path}")

    def _run_eval(self, z_file_path):
        """Run the eval agent with the optimized z."""
        # Get eval config name from cfg, default to eval_pbc
        eval_config_name = self.cfg.get("eval_config_name", "eval_pbc")

        # Demo path for DTW comparison
        demo_path = os.path.join(
            self.cfg.demo_base_path, f"style{self.wall_style}", "episodes", "episode0"
        )

        # Build the eval command
        eval_cmd = [
            "python",
            "script/run.py",
            "--config-dir=cfg/rlbench/eval/close_drawer",
            f"--config-name={eval_config_name}",
            f"wall_style={self.wall_style}",
            f"z_file={z_file_path}",
            f"demo_trajectory_path={demo_path}",
        ]

        log.info(f"Running eval command: {' '.join(eval_cmd)}")

        # Run from the dppo root directory
        result = subprocess.run(
            eval_cmd,
            cwd=dppo_root,
            capture_output=False,  # Let output go to stdout/stderr directly
        )

        if result.returncode != 0:
            log.error(f"Eval command failed with return code {result.returncode}")
        else:
            log.info("Evaluation completed successfully!")
