"""
Finetune Agent for real robot 3D close_drawer task.

Supports:
- PDP-FiLM: FTL-IGM optimization of z (denoising score matching)
- DP: Backprop/inversion optimization of x_T

Given a finetune demo (3D EE trajectory), this agent:
1. Loads the preprocessed demo from finetune_data/style{id}_h100.npz
2. For PDP: Encodes demo to get warm start z, optimizes via FTL-IGM
3. For DP: Optimizes initial noise x_T to reconstruct demo
4. Saves optimized parameters and runs evaluation

Usage:
    python script/run.py --config-name=ft_para_diffusion_film --config-dir=cfg/rlbench/finetune/close_drawer_real style=1
    python script/run.py --config-name=ft_diffusion_mlp --config-dir=cfg/rlbench/finetune/close_drawer_real style=1
"""

import os
import sys
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import hydra

# Add paths
dppo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(dppo_root)
sys.path.insert(0, os.path.join(dppo_root, 'RLBench_close_drawer_real', 'encoder'))

log = logging.getLogger(__name__)


class FinetuneReal3DAgent:
    """Finetune agent for real robot 3D policies (PDP and DP)."""

    def __init__(self, cfg):
        self.cfg = cfg
        self.device = cfg.device
        self.style = cfg.style
        self.model_type = cfg.model_type  # 'pdp' or 'dp'

        # Load normalization stats
        self.norm_stats = self._load_normalization(cfg.normalization_path)

        # Load model
        self.model = hydra.utils.instantiate(cfg.model)
        self.model.eval()

        # For PDP, load encoder
        if self.model_type == 'pdp':
            self.encoder = self._load_encoder(cfg.encoder_checkpoint_path)
            self.latent_dim = cfg.z_dim

        # Freeze model weights (only optimize z or x_T)
        for param in self.model.parameters():
            param.requires_grad = False

        # Create output directory
        self.output_dir = cfg.logdir
        os.makedirs(self.output_dir, exist_ok=True)

    def _load_normalization(self, norm_path):
        """Load normalization statistics."""
        log.info(f"Loading normalization from {norm_path}")
        norm_data = np.load(norm_path)

        return {
            'obs_min': norm_data['obs_min'],
            'obs_max': norm_data['obs_max'],
            'action_min': norm_data['action_min'],
            'action_max': norm_data['action_max'],
        }

    def _load_encoder(self, checkpoint_path):
        """Load 3D trajectory encoder."""
        from trajectory_encoder_3d import TrajectoryVAE3D

        log.info(f"Loading encoder from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
        config = checkpoint['config']

        encoder = TrajectoryVAE3D(
            pos_dim=config['pos_dim'],
            hidden_dim=config['hidden_dim'],
            num_layers=config['num_layers'],
            num_heads=config['num_heads'],
            latent_dim=config['latent_dim'],
            horizon=config['horizon']
        ).to(self.device)

        encoder.load_state_dict(checkpoint['model_state_dict'])
        encoder.eval()
        for param in encoder.parameters():
            param.requires_grad = False

        return encoder

    def _load_finetune_demo(self, style):
        """
        Load preprocessed finetune demo.

        Returns:
            states: (H, 3) - EE positions
            actions: (H, 3) - EE position targets
            states_norm: (H, 3) - normalized
            actions_norm: (H, 3) - normalized
        """
        demo_path = os.path.join(
            dppo_root, 'RLBench_close_drawer_real', 'finetune_data',
            f'style{style}_h{self.cfg.horizon_steps}.npz'
        )

        if not os.path.exists(demo_path):
            raise FileNotFoundError(
                f"Finetune demo not found: {demo_path}\n"
                f"Please run: python RLBench_close_drawer_real/make_dataset/preprocess_finetune_data.py --style {style}"
            )

        log.info(f"Loading finetune demo from {demo_path}")
        data = np.load(demo_path)

        # Get first episode only (single demo for finetuning)
        traj_length = data['traj_lengths'][0]
        states = data['states'][:traj_length]
        actions = data['actions'][:traj_length]
        states_norm = data['states_norm'][:traj_length]
        actions_norm = data['actions_norm'][:traj_length]

        log.info(f"Loaded demo: {traj_length} steps, states={states.shape}")

        return states, actions, states_norm, actions_norm

    def _encode_demo(self, states_norm):
        """
        Encode demo to get z using 3D encoder.

        Args:
            states_norm: (H, 3) normalized EE positions

        Returns:
            z: (z_dim,) latent code
        """
        positions = torch.FloatTensor(states_norm).unsqueeze(0).to(self.device)  # (1, H, 3)

        # Make relative (encoder expects relative positions)
        p0 = positions[:, 0:1, :]
        positions_rel = positions - p0

        with torch.no_grad():
            z_mean, _, _ = self.encoder.encode(positions_rel)
            z = z_mean.squeeze(0)  # (z_dim,)

        return z

    def optimize_z_ftl_igm(self, demo_actions_norm, n_epochs=300, lr=3e-3,
                            n_timestep_samples=16, z_reg_weight=0.0,
                            warm_start=True, init_z=None):
        """
        Optimize z using FTL-IGM (denoising score matching).

        Args:
            demo_actions_norm: (H, 3) - normalized actions
            n_epochs: number of optimization epochs
            lr: learning rate
            n_timestep_samples: number of timesteps to sample per iteration
            z_reg_weight: weight for z regularization
            warm_start: if True, use init_z as starting point
            init_z: initial z value (from encoder)

        Returns:
            z: optimized latent code
            loss_history: list of loss values
        """
        demo_actions = torch.FloatTensor(demo_actions_norm).to(self.device)

        # Initialize z
        if warm_start and init_z is not None:
            log.info(f"Warm start: Using encoder z (norm={init_z.norm().item():.3f})")
            z = nn.Parameter(init_z.clone().detach().requires_grad_(True))
        else:
            log.info("Cold start: Random initialization for z")
            z = nn.Parameter(torch.randn(self.latent_dim, device=self.device) * 0.1)

        optimizer = torch.optim.Adam([z], lr=lr)
        loss_history = []

        # Diffusion parameters
        alphas_cumprod = self.model.alphas_cumprod
        sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

        # Condition: first state (for 3D, use first EE position)
        # Shape: (1, cond_steps, obs_dim) = (1, 1, 3)
        first_state = demo_actions[0:1]  # Use first action as condition (EE pos)
        cond_state = first_state.unsqueeze(0)  # (1, 1, 3)

        horizon_steps = self.cfg.horizon_steps
        action_dim = self.cfg.action_dim

        pbar = tqdm(range(n_epochs), desc="FTL-IGM Optimization")
        for epoch in pbar:
            optimizer.zero_grad()

            # Sample K timesteps and noises
            K = n_timestep_samples
            t = torch.randint(0, self.model.denoising_steps, (K,), device=self.device).long()
            noise = torch.randn(K, horizon_steps, action_dim, device=self.device)

            # Expand demo actions for batch: (H, 3) -> (K, H, 3)
            demo_actions_batch = demo_actions.unsqueeze(0).expand(K, -1, -1)

            # Add noise to actions
            noised_actions = (sqrt_alphas_cumprod[t].view(K, 1, 1) * demo_actions_batch +
                             sqrt_one_minus_alphas_cumprod[t].view(K, 1, 1) * noise)

            # Expand condition and z for batch
            cond_batch = {'state': cond_state.expand(K, -1, -1)}
            z_batch = z.unsqueeze(0).expand(K, -1)

            # Predict noise with z
            with torch.enable_grad():
                predicted_noise = self.model.network(noised_actions, t, cond=cond_batch, z=z_batch)

            # Loss: denoising score matching + optional z regularization
            loss = F.mse_loss(predicted_noise, noise)
            if z_reg_weight > 0:
                loss = loss + z_reg_weight * torch.norm(z, p=2)

            loss.backward()
            torch.nn.utils.clip_grad_norm_([z], max_norm=1.0)
            optimizer.step()

            loss_history.append(loss.item())
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'z_norm': f'{z.norm().item():.3f}'})

        return z.detach(), loss_history

    def optimize_x_T_inversion(self, demo_states_norm, demo_actions_norm, n_epochs=300, lr=1e-2):
        """
        DP Inversion: Optimize initial noise x_T to reconstruct demo trajectory.

        Args:
            demo_states_norm: (H, 3) - normalized states
            demo_actions_norm: (H, 3) - normalized actions
            n_epochs: number of optimization epochs
            lr: learning rate

        Returns:
            x_T: (1, H, 3) optimized initial noise
            loss_history: list of loss values
        """
        demo_actions = torch.FloatTensor(demo_actions_norm).to(self.device)

        # Initialize x_T as learnable parameter
        horizon_steps = self.cfg.horizon_steps
        action_dim = self.cfg.action_dim
        x_T = nn.Parameter(torch.randn(1, horizon_steps, action_dim, device=self.device))

        optimizer = torch.optim.Adam([x_T], lr=lr)
        loss_history = []

        # Get DDIM parameters
        ddim_t = self.model.ddim_t
        ddim_alphas = self.model.ddim_alphas
        ddim_alphas_prev = self.model.ddim_alphas_prev
        ddim_sqrt_one_minus_alphas = self.model.ddim_sqrt_one_minus_alphas

        # Condition: first state
        first_state = torch.FloatTensor(demo_states_norm[0:1]).to(self.device)
        cond = {'state': first_state.unsqueeze(0)}  # (1, 1, 3)

        # Target: full action trajectory
        target_actions = demo_actions.unsqueeze(0)  # (1, H, 3)

        pbar = tqdm(range(n_epochs), desc="DP Inversion")
        for epoch in pbar:
            optimizer.zero_grad()

            # Forward pass through DDIM sampler (deterministic)
            x = x_T.clone()

            for i, t in enumerate(ddim_t):
                t_b = torch.tensor([t], device=self.device).long()
                index_b = torch.tensor([i], device=self.device).long()

                # Get alpha values
                alpha = ddim_alphas[index_b].view(1, 1, 1)
                alpha_prev = ddim_alphas_prev[index_b].view(1, 1, 1)
                sqrt_one_minus_alpha = ddim_sqrt_one_minus_alphas[index_b].view(1, 1, 1)

                # Predict noise (frozen network)
                with torch.no_grad():
                    noise_pred = self.model.actor(x, t_b, cond=cond)

                # DDIM step (deterministic, eta=0)
                x_recon = (x - sqrt_one_minus_alpha * noise_pred) / (alpha ** 0.5)

                # Clip x_0
                if self.model.denoised_clip_value is not None:
                    x_recon = x_recon.clamp(-self.model.denoised_clip_value, self.model.denoised_clip_value)

                # mu = sqrt(alpha_prev) * x_0 + sqrt(1 - alpha_prev) * eps
                dir_xt = (1.0 - alpha_prev).sqrt() * noise_pred
                x = (alpha_prev ** 0.5) * x_recon + dir_xt

            # Compute loss
            loss = F.mse_loss(x, target_actions)

            # Backward through x_T only
            loss.backward()
            optimizer.step()

            loss_history.append(loss.item())
            pbar.set_postfix({'loss': f'{loss.item():.6f}', 'x_T_norm': f'{x_T.norm().item():.3f}'})

        return x_T.detach(), loss_history

    def run(self):
        """Run the finetune process."""
        log.info("=" * 70)
        log.info(f"Finetune {self.model_type.upper()} for real robot close_drawer (style={self.style})")
        log.info("=" * 70)

        # 1. Load finetune demo
        states, actions, states_norm, actions_norm = self._load_finetune_demo(self.style)

        if self.model_type == 'pdp':
            # 2. Encode demo to get warm start z
            init_z = self._encode_demo(states_norm)
            log.info(f"Encoder z from demo: {init_z.cpu().numpy()}")

            # 3. Optimize z using FTL-IGM
            ftl_cfg = self.cfg.ftl_igm
            if ftl_cfg.n_epochs > 0:
                log.info("\n--- FTL-IGM Optimization ---")
                optimized_z, loss_history = self.optimize_z_ftl_igm(
                    demo_actions_norm=actions_norm,
                    n_epochs=ftl_cfg.n_epochs,
                    lr=ftl_cfg.lr,
                    n_timestep_samples=ftl_cfg.n_timestep_samples,
                    z_reg_weight=ftl_cfg.z_reg_weight,
                    warm_start=ftl_cfg.warm_start,
                    init_z=init_z,
                )
                log.info(f"Optimized z: {optimized_z.cpu().numpy()}")
                log.info(f"Final loss: {loss_history[-1]:.4f}")
            else:
                log.info("Skipping FTL-IGM (n_epochs=0), using encoder z directly")
                optimized_z = init_z
                loss_history = []

            # 4. Save optimized z
            z_save_path = os.path.join(self.output_dir, f'optimized_z_style{self.style}.npy')
            z_to_save = optimized_z.cpu().numpy().reshape(1, -1)
            np.save(z_save_path, z_to_save)
            log.info(f"Saved optimized z to {z_save_path}")

            # 5. Plot loss curve
            self._plot_loss_curve(loss_history, "FTL-IGM")

            # 6. Generate and save evaluation trajectory
            self._generate_and_save_trajectory(states_norm, optimized_z=optimized_z)

        else:  # DP
            # 2. Optimize x_T using inversion
            dp_cfg = self.cfg.dp_inversion
            log.info("\n--- DP Inversion ---")
            optimized_x_T, loss_history = self.optimize_x_T_inversion(
                demo_states_norm=states_norm,
                demo_actions_norm=actions_norm,
                n_epochs=dp_cfg.n_epochs,
                lr=dp_cfg.lr,
            )
            log.info(f"Final loss: {loss_history[-1]:.6f}")

            # 3. Save optimized x_T
            x_T_save_path = os.path.join(self.output_dir, f'optimized_x_T_style{self.style}.npy')
            np.save(x_T_save_path, optimized_x_T.cpu().numpy())
            log.info(f"Saved optimized x_T to {x_T_save_path}")

            # 4. Plot loss curve
            self._plot_loss_curve(loss_history, "DP Inversion")

            # 5. Generate and save evaluation trajectory
            self._generate_and_save_trajectory(states_norm, optimized_x_T=optimized_x_T)

        log.info("\n" + "=" * 70)
        log.info("Finetuning complete!")
        log.info("=" * 70)

    def _plot_loss_curve(self, loss_history, title_prefix):
        """Plot and save the optimization loss curve."""
        if len(loss_history) == 0:
            return

        plt.figure(figsize=(10, 6))
        plt.plot(loss_history)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(f'{title_prefix} Loss (style={self.style})')
        plt.grid(True, alpha=0.3)
        save_path = os.path.join(self.output_dir, 'loss_curve.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        log.info(f"Saved loss curve to {save_path}")

    def _generate_and_save_trajectory(self, states_norm, optimized_z=None, optimized_x_T=None):
        """Generate trajectory with optimized parameters and save for visualization."""
        # Condition: first state
        first_state = torch.FloatTensor(states_norm[0:1]).to(self.device)
        cond = {'state': first_state.unsqueeze(0)}  # (1, 1, 3)

        with torch.no_grad():
            if self.model_type == 'pdp' and optimized_z is not None:
                # Set z and generate
                self.model.current_z = optimized_z.unsqueeze(0)
                samples = self.model(cond=cond, deterministic=True)
            elif self.model_type == 'dp' and optimized_x_T is not None:
                # Use optimized x_T and run DDIM
                # For simplicity, just save the x_T and note that eval should use it
                log.info("DP: Saved optimized x_T for evaluation")
                return
            else:
                log.warning("No optimized parameters provided")
                return

            actions_norm = samples.trajectories.cpu().numpy()[0]  # (H, 3)

        # Denormalize
        action_range = self.norm_stats['action_max'] - self.norm_stats['action_min']
        actions_raw = (actions_norm + 1.0) * action_range / 2.0 + self.norm_stats['action_min']

        # Save trajectory
        traj_save_path = os.path.join(self.output_dir, f'generated_trajectory_style{self.style}.npz')
        np.savez(
            traj_save_path,
            actions_norm=actions_norm,
            actions_raw=actions_raw,
            style=self.style,
        )
        log.info(f"Saved generated trajectory to {traj_save_path}")

        # Plot comparison with demo
        self._plot_trajectory_comparison(states_norm, actions_norm)

    def _plot_trajectory_comparison(self, demo_states_norm, generated_actions_norm):
        """Plot comparison between demo and generated trajectory."""
        from mpl_toolkits.mplot3d import Axes3D

        fig = plt.figure(figsize=(12, 5))

        # 3D plot
        ax1 = fig.add_subplot(121, projection='3d')
        ax1.plot(demo_states_norm[:, 0], demo_states_norm[:, 1], demo_states_norm[:, 2],
                'b-', linewidth=2, label='Demo', alpha=0.8)
        ax1.plot(generated_actions_norm[:, 0], generated_actions_norm[:, 1], generated_actions_norm[:, 2],
                'r--', linewidth=2, label='Generated', alpha=0.8)
        ax1.scatter(demo_states_norm[0, 0], demo_states_norm[0, 1], demo_states_norm[0, 2],
                   c='green', s=100, marker='o', label='Start')
        ax1.set_xlabel('X')
        ax1.set_ylabel('Y')
        ax1.set_zlabel('Z')
        ax1.set_title(f'3D Trajectory Comparison (style={self.style})')
        ax1.legend()

        # Per-dimension plot
        ax2 = fig.add_subplot(122)
        t = np.arange(len(demo_states_norm))
        for i, (name, color) in enumerate(zip(['X', 'Y', 'Z'], ['r', 'g', 'b'])):
            ax2.plot(t, demo_states_norm[:, i], f'{color}-', label=f'Demo {name}', alpha=0.7)
            ax2.plot(t, generated_actions_norm[:, i], f'{color}--', label=f'Gen {name}', alpha=0.7)
        ax2.set_xlabel('Time step')
        ax2.set_ylabel('Position (normalized)')
        ax2.set_title('Per-dimension Comparison')
        ax2.legend(ncol=2, fontsize=8)
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()
        save_path = os.path.join(self.output_dir, 'trajectory_comparison.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        log.info(f"Saved trajectory comparison to {save_path}")
