"""
Finetune Agent for baseline models (DP, BC, BC-GMM, IBC) on close_drawer task.

For non-parameterized models, "finetuning" means:
- BC: Actually finetune the model on the wall-avoiding demo trajectory
- DP, BC-GMM, IBC: No training, just run eval with wall collision checking

This agent:
1. Loads the target wall demo (from low_dim_obs.pkl)
2. For BC: Finetunes model on the demo, saves checkpoint
3. For others: Just runs eval with wall collision
4. Automatically runs the eval agent
"""

import os
import sys
import logging
import subprocess
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import hydra

# Add paths
dppo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(dppo_root)
sys.path.append(os.path.join(dppo_root, 'RLBench'))

log = logging.getLogger(__name__)


class FinetuneBaselineAgent:
    """Finetune agent for baseline models (DP, BC, BC-GMM, IBC)."""

    def __init__(self, cfg):
        self.cfg = cfg
        self.device = cfg.device
        self.wall_style = cfg.wall_style
        self.model_type = cfg.model_type  # 'dp', 'bc', 'bc_gmm', 'ibc'
        self.do_training = cfg.get('do_training', False)

        # Set random seed for reproducibility
        self.seed = cfg.seed
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        # Load normalization stats
        self.norm_stats = self._load_normalization(cfg.normalization_path)

        # Load model
        self.model = hydra.utils.instantiate(cfg.model)
        self.model.eval()

        # Create output directory
        self.output_dir = cfg.logdir
        os.makedirs(self.output_dir, exist_ok=True)

    def _load_normalization(self, norm_path):
        """Load normalization statistics."""
        log.info(f"Loading normalization from {norm_path}")
        norm_data = np.load(norm_path)

        obs_min = torch.from_numpy(norm_data['obs_min']).float().to(self.device)
        obs_max = torch.from_numpy(norm_data['obs_max']).float().to(self.device)
        action_min = torch.from_numpy(norm_data['action_min']).float().to(self.device)
        action_max = torch.from_numpy(norm_data['action_max']).float().to(self.device)

        obs_range = obs_max - obs_min
        obs_range = torch.where(obs_range < 1e-6, torch.ones_like(obs_range), obs_range)
        action_range = action_max - action_min
        action_range = torch.where(action_range < 1e-6, torch.ones_like(action_range), action_range)

        return {
            'obs_min': obs_min, 'obs_max': obs_max, 'obs_range': obs_range,
            'action_min': action_min, 'action_max': action_max, 'action_range': action_range,
            'obs_min_np': norm_data['obs_min'], 'obs_max_np': norm_data['obs_max'],
            'action_min_np': norm_data['action_min'], 'action_max_np': norm_data['action_max'],
        }

    def _load_demo_from_pkl(self, demo_path):
        """
        Load demo from low_dim_obs.pkl and extract states/actions.

        State format: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7) = 22
        Action format: joint_pos(7) + gripper_open(1) = 8
        """
        pkl_path = os.path.join(demo_path, "low_dim_obs.pkl")
        if not os.path.exists(pkl_path):
            raise FileNotFoundError(f"Demo file not found: {pkl_path}")

        log.info(f"Loading demo from {pkl_path}")
        with open(pkl_path, "rb") as f:
            demo = pickle.load(f)

        # Extract states and actions from observations
        states = []
        actions = []

        for obs in demo:
            # State: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
            state = np.concatenate([
                obs.joint_positions,           # 7
                obs.joint_velocities,          # 7
                [obs.gripper_open],            # 1
                obs.gripper_pose,              # 7 (pos + quat)
            ])
            states.append(state)

            # Action: from obs.misc['joint_position_action'] if available
            if hasattr(obs, 'misc') and 'joint_position_action' in obs.misc:
                action = obs.misc['joint_position_action']
            else:
                # Fallback: use joint_positions + gripper_open
                action = np.concatenate([obs.joint_positions, [obs.gripper_open]])
            actions.append(action)

        states = np.array(states)
        actions = np.array(actions)

        log.info(f"Loaded demo: {len(demo)} steps, states={states.shape}, actions={actions.shape}")

        # Adjust to match expected horizon
        horizon = self.cfg.horizon_steps
        if len(states) > horizon:
            log.info(f"Truncating demo from {len(states)} to {horizon} steps")
            states = states[:horizon]
            actions = actions[:horizon]
        elif len(states) < horizon:
            # Pad by repeating last state/action
            log.info(f"Padding demo from {len(states)} to {horizon} steps")
            pad_len = horizon - len(states)
            states = np.concatenate([states, np.tile(states[-1:], (pad_len, 1))], axis=0)
            actions = np.concatenate([actions, np.tile(actions[-1:], (pad_len, 1))], axis=0)

        return states, actions

    def _normalize_states_actions(self, states, actions):
        """Normalize states and actions to [-1, 1]."""
        obs_min = self.norm_stats['obs_min_np']
        obs_max = self.norm_stats['obs_max_np']
        action_min = self.norm_stats['action_min_np']
        action_max = self.norm_stats['action_max_np']

        obs_range = obs_max - obs_min
        obs_range = np.where(obs_range < 1e-6, 1.0, obs_range)
        action_range = action_max - action_min
        action_range = np.where(action_range < 1e-6, 1.0, action_range)

        # Normalize to [-1, 1]
        states_norm = 2.0 * (states - obs_min) / obs_range - 1.0
        actions_norm = 2.0 * (actions - action_min) / action_range - 1.0

        return states_norm, actions_norm

    def finetune_bc(self, demo_states, demo_actions):
        """
        Finetune BC model on the given demo.

        Args:
            demo_states: (T, obs_dim) normalized states
            demo_actions: (T, action_dim) normalized actions

        Returns:
            loss_history: list of loss values
        """
        log.info("Finetuning BC model on demo...")

        demo_states = torch.FloatTensor(demo_states).to(self.device)
        demo_actions = torch.FloatTensor(demo_actions).to(self.device)

        # Unfreeze model for training
        self.model.network.train()
        for param in self.model.network.parameters():
            param.requires_grad = True

        optimizer = torch.optim.Adam(
            self.model.network.parameters(),
            lr=self.cfg.finetune.lr
        )
        loss_history = []

        n_epochs = self.cfg.finetune.n_epochs

        pbar = tqdm(range(n_epochs), desc="Finetuning BC")
        for epoch in pbar:
            optimizer.zero_grad()

            # BC predicts full trajectory from first observation
            # Use first state as condition
            obs = demo_states[0:1].unsqueeze(0)  # (1, 1, obs_dim)

            # Target is full action trajectory
            target_actions = demo_actions.unsqueeze(0)  # (1, T, action_dim)

            # Forward pass
            pred_actions = self.model.network(obs)  # (1, horizon, action_dim)

            # Compute loss
            loss = F.mse_loss(pred_actions, target_actions)

            loss.backward()
            optimizer.step()

            loss_history.append(loss.item())
            pbar.set_postfix({'loss': f'{loss.item():.6f}'})

        # Set back to eval mode
        self.model.network.eval()
        for param in self.model.network.parameters():
            param.requires_grad = False

        return loss_history

    def finetune_bc_gmm_component_selection(self, demo_states, demo_actions):
        """
        BC-GMM Component Selection via Posterior.

        Instead of training, select the GMM component that best explains the demo:
            k* = argmax_k (log π_k + log N(a*; μ_k, Σ_k))

        This is analogous to fitting z for PDP - a "mode knob" without weight updates.

        Args:
            demo_states: (T, obs_dim) normalized states
            demo_actions: (T, action_dim) normalized actions

        Returns:
            selected_component: int, the index of the best component
            component_log_likelihoods: list of log-likelihoods for each component
        """
        log.info("BC-GMM: Selecting best component via posterior...")

        demo_states = torch.FloatTensor(demo_states).to(self.device)
        demo_actions = torch.FloatTensor(demo_actions).to(self.device)

        # Ensure model is in eval mode
        self.model.eval()

        with torch.no_grad():
            # Get state (first observation)
            state = demo_states[0:1]  # (1, obs_dim)

            # Get GMM parameters from network
            means, scales, logits = self.model.network(state)
            # means: (1, K, trajectory_dim)
            # scales: (1, K, trajectory_dim)
            # logits: (1, K) - unnormalized mixture weights

            # Flatten target trajectory
            target_traj = demo_actions.view(1, -1)  # (1, trajectory_dim)

            # Compute log mixture weights: log π_k
            log_weights = torch.log_softmax(logits, dim=-1)  # (1, K)

            # Compute per-component log-likelihood: log N(a*; μ_k, Σ_k)
            # For diagonal Gaussian: log N = -0.5 * sum((x - μ)² / σ² + log(2π) + 2*log(σ))
            K = means.shape[1]
            trajectory_dim = means.shape[2]

            component_log_likelihoods = []
            for k in range(K):
                mu_k = means[0, k]  # (trajectory_dim,)
                sigma_k = scales[0, k]  # (trajectory_dim,)

                # Log-likelihood of target under component k
                diff = target_traj[0] - mu_k  # (trajectory_dim,)
                log_prob = -0.5 * (
                    (diff ** 2 / (sigma_k ** 2)).sum()
                    + trajectory_dim * np.log(2 * np.pi)
                    + 2 * torch.log(sigma_k).sum()
                )

                # Add log mixture weight
                total_log_prob = log_weights[0, k] + log_prob
                component_log_likelihoods.append(total_log_prob.item())

                log.info(f"  Component {k}: log π_k = {log_weights[0, k].item():.4f}, "
                        f"log N(a*|μ,Σ) = {log_prob.item():.4f}, "
                        f"total = {total_log_prob.item():.4f}")

            # Select best component
            selected_component = int(np.argmax(component_log_likelihoods))
            log.info(f"\nSelected component: k* = {selected_component} "
                    f"(log-likelihood = {component_log_likelihoods[selected_component]:.4f})")

            # Also log the mean of the selected component for debugging
            selected_mean = means[0, selected_component].view(self.cfg.horizon_steps, self.cfg.action_dim)
            target_actions_reshaped = demo_actions.view(self.cfg.horizon_steps, self.cfg.action_dim)
            mse = F.mse_loss(selected_mean, target_actions_reshaped).item()
            log.info(f"MSE between selected component mean and demo: {mse:.6f}")

        return selected_component, component_log_likelihoods

    def finetune_ibc_demo_init(self, demo_states, demo_actions):
        """
        IBC Demo-Initialized Inference: Save demo trajectory for use during inference.

        Instead of training, use the demo to initialize/shape the inference optimizer:
        - No weight updates to E_theta
        - During inference, initialize candidates near the demo: a^(j)_0 ~ N(a*, sigma^2 I)
        - Run same inference optimizer (Langevin) to find low-energy trajectory

        This is analogous to PDP because:
        - No weight updates
        - Demo provides information by restricting the optimization search region
        - But optimized object is still high-dimensional (trajectory), so slower/less stable than z

        Args:
            demo_states: (T, obs_dim) normalized states
            demo_actions: (T, action_dim) normalized actions

        Returns:
            demo_trajectory: (1, horizon_steps, action_dim) demo trajectory for inference init
        """
        log.info("IBC: Preparing demo trajectory for demo-initialized inference...")

        demo_actions = torch.FloatTensor(demo_actions).to(self.device)

        # Ensure model is frozen (no training)
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

        # Reshape demo actions to trajectory format
        # demo_actions: (horizon_steps, action_dim) -> (1, horizon_steps, action_dim)
        demo_trajectory = demo_actions.unsqueeze(0)  # (1, T, action_dim)

        log.info(f"Demo trajectory shape: {demo_trajectory.shape}")
        log.info(f"Demo trajectory will be used to initialize Langevin/GD optimization")

        # Optionally, evaluate energy of demo trajectory
        with torch.no_grad():
            state = torch.FloatTensor(demo_states[0:1]).to(self.device)  # First state
            demo_flat = demo_trajectory.view(1, -1)  # (1, trajectory_dim)
            energy = self.model.network(state, demo_flat)
            log.info(f"Energy of demo trajectory: {energy.item():.4f}")

        return demo_trajectory

    def finetune_dp_inversion(self, demo_states, demo_actions):
        """
        DP Inversion: Optimize initial noise x_T to reconstruct demo trajectory.

        Freeze denoiser weights θ and optimize:
            x_T* = argmin_{x_T} ||Sample_θ(x_T) - a*||²

        Uses DDIM sampling (deterministic) so gradients can flow through.

        Args:
            demo_states: (T, obs_dim) normalized states
            demo_actions: (T, action_dim) normalized actions

        Returns:
            optimized_x_T: (1, horizon, action_dim) optimized initial noise
            loss_history: list of loss values
        """
        log.info("DP Inversion: Optimizing initial noise x_T...")

        demo_states = torch.FloatTensor(demo_states).to(self.device)
        demo_actions = torch.FloatTensor(demo_actions).to(self.device)

        # Ensure model is frozen
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

        # Initialize x_T as learnable parameter
        horizon_steps = self.cfg.horizon_steps
        action_dim = self.cfg.action_dim
        x_T = nn.Parameter(torch.randn(1, horizon_steps, action_dim, device=self.device))

        optimizer = torch.optim.Adam([x_T], lr=self.cfg.dp_inversion.lr)
        loss_history = []

        n_epochs = self.cfg.dp_inversion.n_epochs

        # Get DDIM parameters from model
        ddim_t = self.model.ddim_t
        ddim_alphas = self.model.ddim_alphas
        ddim_alphas_prev = self.model.ddim_alphas_prev
        ddim_sqrt_one_minus_alphas = self.model.ddim_sqrt_one_minus_alphas

        # Condition: first state
        cond = {'state': demo_states[0:1].unsqueeze(0)}  # (1, 1, obs_dim)

        # Target: full action trajectory
        target_actions = demo_actions.unsqueeze(0)  # (1, horizon, action_dim)

        pbar = tqdm(range(n_epochs), desc="DP Inversion")
        for epoch in pbar:
            optimizer.zero_grad()

            # Forward pass through DDIM sampler (deterministic)
            x = x_T.clone()

            for i, t in enumerate(ddim_t):
                t_b = torch.tensor([t], device=self.device).long()
                index_b = torch.tensor([i], device=self.device).long()

                # Get alpha values
                alpha = ddim_alphas[index_b].view(1, 1, 1)
                alpha_prev = ddim_alphas_prev[index_b].view(1, 1, 1)
                sqrt_one_minus_alpha = ddim_sqrt_one_minus_alphas[index_b].view(1, 1, 1)

                # Predict noise (frozen network)
                with torch.no_grad():
                    noise_pred = self.model.actor(x, t_b, cond=cond)

                # DDIM step (deterministic, eta=0)
                # x_0 = (x_t - sqrt(1-alpha) * eps) / sqrt(alpha)
                x_recon = (x - sqrt_one_minus_alpha * noise_pred) / (alpha ** 0.5)

                # Clip x_0
                if self.model.denoised_clip_value is not None:
                    x_recon = x_recon.clamp(-self.model.denoised_clip_value, self.model.denoised_clip_value)

                # mu = sqrt(alpha_prev) * x_0 + sqrt(1 - alpha_prev) * eps
                dir_xt = (1.0 - alpha_prev).sqrt() * noise_pred
                x = (alpha_prev ** 0.5) * x_recon + dir_xt

            # Compute loss between sampled trajectory and demo
            loss = F.mse_loss(x, target_actions)

            # Backward through x_T only (network is frozen)
            loss.backward()
            optimizer.step()

            loss_history.append(loss.item())
            pbar.set_postfix({'loss': f'{loss.item():.6f}', 'x_T_norm': f'{x_T.norm().item():.3f}'})

        return x_T.detach(), loss_history

    def run(self):
        """Run the finetune process."""
        log.info("=" * 70)
        log.info(f"Finetune {self.model_type.upper()} for close_drawer (wall_style={self.wall_style})")
        log.info("=" * 70)

        # 1. Load target demo from block_setting
        demo_path = os.path.join(
            self.cfg.demo_base_path,
            f'style{self.wall_style}',
            'episodes', 'episode0'
        )

        checkpoint_path = None
        x_T_path = None
        selected_component_path = None
        demo_trajectory_path = None
        loss_history = []

        if self.do_training:
            # Load full demo (states/actions) from low_dim_obs.pkl
            try:
                states, actions = self._load_demo_from_pkl(demo_path)
                states_norm, actions_norm = self._normalize_states_actions(states, actions)
                log.info(f"Normalized states: {states_norm.shape}, actions: {actions_norm.shape}")

                # Finetune the model based on type
                log.info("\n--- Finetuning ---")
                if self.model_type == 'bc':
                    loss_history = self.finetune_bc(states_norm, actions_norm)

                    if loss_history:
                        log.info(f"Final loss: {loss_history[-1]:.6f}")

                        # Save finetuned checkpoint in the same format as pretrained model
                        checkpoint_path = os.path.join(self.output_dir, f'finetuned_model_wall{self.wall_style}.pt')
                        torch.save({
                            'model': self.model.state_dict(),
                            'wall_style': self.wall_style,
                        }, checkpoint_path)
                        log.info(f"Saved finetuned model to {checkpoint_path}")

                        # Plot loss curve
                        self._plot_loss_curve(loss_history)

                elif self.model_type == 'dp':
                    # DP Inversion: optimize x_T
                    optimized_x_T, loss_history = self.finetune_dp_inversion(states_norm, actions_norm)

                    if loss_history:
                        log.info(f"Final loss: {loss_history[-1]:.6f}")

                        # Save optimized x_T
                        x_T_path = os.path.join(self.output_dir, f'optimized_x_T_wall{self.wall_style}.npy')
                        np.save(x_T_path, optimized_x_T.cpu().numpy())
                        log.info(f"Saved optimized x_T to {x_T_path} (shape: {optimized_x_T.shape})")

                        # Plot loss curve
                        self._plot_loss_curve(loss_history)

                elif self.model_type == 'bc_gmm':
                    # BC-GMM: Component selection via posterior (no weight updates)
                    selected_component, component_log_likelihoods = self.finetune_bc_gmm_component_selection(
                        states_norm, actions_norm
                    )

                    # Save selected component info
                    selected_component_path = os.path.join(
                        self.output_dir, f'selected_component_wall{self.wall_style}.npy'
                    )
                    np.save(selected_component_path, {
                        'selected_component': selected_component,
                        'component_log_likelihoods': component_log_likelihoods,
                        'wall_style': self.wall_style,
                    })
                    log.info(f"Saved selected component info to {selected_component_path}")

                elif self.model_type == 'ibc':
                    # IBC: Demo-initialized inference (no weight updates)
                    demo_traj = self.finetune_ibc_demo_init(states_norm, actions_norm)

                    # Save demo trajectory for use during inference
                    demo_trajectory_path = os.path.join(
                        self.output_dir, f'demo_trajectory_wall{self.wall_style}.npy'
                    )
                    np.save(demo_trajectory_path, demo_traj.cpu().numpy())
                    log.info(f"Saved demo trajectory to {demo_trajectory_path}")

                else:
                    log.warning(f"Training not implemented for {self.model_type}, skipping...")

            except FileNotFoundError as e:
                log.error(f"Demo not found: {e}")
                log.error("Please run: python RLBench_close_drawer/make_dataset/visualize_wall_trajectories.py --style {self.wall_style}")
                log.error("This will generate the full demo data needed for finetuning.")
                return
        else:
            log.info(f"\n--- No training for {self.model_type.upper()} (inference only) ---")
            # Just load metadata to confirm demo exists
            metadata_path = os.path.join(demo_path, 'metadata.npy')
            if os.path.exists(metadata_path):
                metadata = np.load(metadata_path, allow_pickle=True).item()
                log.info(f"Target demo metadata: wall_style={metadata.get('style')}")
            else:
                log.warning(f"Demo metadata not found at {metadata_path}")

        # Run eval agent
        log.info("\n" + "=" * 70)
        log.info("Running evaluation...")
        log.info("=" * 70)

        self._run_eval(checkpoint_path, x_T_path, selected_component_path, demo_trajectory_path)

    def _plot_loss_curve(self, loss_history):
        """Plot and save the finetuning loss curve."""
        if len(loss_history) == 0:
            return

        plt.figure(figsize=(10, 6))
        plt.plot(loss_history)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(f'{self.model_type.upper()} Finetuning Loss (wall_style={self.wall_style})')
        plt.grid(True, alpha=0.3)
        save_path = os.path.join(self.output_dir, 'loss_curve.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        log.info(f"Saved loss curve to {save_path}")

    def _run_eval(self, checkpoint_path=None, x_T_path=None, selected_component_path=None, demo_trajectory_path=None):
        """Run the eval agent."""
        # Build the eval command
        eval_cmd = [
            "python", "script/run.py",
            f"--config-dir=cfg/rlbench/eval/close_drawer",
            f"--config-name={self.cfg.eval_config_name}",
            f"wall_style={self.wall_style}",
        ]

        # If we have a finetuned checkpoint, use it
        if checkpoint_path is not None:
            eval_cmd.append(f"base_policy_path={checkpoint_path}")

        # If we have optimized x_T (DP inversion), pass it to eval
        if x_T_path is not None:
            eval_cmd.append(f"x_T_file={x_T_path}")

        # If we have selected component (BC-GMM), pass it to eval
        if selected_component_path is not None:
            eval_cmd.append(f"selected_component_file={selected_component_path}")

        # If we have demo trajectory (IBC), pass it to eval
        if demo_trajectory_path is not None:
            eval_cmd.append(f"demo_trajectory_file={demo_trajectory_path}")

        log.info(f"Running eval command: {' '.join(eval_cmd)}")

        # Run from the dppo root directory
        result = subprocess.run(
            eval_cmd,
            cwd=dppo_root,
            capture_output=False,
        )

        if result.returncode != 0:
            log.error(f"Eval command failed with return code {result.returncode}")
        else:
            log.info("Evaluation completed successfully!")
