#!/usr/bin/env python3
"""
Franka FE Calibration-Only Real Robot Evaluation
===============================================
Based on run_franka_eval.py and new eval-scripts architecture.
Implements calibration-only evaluation on Franka real robot.
"""

import gc
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union, List, Dict
from datetime import datetime

import cv2
import draccus
import gym_franka
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

# Add project paths
sys.path.append(".")
sys.path.append("experiments/robot")

from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
from huggingface_hub import snapshot_download
from peft import LoraConfig, get_peft_model

from experiments.robot.libero.libero_utils import save_rollout_video
from experiments.robot.robot_utils import (
    DATE_TIME,
    normalize_gripper_action,
    set_seed_everywhere,
)
from experiments.robot.openvla_utils import (
    check_model_logic_mismatch,
    model_is_on_hf_hub,
    update_auto_map,
)

from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
from prismatic.models.action_heads import FunctionEncoderActionHead
from prismatic.models.backbones.llm.prompting import PurePromptBuilder
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.constants import ACTION_DIM, NUM_ACTIONS_CHUNK
from prismatic.vla.datasets.datasets import RLDSBatchTransform


@dataclass
class FrankaEvalConfig:
    """Configuration for Franka real robot evaluation."""
    # fmt: off
    
    # Model parameters
    vla_path: str = "openvla/openvla-7b"
    fe_checkpoint_path: Optional[str] = None  # Path to FE pre-trained checkpoint
    fe_basis_functions: int = 32
    n_continuous_actions: int = 6
    use_lora: bool = True
    lora_rank: int = 32
    
    # Demo parameters  
    demo_path: str = "$DEMO_PATH"
    
    # Evaluation parameters
    num_trials_per_task: int = 5
    max_steps: int = 60
    disable_rotation: bool = False
    
    # Logging
    run_id_note: Optional[str] = None
    local_log_dir: str = "./franka-scripts/logs"
    seed: int = 7
    
    # fmt: on


