"""
FTL-IGM Finetune Agent for grasp (pick_up_cup) task with PDP-FiLM-Categorical model.

For categorical conditioning, there is NO encoder.
Instead, we directly optimize the z embedding (4D one-hot or soft vector) using FTL-IGM.

Given a target demo, this agent:
1. Loads the demo from the specified path
2. Initializes z randomly or from a prior (e.g., uniform over categories)
3. Optimizes z using FTL-IGM (denoising score matching)
4. Saves the optimized z to a file
5. Automatically runs the eval agent with the optimized z

Demo selection via grasp_style:
- grasp_style=1: Mode 0° (train, episode0)
- grasp_style=2: Mode 270° (train, episode30, noise-free)
- grasp_style=3: Test 45° (novel angle)
"""

import os
import sys
import logging
import subprocess
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import hydra

# Add paths
dppo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(dppo_root)

from model.diffusion.parameterized_diffusion_eval import ParameterizedDiffusionEval
from model.diffusion.parameterized_diffusion import FiLMDiffusionMLP

log = logging.getLogger(__name__)


class FinetuneGraspCategoricalAgent:
    """FTL-IGM finetuning agent for grasp PDP-FiLM-Categorical (no encoder)."""

    def __init__(self, cfg):
        self.cfg = cfg
        self.device = cfg.device
        self.grasp_style = cfg.grasp_style
        self.num_categories = cfg.get('num_categories', 4)
        self.latent_dim = self.num_categories

        # Set random seed for reproducibility
        self.seed = cfg.seed
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        # Load normalization stats
        self.norm_stats = self._load_normalization(cfg.normalization_path)

        # Load diffusion model
        self.model = hydra.utils.instantiate(cfg.model)
        self.model.eval()

        # Freeze diffusion model weights (only optimize z)
        for param in self.model.network.parameters():
            param.requires_grad = False

        # Create output directory
        self.output_dir = cfg.logdir
        os.makedirs(self.output_dir, exist_ok=True)

    def _load_normalization(self, norm_path):
        """Load normalization statistics."""
        log.info(f"Loading normalization from {norm_path}")
        norm_data = np.load(norm_path)

        obs_min = torch.from_numpy(norm_data['obs_min']).float().to(self.device)
        obs_max = torch.from_numpy(norm_data['obs_max']).float().to(self.device)
        action_min = torch.from_numpy(norm_data['action_min']).float().to(self.device)
        action_max = torch.from_numpy(norm_data['action_max']).float().to(self.device)

        obs_range = obs_max - obs_min
        obs_range = torch.where(obs_range < 1e-6, torch.ones_like(obs_range), obs_range)
        action_range = action_max - action_min
        action_range = torch.where(action_range < 1e-6, torch.ones_like(action_range), action_range)

        return {
            'obs_min': obs_min, 'obs_max': obs_max, 'obs_range': obs_range,
            'action_min': action_min, 'action_max': action_max, 'action_range': action_range,
            'obs_min_np': norm_data['obs_min'], 'obs_max_np': norm_data['obs_max'],
            'action_min_np': norm_data['action_min'], 'action_max_np': norm_data['action_max'],
        }

    def _load_demo_from_pkl(self, demo_path):
        """
        Load demo from low_dim_obs.pkl and extract states/actions.

        State format: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7) = 22
        Action format: joint_pos(7) + gripper_open(1) = 8
        """
        pkl_path = os.path.join(demo_path, "low_dim_obs.pkl")
        if not os.path.exists(pkl_path):
            raise FileNotFoundError(f"Demo file not found: {pkl_path}")

        log.info(f"Loading demo from {pkl_path}")
        with open(pkl_path, "rb") as f:
            demo = pickle.load(f)

        # Extract states and actions from observations
        states = []
        actions = []

        for obs in demo:
            # State: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
            state = np.concatenate([
                obs.joint_positions,           # 7
                obs.joint_velocities,          # 7
                [obs.gripper_open],            # 1
                obs.gripper_pose,              # 7 (pos + quat)
            ])
            states.append(state)

            # Action: from obs.misc['joint_position_action'] if available
            if hasattr(obs, 'misc') and 'joint_position_action' in obs.misc:
                action = obs.misc['joint_position_action']
            else:
                # Fallback: use joint_positions + gripper_open
                action = np.concatenate([obs.joint_positions, [obs.gripper_open]])
            actions.append(action)

        states = np.array(states)
        actions = np.array(actions)

        log.info(f"Loaded demo: {len(demo)} steps, states={states.shape}, actions={actions.shape}")

        # Adjust to match expected horizon
        horizon = self.cfg.horizon_steps
        if len(states) > horizon:
            log.info(f"Truncating demo from {len(states)} to {horizon} steps")
            states = states[:horizon]
            actions = actions[:horizon]
        elif len(states) < horizon:
            # Pad by repeating last state/action
            log.info(f"Padding demo from {len(states)} to {horizon} steps")
            pad_len = horizon - len(states)
            states = np.concatenate([states, np.tile(states[-1:], (pad_len, 1))], axis=0)
            actions = np.concatenate([actions, np.tile(actions[-1:], (pad_len, 1))], axis=0)

        return states, actions

    def _normalize_states_actions(self, states, actions):
        """Normalize states and actions to [-1, 1]."""
        obs_min = self.norm_stats['obs_min_np']
        obs_max = self.norm_stats['obs_max_np']
        action_min = self.norm_stats['action_min_np']
        action_max = self.norm_stats['action_max_np']

        obs_range = obs_max - obs_min
        obs_range = np.where(obs_range < 1e-6, 1.0, obs_range)
        action_range = action_max - action_min
        action_range = np.where(action_range < 1e-6, 1.0, action_range)

        # Normalize to [-1, 1]
        states_norm = 2.0 * (states - obs_min) / obs_range - 1.0
        actions_norm = 2.0 * (actions - action_min) / action_range - 1.0

        return states_norm, actions_norm

    def _get_init_z(self):
        """
        Get initial z for categorical model.

        Options:
        - uniform: start with uniform distribution over categories
        - random: random initialization
        - one_hot: start with a specific one-hot based on grasp_style
        """
        init_mode = self.cfg.get('z_init_mode', 'uniform')

        if init_mode == 'uniform':
            # Uniform distribution over categories (soft one-hot)
            z = torch.ones(self.num_categories, device=self.device) / self.num_categories
            log.info(f"Init z: uniform distribution {z.cpu().numpy()}")
        elif init_mode == 'random':
            # Random initialization
            z = torch.randn(self.num_categories, device=self.device) * 0.1
            log.info(f"Init z: random {z.cpu().numpy()}")
        elif init_mode == 'one_hot':
            # Use grasp_style to determine initial one-hot
            # Style 1 (0°) -> mode 0, Style 2 (270°) -> mode 3, Style 3 (45°) -> novel
            style_to_mode = {1: 0, 2: 3, 3: 0}  # Style 3 has no direct mapping, use mode 0
            mode_idx = style_to_mode.get(self.grasp_style, 0)
            z = torch.zeros(self.num_categories, device=self.device)
            z[mode_idx] = 1.0
            log.info(f"Init z: one-hot for mode {mode_idx} {z.cpu().numpy()}")
        else:
            raise ValueError(f"Unknown z_init_mode: {init_mode}")

        return z

    def optimize_z(self, demo_states, demo_actions, n_epochs=300, lr=3e-3,
                   n_timestep_samples=16, z_reg_weight=0.0):
        """
        Optimize z using FTL-IGM (denoising score matching).

        For categorical, z is a 4D vector that can be any values (not necessarily one-hot).

        Args:
            demo_states: (T, state_dim) - normalized states in [-1, 1]
            demo_actions: (T, action_dim) - normalized actions in [-1, 1]
            n_epochs: number of optimization epochs
            lr: learning rate
            n_timestep_samples: number of timesteps to sample per iteration
            z_reg_weight: weight for z regularization (0 = no reg)

        Returns:
            z: optimized latent code
            loss_history: list of loss values
        """
        demo_states = torch.FloatTensor(demo_states).to(self.device)
        demo_actions = torch.FloatTensor(demo_actions).to(self.device)

        # Initialize z
        init_z = self._get_init_z()
        z = nn.Parameter(init_z.clone().detach().requires_grad_(True))

        optimizer = torch.optim.Adam([z], lr=lr)
        loss_history = []

        # Diffusion parameters
        alphas_cumprod = self.model.alphas_cumprod
        sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

        # Condition: first state
        cond_state = demo_states[0:1].unsqueeze(0)  # (1, 1, obs_dim)

        horizon_steps = self.cfg.horizon_steps
        action_dim = self.cfg.action_dim

        pbar = tqdm(range(n_epochs), desc="Optimizing z (categorical)")
        for epoch in pbar:
            optimizer.zero_grad()

            # Sample K timesteps and noises
            K = n_timestep_samples
            t = torch.randint(0, self.model.denoising_steps, (K,), device=self.device).long()
            noise = torch.randn(K, horizon_steps, action_dim).to(self.device)

            # Expand demo actions for batch: (1, T, action_dim) -> (K, T, action_dim)
            demo_actions_batch = demo_actions.unsqueeze(0).expand(K, -1, -1)

            # Add noise to actions
            noised_actions = (sqrt_alphas_cumprod[t].view(K, 1, 1) * demo_actions_batch +
                             sqrt_one_minus_alphas_cumprod[t].view(K, 1, 1) * noise)

            # Expand condition and z for batch
            cond_batch = {'state': cond_state.expand(K, -1, -1)}
            z_batch = z.unsqueeze(0).expand(K, -1)

            # Predict noise with z
            with torch.enable_grad():
                predicted_noise = self.model.network(noised_actions, t, cond=cond_batch, z=z_batch)

            # Loss: denoising score matching + optional z regularization
            loss = F.mse_loss(predicted_noise, noise)
            if z_reg_weight > 0:
                loss = loss + z_reg_weight * torch.norm(z, p=2)

            loss.backward()
            torch.nn.utils.clip_grad_norm_([z], max_norm=1.0)
            optimizer.step()

            loss_history.append(loss.item())
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'z_norm': f'{z.norm().item():.3f}'})

        return z.detach(), loss_history

    def run(self):
        """Run the finetune process (compute and save optimized z)."""
        log.info("=" * 70)
        log.info(f"FTL-IGM Finetune for grasp (pick_up_cup) - Categorical - grasp_style={self.grasp_style}")
        log.info("=" * 70)

        # Get demo path from config
        demo_path = self.cfg.demo_path

        # Load metadata
        metadata_path = os.path.join(demo_path, 'metadata.npy')
        if os.path.exists(metadata_path):
            metadata = np.load(metadata_path, allow_pickle=True).item()
            approach_angle_deg = np.degrees(metadata.get('approach_angle', 0))
            log.info(f"Target demo metadata: approach_angle={approach_angle_deg:.1f}°")
        else:
            log.warning(f"No metadata found at {metadata_path}")
            metadata = {}

        # Load demo
        states, actions = self._load_demo_from_pkl(demo_path)
        states_norm, actions_norm = self._normalize_states_actions(states, actions)
        log.info(f"Normalized demo: states={states_norm.shape}, actions={actions_norm.shape}")

        # No encoder for categorical - directly optimize z
        log.info("Categorical model: No encoder, directly optimizing z with FTL-IGM")

        # Optimize z using FTL-IGM
        n_epochs = self.cfg.ftl_igm.n_epochs
        if n_epochs > 0:
            log.info("\n--- FTL-IGM Optimization ---")
            optimized_z, loss_history = self.optimize_z(
                demo_states=states_norm,
                demo_actions=actions_norm,
                n_epochs=n_epochs,
                lr=self.cfg.ftl_igm.lr,
                n_timestep_samples=self.cfg.ftl_igm.n_timestep_samples,
                z_reg_weight=self.cfg.ftl_igm.z_reg_weight,
            )
            log.info(f"Optimized z: {optimized_z.cpu().numpy()}")
            log.info(f"Final loss: {loss_history[-1]:.4f}")
        else:
            log.info("\n--- Skipping FTL-IGM Optimization (n_epochs=0) ---")
            log.info("Using initial z directly")
            optimized_z = self._get_init_z()
            loss_history = []

        # Save optimized z (as 2D array for compatibility with eval agent)
        z_save_path = os.path.join(self.output_dir, f'optimized_z_style{self.grasp_style}.npy')
        z_to_save = optimized_z.cpu().numpy().reshape(1, -1)  # Shape: (1, z_dim)
        np.save(z_save_path, z_to_save)
        log.info(f"Saved optimized z to {z_save_path} (shape: {z_to_save.shape})")

        # Plot loss curve
        self._plot_loss_curve(loss_history)

        # Run evaluation with optimized z
        log.info("\n" + "=" * 70)
        log.info("FTL-IGM Optimization Complete! Running evaluation...")
        log.info("=" * 70)

        self._run_eval(z_save_path)

    def _plot_loss_curve(self, loss_history):
        """Plot and save the optimization loss curve."""
        if len(loss_history) == 0:
            log.info("No loss history to plot (optimization skipped)")
            return

        plt.figure(figsize=(10, 6))
        plt.plot(loss_history)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(f'FTL-IGM Optimization Loss - Categorical (grasp_style={self.grasp_style})')
        plt.grid(True, alpha=0.3)
        save_path = os.path.join(self.output_dir, 'loss_curve.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        log.info(f"Saved loss curve to {save_path}")

    def _run_eval(self, z_file_path):
        """Run the eval agent with the optimized z (40 episodes with style checking)."""
        eval_config_name = self.cfg.get('eval_config_name', 'eval_para_diffusion_film_categorical_finetuned')

        eval_cmd = [
            "python", "script/run.py",
            "--config-dir=cfg/rlbench/eval/grasp",
            f"--config-name={eval_config_name}",
            f"z_file={z_file_path}",
            f"grasp_style={self.grasp_style}",  # Pass grasp_style for angle-based success check
        ]

        log.info(f"Running eval command: {' '.join(eval_cmd)}")

        result = subprocess.run(
            eval_cmd,
            cwd=dppo_root,
            capture_output=False,
        )

        if result.returncode != 0:
            log.error(f"Eval command failed with return code {result.returncode}")
        else:
            log.info("Evaluation completed successfully!")
