"""
FTL-IGM Finetune Agent for grasp (pick_up_cup) task with PDP-FiLM model.

Given a target demo, this agent:
1. Loads the demo from the specified path
2. Encodes the demo to get warm start z
3. Optimizes z using FTL-IGM (denoising score matching)
4. Saves the optimized z to a file
5. Automatically runs the eval agent with the optimized z

Demo selection via grasp_style:
- grasp_style=1: Mode 0° (train, episode0)
- grasp_style=2: Mode 270° (train, episode30, noise-free)
- grasp_style=3: Test 45° (novel angle)
"""

import os
import sys
import logging
import subprocess
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import hydra

# Add paths
dppo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(dppo_root)

from model.diffusion.parameterized_diffusion_eval import ParameterizedDiffusionEval
from model.diffusion.parameterized_diffusion import FiLMDiffusionMLP

# Import encoder and relative trajectory conversion from the correct grasp encoder directory
# Using importlib to avoid sys.path conflicts with other encoder directories
import importlib.util
_encoder_path = os.path.join(dppo_root, 'RLBench_grasp', 'encoder', 'trajectory_encoder.py')
_spec = importlib.util.spec_from_file_location("grasp_trajectory_encoder", _encoder_path)
_encoder_module = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_encoder_module)
TrajectoryVAE = _encoder_module.TrajectoryVAE

_train_encoder_path = os.path.join(dppo_root, 'RLBench_grasp', 'encoder', 'train_grasp_encoder.py')
_spec2 = importlib.util.spec_from_file_location("train_grasp_encoder", _train_encoder_path)
_train_encoder_module = importlib.util.module_from_spec(_spec2)
_spec2.loader.exec_module(_train_encoder_module)
make_trajectory_relative = _train_encoder_module.make_trajectory_relative

log = logging.getLogger(__name__)