class FrankaFECalibrationEvaluator:
    """Franka real robot evaluator with FE calibration-only capability."""

    def __init__(self, cfg: FrankaEvalConfig):
        """Initialize evaluator."""
        self.cfg = cfg
        set_seed_everywhere(cfg.seed)
        
        # Load model and FE components
        self.vla, self.processor, self.fe_action_head = self._load_model()
        
        # Setup tokenizers and transforms
        self.tokenizer = self.processor.tokenizer
        self.action_tokenizer = ActionTokenizer(self.tokenizer)
        
        self.batch_transform = RLDSBatchTransform(
            action_tokenizer=self.action_tokenizer,
            base_tokenizer=self.tokenizer,
            image_transform=self.processor.image_processor.apply_transform,
            prompt_builder_fn=PurePromptBuilder,
            predict_stop_token=True,
        )
        
        self.collator = PaddedCollatorForActionPrediction(
            self.tokenizer.model_max_length, self.tokenizer.pad_token_id, padding_side="right"
        )
        
        # Current calibration state
        self.last_task_id = None
        self.is_calibrated = False

    def _load_model(self):
        """Load VLA model and FE components based on pretrain_fe.py architecture."""
        print(f"Loading base VLA model: {self.cfg.vla_path}")
        
        # Download model if needed
        if model_is_on_hf_hub(self.cfg.vla_path):
            print(f"Downloading model from HuggingFace Hub: {self.cfg.vla_path}")
            cached_model_path = snapshot_download(self.cfg.vla_path)
        else:
            cached_model_path = self.cfg.vla_path

        # Load configuration and processor
        config = AutoConfig.from_pretrained(cached_model_path, trust_remote_code=True)
        processor = AutoProcessor.from_pretrained(cached_model_path, trust_remote_code=True)
        
        # Load base VLA model
        vla = OpenVLAForActionPrediction.from_pretrained(
            cached_model_path,
            config=config,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            trust_remote_code=True,
            device_map="cuda",
        )

        # Create Function Encoder Action Head
        llm_dim = vla.llm_backbone.embed_dim
        fe_action_head = FunctionEncoderActionHead(
            input_dim=llm_dim * ACTION_DIM,
            hidden_dim=llm_dim,
            action_dim=ACTION_DIM,
            k=self.cfg.fe_basis_functions,
            n_continuous=self.cfg.n_continuous_actions,
        )
        fe_action_head.to(device="cuda", dtype=torch.bfloat16)
        
        # Apply LoRA if enabled
        if self.cfg.use_lora:
            print("Applying LoRA configuration...")
            lora_config = LoraConfig(
                r=self.cfg.lora_rank,
                lora_alpha=min(self.cfg.lora_rank, 16),
                lora_dropout=0.1,
                target_modules="all-linear",
            )
            vla = get_peft_model(vla, lora_config)
            
        # Load FE checkpoint if provided
        if self.cfg.fe_checkpoint_path:
            print(f"Loading FE checkpoint: {self.cfg.fe_checkpoint_path}")
            checkpoint = torch.load(self.cfg.fe_checkpoint_path, map_location="cuda")
            
            if "fe_action_head" in checkpoint:
                fe_action_head.load_state_dict(checkpoint["fe_action_head"])
                print("Loaded FE action head weights")
            
            if "model" in checkpoint and hasattr(vla, 'load_state_dict'):
                vla.load_state_dict(checkpoint["model"], strict=False)
                print("Loaded VLA model weights")
        
        vla.eval()
        fe_action_head.eval()
        
        return vla, processor, fe_action_head

    def process_image(self, img: np.ndarray) -> np.ndarray:
        """Process image data for model input."""
        if img.dtype == np.int64:
            img = img.astype(np.uint8)
        img = cv2.resize(img, (224, 224))
        return img

    def calibrate_with_demo(self, demo_path: str, instruction: str):
        """Calibrate FE with demonstration data."""
        print(f"Calibrating FE with demo: {demo_path}")
        
        # Prepare calibration data
        calibration_traj = []
        images = np.load(f"{demo_path}/images.npy")[..., :3]  # remove alpha channel
        actions = np.load(f"{demo_path}/actions.npy")
        
        # clip gripper action to OXE range
        actions[:, -1] = actions[:, -1].clip(0, 1)
        
        for i in range(len(images)):
            vla_step = {
                "observation": {
                    "image_primary": torch.tensor(images[i].copy()).unsqueeze(0),
                    "timestep": torch.tensor(i),
                    "pad_mask_dict": {
                        "image_primary": torch.tensor([True]),
                        "timestep": torch.tensor([True]),
                    },
                    "pad_mask": torch.tensor([True]),
                },
                "task": {
                    "language_instruction": instruction,
                    "pad_mask_dict": {
                        "language_instruction": torch.tensor([True]),
                    },
                },
                "action": torch.tensor(actions[i]).unsqueeze(0),
                "dataset_name": "Franka",
                "absolute_action_mask": torch.tensor([False] * 6 + [True]),
            }
            calibration_traj.append(self.batch_transform(vla_step))
        
        # Execute calibration
        self._calibrate_fe_model(calibration_traj)
        self.is_calibrated = True
        print("FE calibration completed!")

    def _calibrate_fe_model(self, calibration_traj: List[Dict], batch_size: int = 8):
        """Execute FE model calibration process."""
        print("Running FE calibration...")
        
        # Prepare data batches
        all_hidden_states = []
        all_actions = []
        
        with torch.no_grad():
            for i in range(0, len(calibration_traj), batch_size):
                batch_data = calibration_traj[i:i + batch_size]
                batched_inputs = self.collator(batch_data)
                
                # Forward pass through VLA to get hidden states
                input_ids = batched_inputs["input_ids"].to("cuda")
                attention_mask = batched_inputs["attention_mask"].to("cuda")
                pixel_values = {k: v.to("cuda") for k, v in batched_inputs["pixel_values"].items()}
                labels = (batched_inputs["labels"] - self.action_tokenizer.action_token_begin_idx - 1).to("cuda")
                ground_truth_actions = batched_inputs["ground_truth_actions"].to("cuda")
                
                with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
                    # Get hidden states from VLA backbone
                    outputs = self.vla.language_model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        inputs_embeds=None,
                        pixel_values=pixel_values,
                    )
                    
                    # Extract hidden states for action tokens
                    last_hidden_state = outputs.last_hidden_state
                    action_preds_idx = torch.nonzero(labels >= 0, as_tuple=True)
                    if len(action_preds_idx[0]) > 0:
                        action_hidden_states = last_hidden_state[action_preds_idx[0], action_preds_idx[1]]
                        
                        # Reshape to match FE input format
                        B = ground_truth_actions.shape[0]
                        action_hidden_states_reshaped = action_hidden_states.view(B, -1)
                        
                        all_hidden_states.append(action_hidden_states_reshaped)
                        all_actions.append(ground_truth_actions)
        
        if all_hidden_states:
            # Concatenate all data
            hidden_states = torch.cat(all_hidden_states, dim=0)
            actions = torch.cat(all_actions, dim=0)
            
            # Use FE calibration (CVX optimization)
            with torch.no_grad():
                self.fe_action_head.calibrate_on_batch(
                    hidden_states, actions, dataset_names=["Franka"] * len(actions)
                )
            
            print(f"Calibrated FE with {len(actions)} samples")

    def predict_action(self, obs: Dict, instruction: str, timestep: int) -> np.ndarray:
        """Predict action using FE model."""
        img = self.process_image(obs["agentview_rgb"])
        
        # Prepare VLA input
        action = np.zeros(7)
        vla_step = {
            "observation": {
                "image_primary": torch.tensor(img).unsqueeze(0),
                "timestep": torch.tensor(timestep),
                "pad_mask_dict": {
                    "image_primary": torch.tensor([True]),
                    "timestep": torch.tensor([True]),
                },
                "pad_mask": torch.tensor([True]),
            },
            "task": {
                "language_instruction": instruction,
                "pad_mask_dict": {
                    "language_instruction": torch.tensor([True]),
                },
            },
            "action": torch.tensor(action).unsqueeze(0),
            "dataset_name": "Franka",
            "absolute_action_mask": torch.tensor([False] * 6 + [True]),
        }
        
        batched_inputs = self.collator([self.batch_transform(vla_step)])
        
        with torch.no_grad():
            with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
                input_ids = batched_inputs["input_ids"].to("cuda")
                attention_mask = batched_inputs["attention_mask"].to("cuda")
                pixel_values = {k: v.to("cuda") for k, v in batched_inputs["pixel_values"].items()}
                labels = (batched_inputs["labels"] - self.action_tokenizer.action_token_begin_idx - 1).to("cuda")
                ground_truth_actions = batched_inputs["ground_truth_actions"].to("cuda")
                
                # Forward pass through model
                outputs = self.vla.language_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    inputs_embeds=None,
                    pixel_values=pixel_values,
                )
                
                # Extract hidden states and use FE for action prediction
                last_hidden_state = outputs.last_hidden_state
                action_preds_idx = torch.nonzero(labels >= 0, as_tuple=True)
                
                if len(action_preds_idx[0]) > 0:
                    action_hidden_states = last_hidden_state[action_preds_idx[0], action_preds_idx[1]]
                    action_hidden_states_reshaped = action_hidden_states.view(1, -1)
                    
                    # Predict action using FE
                    action_preds = self.fe_action_head.predict_action(
                        action_hidden_states_reshaped, dataset_names=["Franka"]
                    )
                    
                    action = action_preds.clip(-1, 1)[0].detach().cpu().to(torch.float32).numpy()
        
        return action


