"""
Finetune Agent for baseline models (DP, BC, BC-GMM, IBC) on pick_place task.

For pick_place with TWO learned phases (REACH and CARRY), "finetuning" means:
- BC: Finetune the model on both REACH and CARRY segments of the wall-avoiding demo
- DP: Optimize x_T separately for REACH and CARRY phases (DP Inversion)
- BC-GMM: Select best GMM component separately for each phase via posterior
- IBC: Save demo trajectory for each phase to initialize Langevin inference

This agent:
1. Loads the target wall demo (from low_dim_obs.pkl)
2. Segments into REACH and CARRY phases
3. Applies phase-specific finetuning for each method
4. Automatically runs the eval agent with finetuned parameters
"""

import os
import sys
import logging
import subprocess
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import hydra

# Add paths
dppo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(dppo_root)
sys.path.append(os.path.join(dppo_root, 'RLBench'))

log = logging.getLogger(__name__)

# Phase configuration for pick_place
PHASE_CONFIG = {
    'reach': (0, 64),           # REACH phase: observations[0:64]
    'descend': (64, 72),
    'grasp': (72, 80),
    'lift': (80, 88),
    'carry': (88, 152),         # CARRY phase: observations[88:152]
    'descend_release': (152, 160),
    'release': (160, 168),
}

# EE position normalization bounds
EE_POS_MIN = np.array([0.0, -0.6, 0.0], dtype=np.float32)
EE_POS_MAX = np.array([1.0, 0.6, 1.6], dtype=np.float32)


def normalize_ee_position(ee_raw):
    """Normalize EE position from raw to [-1, 1] using robot workspace bounds."""
    ee_range = EE_POS_MAX - EE_POS_MIN
    return 2.0 * (ee_raw - EE_POS_MIN) / ee_range - 1.0