class FinetuneGraspAgent:
    """FTL-IGM finetuning agent for grasp PDP-FiLM."""

    def __init__(self, cfg):
        self.cfg = cfg
        self.device = cfg.device
        self.grasp_style = cfg.grasp_style

        # Set random seed for reproducibility
        self.seed = cfg.seed
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        # Load normalization stats
        self.norm_stats = self._load_normalization(cfg.normalization_path)

        # Load encoder for warm start (optional - can be null for random init)
        encoder_checkpoint = cfg.get('encoder_checkpoint', None)
        if encoder_checkpoint is not None:
            self.encoder = self._load_encoder(encoder_checkpoint)
        else:
            log.info("No encoder checkpoint provided - will use random z initialization")
            self.encoder = None
            # Get latent_dim from model config
            self.latent_dim = cfg.model.network.z_dim

        # Load diffusion model
        self.model = hydra.utils.instantiate(cfg.model)
        self.model.eval()

        # Freeze diffusion model weights (only optimize z)
        for param in self.model.network.parameters():
            param.requires_grad = False

        # Create output directory
        self.output_dir = cfg.logdir
        os.makedirs(self.output_dir, exist_ok=True)

    def _load_normalization(self, norm_path):
        """Load normalization statistics."""
        log.info(f"Loading normalization from {norm_path}")
        norm_data = np.load(norm_path)

        obs_min = torch.from_numpy(norm_data['obs_min']).float().to(self.device)
        obs_max = torch.from_numpy(norm_data['obs_max']).float().to(self.device)
        action_min = torch.from_numpy(norm_data['action_min']).float().to(self.device)
        action_max = torch.from_numpy(norm_data['action_max']).float().to(self.device)

        obs_range = obs_max - obs_min
        obs_range = torch.where(obs_range < 1e-6, torch.ones_like(obs_range), obs_range)
        action_range = action_max - action_min
        action_range = torch.where(action_range < 1e-6, torch.ones_like(action_range), action_range)

        return {
            'obs_min': obs_min, 'obs_max': obs_max, 'obs_range': obs_range,
            'action_min': action_min, 'action_max': action_max, 'action_range': action_range,
            'obs_min_np': norm_data['obs_min'], 'obs_max_np': norm_data['obs_max'],
            'action_min_np': norm_data['action_min'], 'action_max_np': norm_data['action_max'],
        }

    def _load_encoder(self, checkpoint_path):
        """Load trajectory encoder."""
        log.info(f"Loading encoder from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
        config = checkpoint['config']

        encoder = TrajectoryVAE(
            state_dim=config['state_dim'],
            action_dim=config['action_dim'],
            hidden_dim=config['hidden_dim'],
            num_layers=config['num_layers'],
            num_heads=config['num_heads'],
            latent_dim=config['latent_dim'],
            horizon=config['horizon']
        ).to(self.device)

        encoder.load_state_dict(checkpoint['model_state_dict'])
        encoder.eval()
        for param in encoder.parameters():
            param.requires_grad = False

        self.latent_dim = config['latent_dim']
        return encoder

    def _load_demo_from_pkl(self, demo_path):
        """
        Load demo from low_dim_obs.pkl and extract states/actions.

        State format: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7) = 22
        Action format: joint_pos(7) + gripper_open(1) = 8
        """
        pkl_path = os.path.join(demo_path, "low_dim_obs.pkl")
        if not os.path.exists(pkl_path):
            raise FileNotFoundError(f"Demo file not found: {pkl_path}")

        log.info(f"Loading demo from {pkl_path}")
        with open(pkl_path, "rb") as f:
            demo = pickle.load(f)

        # Extract states and actions from observations
        states = []
        actions = []

        for obs in demo:
            # State: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
            state = np.concatenate([
                obs.joint_positions,           # 7
                obs.joint_velocities,          # 7
                [obs.gripper_open],            # 1
                obs.gripper_pose,              # 7 (pos + quat)
            ])
            states.append(state)

            # Action: from obs.misc['joint_position_action'] if available
            if hasattr(obs, 'misc') and 'joint_position_action' in obs.misc:
                action = obs.misc['joint_position_action']
            else:
                # Fallback: use joint_positions + gripper_open
                action = np.concatenate([obs.joint_positions, [obs.gripper_open]])
            actions.append(action)

        states = np.array(states)
        actions = np.array(actions)

        log.info(f"Loaded demo: {len(demo)} steps, states={states.shape}, actions={actions.shape}")

        # Adjust to match expected horizon
        horizon = self.cfg.horizon_steps
        if len(states) > horizon:
            log.info(f"Truncating demo from {len(states)} to {horizon} steps")
            states = states[:horizon]
            actions = actions[:horizon]
        elif len(states) < horizon:
            # Pad by repeating last state/action
            log.info(f"Padding demo from {len(states)} to {horizon} steps")
            pad_len = horizon - len(states)
            states = np.concatenate([states, np.tile(states[-1:], (pad_len, 1))], axis=0)
            actions = np.concatenate([actions, np.tile(actions[-1:], (pad_len, 1))], axis=0)

        return states, actions

    def _normalize_states_actions(self, states, actions):
        """Normalize states and actions to [-1, 1]."""
        obs_min = self.norm_stats['obs_min_np']
        obs_max = self.norm_stats['obs_max_np']
        action_min = self.norm_stats['action_min_np']
        action_max = self.norm_stats['action_max_np']

        obs_range = obs_max - obs_min
        obs_range = np.where(obs_range < 1e-6, 1.0, obs_range)
        action_range = action_max - action_min
        action_range = np.where(action_range < 1e-6, 1.0, action_range)

        # Normalize to [-1, 1]
        states_norm = 2.0 * (states - obs_min) / obs_range - 1.0
        actions_norm = 2.0 * (actions - action_min) / action_range - 1.0

        return states_norm, actions_norm

    def _encode_demo(self, states, actions):
        """Encode demo to get z using encoder."""
        states_tensor = torch.FloatTensor(states).to(self.device)
        actions_tensor = torch.FloatTensor(actions).to(self.device)

        with torch.no_grad():
            states_rel, actions_rel = make_trajectory_relative(
                states_tensor.unsqueeze(0), actions_tensor.unsqueeze(0)
            )
            z = self.encoder.encode(states_rel, actions_rel).squeeze(0)

        return z

    def optimize_z(self, demo_states, demo_actions, n_epochs=300, lr=3e-3,
                   n_timestep_samples=16, z_reg_weight=0.0, warm_start=True, init_z=None):
        """
        Optimize z using FTL-IGM (denoising score matching).

        Args:
            demo_states: (T, state_dim) - normalized states in [-1, 1]
            demo_actions: (T, action_dim) - normalized actions in [-1, 1]
            n_epochs: number of optimization epochs
            lr: learning rate
            n_timestep_samples: number of timesteps to sample per iteration
            z_reg_weight: weight for z regularization (0 = no reg)
            warm_start: if True, use init_z as starting point
            init_z: initial z value (from encoder)

        Returns:
            z: optimized latent code
            loss_history: list of loss values
        """
        demo_states = torch.FloatTensor(demo_states).to(self.device)
        demo_actions = torch.FloatTensor(demo_actions).to(self.device)

        # Initialize z
        if warm_start and init_z is not None:
            log.info(f"Warm start: Using encoder z (norm={init_z.norm().item():.3f})")
            z = nn.Parameter(init_z.clone().detach().requires_grad_(True))
        else:
            log.info("Cold start: Random initialization for z")
            z = nn.Parameter(torch.randn(self.latent_dim).to(self.device) * 0.1)

        optimizer = torch.optim.Adam([z], lr=lr)
        loss_history = []

        # Diffusion parameters
        alphas_cumprod = self.model.alphas_cumprod
        sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

        # Condition: first state
        cond_state = demo_states[0:1].unsqueeze(0)  # (1, 1, obs_dim)

        horizon_steps = self.cfg.horizon_steps
        action_dim = self.cfg.action_dim

        pbar = tqdm(range(n_epochs), desc="Optimizing z")
        for epoch in pbar:
            optimizer.zero_grad()

            # Sample K timesteps and noises
            K = n_timestep_samples
            t = torch.randint(0, self.model.denoising_steps, (K,), device=self.device).long()
            noise = torch.randn(K, horizon_steps, action_dim).to(self.device)

            # Expand demo actions for batch: (1, T, action_dim) -> (K, T, action_dim)
            demo_actions_batch = demo_actions.unsqueeze(0).expand(K, -1, -1)

            # Add noise to actions
            noised_actions = (sqrt_alphas_cumprod[t].view(K, 1, 1) * demo_actions_batch +
                             sqrt_one_minus_alphas_cumprod[t].view(K, 1, 1) * noise)

            # Expand condition and z for batch
            cond_batch = {'state': cond_state.expand(K, -1, -1)}
            z_batch = z.unsqueeze(0).expand(K, -1)

            # Predict noise with z
            with torch.enable_grad():
                predicted_noise = self.model.network(noised_actions, t, cond=cond_batch, z=z_batch)

            # Loss: denoising score matching + optional z regularization
            loss = F.mse_loss(predicted_noise, noise)
            if z_reg_weight > 0:
                loss = loss + z_reg_weight * torch.norm(z, p=2)

            loss.backward()
            torch.nn.utils.clip_grad_norm_([z], max_norm=1.0)
            optimizer.step()

            loss_history.append(loss.item())
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'z_norm': f'{z.norm().item():.3f}'})

        return z.detach(), loss_history

    def run(self):
        """Run the finetune process (compute and save optimized z)."""
        log.info("=" * 70)
        log.info(f"FTL-IGM Finetune for grasp (pick_up_cup) - grasp_style={self.grasp_style}")
        log.info("=" * 70)

        # Get demo path from config
        demo_path = self.cfg.demo_path

        # Load metadata
        metadata_path = os.path.join(demo_path, 'metadata.npy')
        if os.path.exists(metadata_path):
            metadata = np.load(metadata_path, allow_pickle=True).item()
            approach_angle_deg = np.degrees(metadata.get('approach_angle', 0))
            log.info(f"Target demo metadata: approach_angle={approach_angle_deg:.1f}°")
        else:
            log.warning(f"No metadata found at {metadata_path}")
            metadata = {}

        # Load demo
        states, actions = self._load_demo_from_pkl(demo_path)
        states_norm, actions_norm = self._normalize_states_actions(states, actions)
        log.info(f"Normalized demo: states={states_norm.shape}, actions={actions_norm.shape}")

        # Get initial z (from encoder if available, else None for random init)
        if self.encoder is not None:
            init_z = self._encode_demo(states_norm, actions_norm)
            log.info(f"Encoder z from demo: {init_z.cpu().numpy()}")
        else:
            init_z = None
            log.info("No encoder - will use random z initialization for FTL-IGM")

        # Optimize z using FTL-IGM
        n_epochs = self.cfg.ftl_igm.n_epochs
        if n_epochs > 0:
            log.info("\n--- FTL-IGM Optimization ---")
            optimized_z, loss_history = self.optimize_z(
                demo_states=states_norm,
                demo_actions=actions_norm,
                n_epochs=n_epochs,
                lr=self.cfg.ftl_igm.lr,
                n_timestep_samples=self.cfg.ftl_igm.n_timestep_samples,
                z_reg_weight=self.cfg.ftl_igm.z_reg_weight,
                warm_start=self.cfg.ftl_igm.warm_start if init_z is not None else False,
                init_z=init_z,
            )
            log.info(f"Optimized z: {optimized_z.cpu().numpy()}")
            log.info(f"Final loss: {loss_history[-1]:.4f}")
        else:
            if init_z is not None:
                log.info("\n--- Skipping FTL-IGM Optimization (n_epochs=0) ---")
                log.info("Using encoder z directly from demo")
                optimized_z = init_z
            else:
                log.info("\n--- Using random z (no encoder, n_epochs=0) ---")
                optimized_z = torch.randn(self.latent_dim, device=self.device) * 0.1
                log.info(f"Random z: {optimized_z.cpu().numpy()}")
            loss_history = []

        # Save optimized z (as 2D array for compatibility with eval agent)
        z_save_path = os.path.join(self.output_dir, f'optimized_z_style{self.grasp_style}.npy')
        z_to_save = optimized_z.cpu().numpy().reshape(1, -1)  # Shape: (1, z_dim)
        np.save(z_save_path, z_to_save)
        log.info(f"Saved optimized z to {z_save_path} (shape: {z_to_save.shape})")

        # Plot loss curve
        self._plot_loss_curve(loss_history)

        # Run evaluation with optimized z
        log.info("\n" + "=" * 70)
        log.info("FTL-IGM Optimization Complete! Running evaluation...")
        log.info("=" * 70)

        self._run_eval(z_save_path)

    def _plot_loss_curve(self, loss_history):
        """Plot and save the optimization loss curve."""
        if len(loss_history) == 0:
            log.info("No loss history to plot (optimization skipped)")
            return

        plt.figure(figsize=(10, 6))
        plt.plot(loss_history)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(f'FTL-IGM Optimization Loss (grasp_style={self.grasp_style})')
        plt.grid(True, alpha=0.3)
        save_path = os.path.join(self.output_dir, 'loss_curve.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        log.info(f"Saved loss curve to {save_path}")

    def _run_eval(self, z_file_path):
        """Run the eval agent with the optimized z (40 episodes with style checking)."""
        eval_config_name = self.cfg.get('eval_config_name', 'eval_para_diffusion_film_finetuned')

        eval_cmd = [
            "python", "script/run.py",
            "--config-dir=cfg/rlbench/eval/grasp",
            f"--config-name={eval_config_name}",
            f"z_file={z_file_path}",
            f"grasp_style={self.grasp_style}",  # Pass grasp_style for angle-based success check
        ]

        log.info(f"Running eval command: {' '.join(eval_cmd)}")

        result = subprocess.run(
            eval_cmd,
            cwd=dppo_root,
            capture_output=False,
        )

        if result.returncode != 0:
            log.error(f"Eval command failed with return code {result.returncode}")
        else:
            log.info("Evaluation completed successfully!")