@draccus.wrap()
def run_franka_evaluation(cfg: FrankaEvalConfig) -> None:
    """Run Franka real robot evaluation."""
    print(f"Starting Franka FE Calibration-Only Evaluation")
    print(f"Configuration: {cfg}")
    
    # Initialize evaluator
    evaluator = FrankaFECalibrationEvaluator(cfg)
    
    # Initialize logging
    run_id = f"FRANKA-FE-EVAL-{DATE_TIME}"
    if cfg.run_id_note:
        run_id += f"--{cfg.run_id_note}"
    
    os.makedirs(cfg.local_log_dir, exist_ok=True)
    local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".json")
    
    results = {
        "run_id": run_id,
        "config": cfg.__dict__,
        "experiment_type": "FE_Calibration_Only",
        "timestamp": datetime.now().isoformat(),
        "tasks": {}
    }
    
    # Initialize environment
    env = None
    total_episodes = 0
    
    print("\nStarting interactive Franka evaluation...")
    print("Commands: Enter task_id to run task, 'p' to pause/reset env, 'q' to quit")
    
    while True:
        task_id = input("Enter task ID: ")
        
        if task_id == "q":
            break
        elif task_id == "p":
            if env:
                env.close()
                del env
                env = None
                print("Environment closed")
            continue
        
        # Initialize environment if needed
        if env is None:
            env = gym.make("FrankaRGB-v1")
        
        # Load task instruction
        task_demo_path = f"{cfg.demo_path}/{task_id}"
        instruction_file = f"{task_demo_path}/language_instruction.txt"
        
        if not os.path.exists(instruction_file):
            print(f"Task instruction file not found: {instruction_file}")
            continue
            
        with open(instruction_file, "r") as f:
            instruction = f.read().strip()
        
        # Calibrate if new task or first time
        if task_id != evaluator.last_task_id:
            if os.path.exists(task_demo_path):
                evaluator.calibrate_with_demo(task_demo_path, instruction)
                evaluator.last_task_id = task_id
            else:
                print(f"Demo path not found: {task_demo_path}")
                continue
        
        print(f"\nTask: {instruction}")
        
        # Run multiple trials for this task
        task_results = []
        for trial in range(cfg.num_trials_per_task):
            print(f"\nTrial {trial + 1}/{cfg.num_trials_per_task}")
            
            # Reset environment
            obs, info = env.reset()
            
            # Initial positioning (from original run_franka_eval.py)
            for _ in range(10):
                obs, _, _, _, _ = env.step([1, 0, -1, 0, 0, 0, 0])
            for _ in range(8):
                obs, _, _, _, _ = env.step([0, 0, -1, 0, 0, 0, 0])
            
            # Run episode
            t = 0
            replay_images = []
            success = False
            
            while t < cfg.max_steps:
                # Get action from FE model
                action = evaluator.predict_action(obs, instruction, t)
                
                # Process action
                action = normalize_gripper_action(action, binarize=True)
                if cfg.disable_rotation:
                    action[3:6] = 0
                
                print(f"Step {t}: Action: {action}")
                
                # Save image for replay
                replay_images.append(evaluator.process_image(obs["agentview_rgb"]))
                
                # Execute action
                obs, reward, terminated, truncated, info = env.step(action.tolist())
                
                if terminated:
                    success = True
                    break
                    
                t += 1
            
            # Record trial results
            trial_result = {
                "trial": trial + 1,
                "success": success,
                "steps": t + 1,
                "max_steps_reached": t >= cfg.max_steps,
            }
            task_results.append(trial_result)
            total_episodes += 1
            
            print(f"Trial {trial + 1} - Success: {success}, Steps: {t + 1}")
            
            # Save replay video
            save_rollout_video(
                replay_images, 
                total_episodes, 
                success=success, 
                task_description=f"{instruction}_trial_{trial+1}",
                log_file=None
            )
        
        # Calculate task statistics
        successes = [r["success"] for r in task_results]
        success_rate = np.mean(successes)
        avg_steps = np.mean([r["steps"] for r in task_results])
        
        task_summary = {
            "task_id": task_id,
            "instruction": instruction,
            "num_trials": cfg.num_trials_per_task,
            "success_rate": success_rate,
            "avg_steps": avg_steps,
            "trials": task_results
        }
        
        results["tasks"][task_id] = task_summary
        
        print(f"\nTask {task_id} Summary:")
        print(f"Success Rate: {success_rate:.2%} ({sum(successes)}/{len(successes)})")
        print(f"Average Steps: {avg_steps:.1f}")
        
        # Save results
        with open(local_log_filepath, "w") as f:
            json.dump(results, f, indent=2)
    
    # Cleanup
    if env:
        env.close()
    
    print(f"\nEvaluation completed! Results saved to: {local_log_filepath}")
    return results


if __name__ == "__main__":
    run_franka_evaluation()