class FinetunePickPlaceBaselineAgent:
    """Finetune agent for baseline models (DP, BC, BC-GMM, IBC) on pick_place."""

    def __init__(self, cfg):
        self.cfg = cfg
        self.device = cfg.device
        self.style = cfg.style  # Single style for both demo and wall
        self.model_type = cfg.model_type  # 'dp', 'bc', 'bc_gmm', 'ibc'
        self.do_training = cfg.get('do_training', False)

        # Set random seed
        self.seed = cfg.seed
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        # Load normalization stats
        self.norm_stats = self._load_normalization(cfg.normalization_path)

        # Load model
        self.model = hydra.utils.instantiate(cfg.model)
        self.model.eval()

        # Load waypoints
        self._load_waypoints()

        # Create output directory
        self.output_dir = cfg.logdir
        os.makedirs(self.output_dir, exist_ok=True)

    def _load_normalization(self, norm_path):
        """Load normalization statistics."""
        log.info(f"Loading normalization from {norm_path}")
        norm_data = np.load(norm_path)

        obs_min = torch.from_numpy(norm_data['obs_min']).float().to(self.device)
        obs_max = torch.from_numpy(norm_data['obs_max']).float().to(self.device)
        action_min = torch.from_numpy(norm_data['action_min']).float().to(self.device)
        action_max = torch.from_numpy(norm_data['action_max']).float().to(self.device)

        obs_range = obs_max - obs_min
        obs_range = torch.where(obs_range < 1e-6, torch.ones_like(obs_range), obs_range)
        action_range = action_max - action_min
        action_range = torch.where(action_range < 1e-6, torch.ones_like(action_range), action_range)

        return {
            'obs_min': obs_min, 'obs_max': obs_max, 'obs_range': obs_range,
            'action_min': action_min, 'action_max': action_max, 'action_range': action_range,
            'obs_min_np': norm_data['obs_min'], 'obs_max_np': norm_data['obs_max'],
            'action_min_np': norm_data['action_min'], 'action_max_np': norm_data['action_max'],
        }

    def _load_waypoints(self):
        """Load waypoint positions from demo metadata."""
        summary_path = os.path.join(
            self.cfg.demo_base_path,
            f'style{self.style}',
            'summary.npy'
        )
        if os.path.exists(summary_path):
            summary = np.load(summary_path, allow_pickle=True).item()
            self.waypoints = {
                'home_pos': np.array(summary['home_pos']),
                'pregrasp_pos': np.array(summary['pregrasp_pos']),
                'lift_pos': np.array(summary['lift_pos']),
                'prerelease_pos': np.array(summary['prerelease_pos']),
            }
        else:
            init_path = os.path.join(
                os.path.dirname(self.cfg.demo_base_path),
                'make_dataset', 'stack_blocks_init.npz'
            )
            if os.path.exists(init_path):
                init_data = np.load(init_path)
                self.waypoints = {
                    'home_pos': init_data['home_pos'],
                    'pregrasp_pos': init_data['pregrasp_pos'],
                    'lift_pos': init_data['lift_pos'],
                    'prerelease_pos': init_data['prerelease_pos'],
                }
            else:
                raise FileNotFoundError(f"Cannot find waypoints.")

        self.subgoal_reach = normalize_ee_position(
            np.array(self.waypoints['pregrasp_pos'], dtype=np.float32)
        )
        self.subgoal_carry = normalize_ee_position(
            np.array(self.waypoints['prerelease_pos'], dtype=np.float32)
        )

    def _load_demo_from_pkl(self, demo_path):
        """
        Load demo and segment into REACH and CARRY phases.

        Returns:
            reach_states, reach_actions: (64, state_dim), (64, action_dim)
            carry_states, carry_actions: (64, state_dim), (64, action_dim)
        """
        pkl_path = os.path.join(demo_path, "low_dim_obs.pkl")
        if not os.path.exists(pkl_path):
            raise FileNotFoundError(f"Demo file not found: {pkl_path}")

        log.info(f"Loading demo from {pkl_path}")
        with open(pkl_path, "rb") as f:
            demo = pickle.load(f)

        all_states = []
        all_actions = []

        for obs in demo:
            state = np.concatenate([
                obs.joint_positions,
                obs.joint_velocities,
                [obs.gripper_open],
                obs.gripper_pose,
            ])
            all_states.append(state)

            if hasattr(obs, 'misc') and 'joint_position_action' in obs.misc:
                action = obs.misc['joint_position_action']
            else:
                action = np.concatenate([obs.joint_positions, [obs.gripper_open]])
            all_actions.append(action)

        all_states = np.array(all_states)
        all_actions = np.array(all_actions)

        # Segment
        reach_start, reach_end = PHASE_CONFIG['reach']
        carry_start, carry_end = PHASE_CONFIG['carry']

        reach_states = all_states[reach_start:reach_end]
        reach_actions = all_actions[reach_start:reach_end]
        carry_states = all_states[carry_start:carry_end]
        carry_actions = all_actions[carry_start:carry_end]

        log.info(f"REACH: states={reach_states.shape}, actions={reach_actions.shape}")
        log.info(f"CARRY: states={carry_states.shape}, actions={carry_actions.shape}")

        return reach_states, reach_actions, carry_states, carry_actions

    def _normalize_states_actions(self, states, actions):
        """Normalize states and actions to [-1, 1]."""
        obs_min = self.norm_stats['obs_min_np']
        obs_max = self.norm_stats['obs_max_np']
        action_min = self.norm_stats['action_min_np']
        action_max = self.norm_stats['action_max_np']

        obs_range = obs_max - obs_min
        obs_range = np.where(obs_range < 1e-6, 1.0, obs_range)
        action_range = action_max - action_min
        action_range = np.where(action_range < 1e-6, 1.0, action_range)

        states_norm = 2.0 * (states - obs_min) / obs_range - 1.0
        actions_norm = 2.0 * (actions - action_min) / action_range - 1.0

        return states_norm, actions_norm

    def finetune_bc(self, reach_states, reach_actions, carry_states, carry_actions):
        """
        Finetune BC model on both REACH and CARRY segments.

        For pick_place BC, we train on both phases together since
        the model conditions on subgoal to differentiate phases.
        """
        log.info("Finetuning BC model on REACH and CARRY segments...")

        reach_states_t = torch.FloatTensor(reach_states).to(self.device)
        reach_actions_t = torch.FloatTensor(reach_actions).to(self.device)
        carry_states_t = torch.FloatTensor(carry_states).to(self.device)
        carry_actions_t = torch.FloatTensor(carry_actions).to(self.device)

        subgoal_reach_t = torch.FloatTensor(self.subgoal_reach).to(self.device)
        subgoal_carry_t = torch.FloatTensor(self.subgoal_carry).to(self.device)

        # Unfreeze model for training
        self.model.network.train()
        for param in self.model.network.parameters():
            param.requires_grad = True

        optimizer = torch.optim.Adam(
            self.model.network.parameters(),
            lr=self.cfg.finetune.lr
        )
        loss_history = []

        n_epochs = self.cfg.finetune.n_epochs

        pbar = tqdm(range(n_epochs), desc="Finetuning BC")
        for epoch in pbar:
            optimizer.zero_grad()
            total_loss = 0.0

            # REACH phase
            reach_obs = torch.cat([
                reach_states_t[0:1],
                subgoal_reach_t.unsqueeze(0)
            ], dim=-1).unsqueeze(0)  # (1, 1, 25)
            reach_pred = self.model.network(reach_obs)
            reach_loss = F.mse_loss(reach_pred, reach_actions_t.unsqueeze(0))
            total_loss += reach_loss

            # CARRY phase
            carry_obs = torch.cat([
                carry_states_t[0:1],
                subgoal_carry_t.unsqueeze(0)
            ], dim=-1).unsqueeze(0)  # (1, 1, 25)
            carry_pred = self.model.network(carry_obs)
            carry_loss = F.mse_loss(carry_pred, carry_actions_t.unsqueeze(0))
            total_loss += carry_loss

            total_loss.backward()
            optimizer.step()

            loss_history.append(total_loss.item())
            pbar.set_postfix({'loss': f'{total_loss.item():.6f}'})

        # Set back to eval mode
        self.model.network.eval()
        for param in self.model.network.parameters():
            param.requires_grad = False

        return loss_history

    def finetune_dp_inversion(self, reach_states, reach_actions, carry_states, carry_actions):
        """
        DP Inversion: Optimize x_T separately for REACH and CARRY phases.

        Returns optimized x_T for both phases.
        """
        log.info("DP Inversion: Optimizing x_T for REACH and CARRY phases...")

        # Ensure model is frozen
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

        horizon_steps = self.cfg.horizon_steps
        action_dim = self.cfg.action_dim

        reach_states_t = torch.FloatTensor(reach_states).to(self.device)
        reach_actions_t = torch.FloatTensor(reach_actions).to(self.device)
        carry_states_t = torch.FloatTensor(carry_states).to(self.device)
        carry_actions_t = torch.FloatTensor(carry_actions).to(self.device)

        subgoal_reach_t = torch.FloatTensor(self.subgoal_reach).to(self.device)
        subgoal_carry_t = torch.FloatTensor(self.subgoal_carry).to(self.device)

        # Optimize x_T for REACH
        log.info("Optimizing x_T for REACH...")
        x_T_reach, loss_reach = self._optimize_x_T(
            reach_states_t, reach_actions_t, subgoal_reach_t, 'reach'
        )

        # Optimize x_T for CARRY
        log.info("Optimizing x_T for CARRY...")
        x_T_carry, loss_carry = self._optimize_x_T(
            carry_states_t, carry_actions_t, subgoal_carry_t, 'carry'
        )

        return x_T_reach, x_T_carry, loss_reach + loss_carry

    def _optimize_x_T(self, states, actions, subgoal, phase_name):
        """Optimize x_T for a single phase using DDIM inversion."""
        horizon_steps = self.cfg.horizon_steps
        action_dim = self.cfg.action_dim

        x_T = nn.Parameter(torch.randn(1, horizon_steps, action_dim, device=self.device))
        optimizer = torch.optim.Adam([x_T], lr=self.cfg.dp_inversion.lr)
        loss_history = []

        # Get DDIM parameters
        ddim_t = self.model.ddim_t
        ddim_alphas = self.model.ddim_alphas
        ddim_alphas_prev = self.model.ddim_alphas_prev
        ddim_sqrt_one_minus_alphas = self.model.ddim_sqrt_one_minus_alphas

        # Condition
        first_state = states[0:1]
        cond_state = torch.cat([first_state, subgoal.unsqueeze(0)], dim=-1).unsqueeze(0)
        cond = {'state': cond_state}

        target_actions = actions.unsqueeze(0)

        n_epochs = self.cfg.dp_inversion.n_epochs

        pbar = tqdm(range(n_epochs), desc=f"DP Inversion ({phase_name})")
        for epoch in pbar:
            optimizer.zero_grad()

            x = x_T.clone()

            for i, t in enumerate(ddim_t):
                t_b = torch.tensor([t], device=self.device).long()
                index_b = torch.tensor([i], device=self.device).long()

                alpha = ddim_alphas[index_b].view(1, 1, 1)
                alpha_prev = ddim_alphas_prev[index_b].view(1, 1, 1)
                sqrt_one_minus_alpha = ddim_sqrt_one_minus_alphas[index_b].view(1, 1, 1)

                with torch.no_grad():
                    noise_pred = self.model.actor(x, t_b, cond=cond)

                x_recon = (x - sqrt_one_minus_alpha * noise_pred) / (alpha ** 0.5)

                if self.model.denoised_clip_value is not None:
                    x_recon = x_recon.clamp(-self.model.denoised_clip_value, self.model.denoised_clip_value)

                dir_xt = (1.0 - alpha_prev).sqrt() * noise_pred
                x = (alpha_prev ** 0.5) * x_recon + dir_xt

            loss = F.mse_loss(x, target_actions)

            loss.backward()
            optimizer.step()

            loss_history.append(loss.item())
            pbar.set_postfix({'loss': f'{loss.item():.6f}'})

        return x_T.detach(), loss_history

    def finetune_bc_gmm_component_selection(self, reach_states, reach_actions, carry_states, carry_actions):
        """
        BC-GMM: Select best component separately for REACH and CARRY.
        """
        log.info("BC-GMM: Selecting best components for REACH and CARRY...")

        self.model.eval()

        reach_states_t = torch.FloatTensor(reach_states).to(self.device)
        reach_actions_t = torch.FloatTensor(reach_actions).to(self.device)
        carry_states_t = torch.FloatTensor(carry_states).to(self.device)
        carry_actions_t = torch.FloatTensor(carry_actions).to(self.device)

        subgoal_reach_t = torch.FloatTensor(self.subgoal_reach).to(self.device)
        subgoal_carry_t = torch.FloatTensor(self.subgoal_carry).to(self.device)

        # Select component for REACH
        reach_component, reach_ll = self._select_gmm_component(
            reach_states_t, reach_actions_t, subgoal_reach_t, 'reach'
        )

        # Select component for CARRY
        carry_component, carry_ll = self._select_gmm_component(
            carry_states_t, carry_actions_t, subgoal_carry_t, 'carry'
        )

        return reach_component, carry_component, reach_ll, carry_ll

    def _select_gmm_component(self, states, actions, subgoal, phase_name):
        """Select best GMM component for a phase."""
        with torch.no_grad():
            state = torch.cat([states[0:1], subgoal.unsqueeze(0)], dim=-1)
            means, scales, logits = self.model.network(state)

            target_traj = actions.view(1, -1)
            log_weights = torch.log_softmax(logits, dim=-1)

            K = means.shape[1]
            trajectory_dim = means.shape[2]

            component_log_likelihoods = []
            for k in range(K):
                mu_k = means[0, k]
                sigma_k = scales[0, k]

                diff = target_traj[0] - mu_k
                log_prob = -0.5 * (
                    (diff ** 2 / (sigma_k ** 2)).sum()
                    + trajectory_dim * np.log(2 * np.pi)
                    + 2 * torch.log(sigma_k).sum()
                )

                total_log_prob = log_weights[0, k] + log_prob
                component_log_likelihoods.append(total_log_prob.item())

                log.info(f"  [{phase_name}] Component {k}: total_log_prob = {total_log_prob.item():.4f}")

            selected_component = int(np.argmax(component_log_likelihoods))
            log.info(f"  [{phase_name}] Selected component: {selected_component}")

        return selected_component, component_log_likelihoods

    def finetune_ibc_demo_init(self, reach_states, reach_actions, carry_states, carry_actions):
        """
        IBC: Save demo trajectories for REACH and CARRY to initialize Langevin inference.
        """
        log.info("IBC: Preparing demo trajectories for REACH and CARRY...")

        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

        reach_actions_t = torch.FloatTensor(reach_actions).to(self.device)
        carry_actions_t = torch.FloatTensor(carry_actions).to(self.device)

        reach_trajectory = reach_actions_t.unsqueeze(0)
        carry_trajectory = carry_actions_t.unsqueeze(0)

        log.info(f"REACH demo trajectory shape: {reach_trajectory.shape}")
        log.info(f"CARRY demo trajectory shape: {carry_trajectory.shape}")

        return reach_trajectory, carry_trajectory

    def run(self):
        """Run the two-phase finetune process for baselines."""
        log.info("=" * 70)
        log.info(f"Finetune {self.model_type.upper()} for pick_place (style={self.style})")
        log.info("=" * 70)

        # Load target demo
        demo_path = os.path.join(
            self.cfg.demo_base_path,
            f'style{self.style}',
            'episodes', 'episode0'
        )

        checkpoint_path = None
        x_T_path = None
        selected_component_path = None
        demo_trajectory_path = None
        loss_history = []

        if self.do_training:
            try:
                reach_states, reach_actions, carry_states, carry_actions = self._load_demo_from_pkl(demo_path)
                reach_states_norm, reach_actions_norm = self._normalize_states_actions(reach_states, reach_actions)
                carry_states_norm, carry_actions_norm = self._normalize_states_actions(carry_states, carry_actions)

                log.info("\n--- Finetuning ---")

                if self.model_type == 'bc':
                    loss_history = self.finetune_bc(
                        reach_states_norm, reach_actions_norm,
                        carry_states_norm, carry_actions_norm
                    )

                    if loss_history:
                        log.info(f"Final loss: {loss_history[-1]:.6f}")

                        checkpoint_path = os.path.join(
                            self.output_dir, f'finetuned_model_style{self.style}.pt'
                        )
                        torch.save({
                            'model': self.model.state_dict(),
                            'style': self.style,
                        }, checkpoint_path)
                        log.info(f"Saved finetuned model to {checkpoint_path}")

                        self._plot_loss_curve(loss_history)

                elif self.model_type == 'dp':
                    x_T_reach, x_T_carry, loss_history = self.finetune_dp_inversion(
                        reach_states_norm, reach_actions_norm,
                        carry_states_norm, carry_actions_norm
                    )

                    if loss_history:
                        log.info(f"Final REACH loss: {loss_history[len(loss_history)//2 - 1]:.6f}")
                        log.info(f"Final CARRY loss: {loss_history[-1]:.6f}")

                        # Save both x_T values
                        x_T_path = os.path.join(
                            self.output_dir, f'optimized_x_T_style{self.style}.npz'
                        )
                        np.savez(
                            x_T_path,
                            x_T_reach=x_T_reach.cpu().numpy(),
                            x_T_carry=x_T_carry.cpu().numpy()
                        )
                        log.info(f"Saved optimized x_T to {x_T_path}")

                        self._plot_loss_curve(loss_history)

                elif self.model_type == 'bc_gmm':
                    reach_comp, carry_comp, reach_ll, carry_ll = self.finetune_bc_gmm_component_selection(
                        reach_states_norm, reach_actions_norm,
                        carry_states_norm, carry_actions_norm
                    )

                    selected_component_path = os.path.join(
                        self.output_dir, f'selected_components_style{self.style}.npy'
                    )
                    np.save(selected_component_path, {
                        'reach_component': reach_comp,
                        'carry_component': carry_comp,
                        'reach_log_likelihoods': reach_ll,
                        'carry_log_likelihoods': carry_ll,
                        'style': self.style,
                    })
                    log.info(f"Saved selected components to {selected_component_path}")

                elif self.model_type == 'ibc':
                    reach_traj, carry_traj = self.finetune_ibc_demo_init(
                        reach_states_norm, reach_actions_norm,
                        carry_states_norm, carry_actions_norm
                    )

                    demo_trajectory_path = os.path.join(
                        self.output_dir, f'demo_trajectories_style{self.style}.npz'
                    )
                    np.savez(
                        demo_trajectory_path,
                        reach_trajectory=reach_traj.cpu().numpy(),
                        carry_trajectory=carry_traj.cpu().numpy()
                    )
                    log.info(f"Saved demo trajectories to {demo_trajectory_path}")

                else:
                    log.warning(f"Training not implemented for {self.model_type}")

            except FileNotFoundError as e:
                log.error(f"Demo not found: {e}")
                return
        else:
            log.info(f"\n--- No training for {self.model_type.upper()} (inference only) ---")

        # Run eval
        log.info("\n" + "=" * 70)
        log.info("Running evaluation...")
        log.info("=" * 70)

        self._run_eval(checkpoint_path, x_T_path, selected_component_path, demo_trajectory_path)

    def _plot_loss_curve(self, loss_history):
        """Plot and save the finetuning loss curve."""
        if len(loss_history) == 0:
            return

        plt.figure(figsize=(10, 6))
        plt.plot(loss_history)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(f'{self.model_type.upper()} Finetuning Loss (style={self.style})')
        plt.grid(True, alpha=0.3)
        save_path = os.path.join(self.output_dir, 'loss_curve.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        log.info(f"Saved loss curve to {save_path}")

    def _run_eval(self, checkpoint_path=None, x_T_path=None, selected_component_path=None, demo_trajectory_path=None):
        """Run the eval agent with 40 samples."""
        env_name = self.cfg.get('env_name', 'pick_place')
        eval_cmd = [
            "python", "script/run.py",
            f"--config-dir=cfg/rlbench/eval/{env_name}",
            f"--config-name={self.cfg.eval_config_name}",
            f"wall_style={self.style}",
            f"n_eval_episodes=40",  # 40 evaluation samples
        ]

        if checkpoint_path is not None:
            eval_cmd.append(f"base_policy_path={checkpoint_path}")

        if x_T_path is not None:
            eval_cmd.append(f"x_T_file={x_T_path}")

        if selected_component_path is not None:
            eval_cmd.append(f"selected_component_file={selected_component_path}")

        if demo_trajectory_path is not None:
            eval_cmd.append(f"demo_trajectory_file={demo_trajectory_path}")

        log.info(f"Running eval command: {' '.join(eval_cmd)}")

        result = subprocess.run(
            eval_cmd,
            cwd=dppo_root,
            capture_output=False,
        )

        if result.returncode != 0:
            log.error(f"Eval command failed with return code {result.returncode}")
        else:
            log.info("Evaluation completed successfully!")
