"""
Experiment 0: FE L1 Loss Validation
====================================
Test FE model's L1 loss on various dataset validation splits.
Uses 512 random samples from training split for calibration.
"""

import gzip
import json
import pickle
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional

import numpy as np
import torch
from huggingface_hub import snapshot_download
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor

# Add project root to path
sys.path.append(str(Path(__file__).parent.parent))

from experiments.robot.openvla_utils import (
    check_model_logic_mismatch,
    model_is_on_hf_hub,
    update_auto_map,
)
from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
from prismatic.models.action_heads import FunctionEncoderActionHead
from prismatic.models.backbones.llm.prompting import PurePromptBuilder
from prismatic.training.train_utils import get_current_action_mask, get_next_actions_mask
from prismatic.vla.calibration_buffer import CalibrationManager
from prismatic.vla.constants import (
    ACTION_DIM,
    ACTION_PROPRIO_NORMALIZATION_TYPE,
    NUM_ACTIONS_CHUNK,
    NormalizationType,
)
from prismatic.vla.datasets.rlds.oxe.mixtures import OXE_NAMED_MIXTURES
from prismatic.vla.materialize import get_vla_dataset_and_collator


@dataclass
class Exp0Config:
    """Configuration for Experiment 0."""

    # Model paths
    vla_path: str = "openvla/openvla-7b"
    fe_checkpoint_path: Optional[str] = None

    # FE configuration
    fe_basis_functions: int = 16
    n_continuous_actions: int = 6
    calibration_samples: int = 512

    # Dataset configuration
    data_root_dir: str = "DATA_ROOT_DIR"
    dataset_names: Optional[List[str]] = None  # If None, use default list

    # Evaluation settings
    batch_size: int = 8
    num_val_episodes: int = 20  # Number of validation episodes to test

    # LoRA configuration
    use_lora: bool = True
    lora_rank: int = 32
    lora_dropout: float = 0.0

    # Hardware
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    dtype: torch.dtype = torch.bfloat16

    # Output
    output_dir: str = "runs/exp0_fe_l1_validation"
    seed: int = 42


class FEL1Evaluator:
    """Evaluator for FE L1 loss on validation datasets."""

    def __init__(self, config: Exp0Config):
        self.config = config
        self.device = config.device
        self.dtype = config.dtype

        # Setup output directory
        self.output_dir = Path(config.output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

        # Setup random seed
        torch.manual_seed(config.seed)
        np.random.seed(config.seed)

        # Store dataset statistics for unnormalization (separate train vs eval)
        self.dataset_stats_train: Dict[str, Dict] = {}
        self.dataset_stats_eval: Dict[str, Dict] = {}

        # Initialize models and components
        self._setup_models()
        self._setup_datasets()

    def _setup_models(self):
        """Setup VLA model and FE action head."""
        print("=" * 80)
        print("Loading VLA model...")

        # Load processor and VLA exactly like training script
        if model_is_on_hf_hub(self.config.vla_path):
            vla_download_path = snapshot_download(repo_id=self.config.vla_path)
            vla_path = vla_download_path
        else:
            # Register models if loading locally
            AutoConfig.register("openvla", OpenVLAConfig)
            AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
            AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
            AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
            vla_path = self.config.vla_path

        # Update config
        update_auto_map(vla_path)
        check_model_logic_mismatch(vla_path)

        self.processor = AutoProcessor.from_pretrained(vla_path, trust_remote_code=True)
        self.vla = AutoModelForVision2Seq.from_pretrained(
            vla_path,
            torch_dtype=self.dtype,
            low_cpu_mem_usage=True,
            trust_remote_code=True,
        ).to(self.device)

        self.vla.vision_backbone.set_num_images_in_input(1)

        # Get model dimensions
        vla_module = self.vla.module if hasattr(self.vla, "module") else self.vla
        llm_dim = vla_module.llm_dim
        ACTION_DIM = 7  # 6 DoF + gripper

        # Create FE action head
        print(f"Creating FE action head with {self.config.fe_basis_functions} basis functions...")
        self.fe_action_head = FunctionEncoderActionHead(
            input_dim=llm_dim * ACTION_DIM,
            hidden_dim=llm_dim,
            action_dim=ACTION_DIM,
            k=self.config.fe_basis_functions,
            n_continuous=self.config.n_continuous_actions,
        ).to(self.device, dtype=self.dtype)

        # Apply LoRA if configured
        if self.config.use_lora:
            print(f"Applying LoRA with rank {self.config.lora_rank}...")
            lora_config = LoraConfig(
                r=self.config.lora_rank,
                lora_alpha=min(self.config.lora_rank, 16),
                lora_dropout=self.config.lora_dropout,
                target_modules="all-linear",
                init_lora_weights="gaussian",
            )
            self.vla = get_peft_model(self.vla, lora_config)
            self.vla.print_trainable_parameters()

        # Load FE checkpoint directory if provided (expects files saved by pretrain_fe.py)
        if self.config.fe_checkpoint_path:
            ckpt_dir = Path(self.config.fe_checkpoint_path)
            print(f"Loading FE checkpoint directory: {ckpt_dir}")

            # Try to infer step from directory name or contained files
            step = None
            # Parse from directory name like "checkpoint-40000"
            name_parts = ckpt_dir.name.split("-")
            if len(name_parts) > 1 and name_parts[-1].isdigit():
                step = int(name_parts[-1])
            if step is None:
                # Fallback: scan for fe_action_head--*_checkpoint.pt
                matches = list(ckpt_dir.glob("fe_action_head--*_checkpoint.pt"))
                if matches:
                    stem = matches[0].stem  # e.g., fe_action_head--40000_checkpoint
                    try:
                        step = int(stem.split("--")[1].split("_")[0])
                    except Exception:
                        step = None

            if step is None:
                print("Could not infer checkpoint step; skipping checkpoint load.")
            else:
                # Load VLA state (prefer LoRA-only state if present)
                def _add_default_to_lora_keys(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
                    """Adjust PEFT LoRA key names to include `.default` for compatibility with PEFT>=0.10."""
                    new_state = {}
                    for k, v in state_dict.items():
                        k_new = k
                        if "lora_A" in k and ".default" not in k:
                            k_new = k.replace("lora_A", "lora_A.default")
                        if "lora_B" in k_new and ".default" not in k_new:
                            k_new = k_new.replace("lora_B", "lora_B.default")
                        new_state[k_new] = v
                    return new_state

                vla_lora_path = ckpt_dir / f"vla_lora--{step}_checkpoint.pt"
                vla_full_path = ckpt_dir / f"vla--{step}_checkpoint.pt"
                if vla_lora_path.exists():
                    print(f"Loading VLA LoRA weights: {vla_lora_path}")
                    vla_state = torch.load(vla_lora_path, map_location=self.device)
                    vla_state = _add_default_to_lora_keys(vla_state)
                    missing, unexpected = self.vla.load_state_dict(vla_state, strict=False)
                    if missing:
                        pass  # It is ok because we are loading only LoRA weights
                    if unexpected:
                        print(f"[LoRA] Unexpected keys after load: {len(unexpected)}:")
                        for k in unexpected:
                            print(f"  {k}")
                        raise RuntimeError("Unexpected keys in LoRA checkpoint.")
                elif vla_full_path.exists():
                    print(f"Loading full VLA weights: {vla_full_path}")
                    vla_state = torch.load(vla_full_path, map_location=self.device)
                    self.vla.load_state_dict(vla_state, strict=True)
                else:
                    print("No VLA weight file found in checkpoint directory.")

                # Load FE action head weights
                fe_path = ckpt_dir / f"fe_action_head--{step}_checkpoint.pt"
                if fe_path.exists():
                    print(f"Loading FE action head weights: {fe_path}")
                    fe_state = torch.load(fe_path, map_location=self.device)
                    self.fe_action_head.load_state_dict(fe_state, strict=True)
                else:
                    raise FileNotFoundError(f"No FE action head weight file found in checkpoint directory: {fe_path}")

                # Load calibration state (dataset coefficients)
                # calib_path = ckpt_dir / f"calibration_state--{step}.pt"
                # if calib_path.exists():
                #     print(f"Loading calibration state: {calib_path}")
                #     calibration_state = torch.load(calib_path, map_location=self.device)
                #     dataset_coeffs = calibration_state.get("dataset_coefficients", {})
                #     for ds_name, coeffs in dataset_coeffs.items():
                #         if isinstance(coeffs, dict) and "l1" in coeffs and "l2" in coeffs:
                #             self.fe_action_head.set_dataset_coefficients(ds_name, coeffs["l1"], coeffs["l2"])
                # else:
                #     print("No calibration state file found in checkpoint directory.")

        # Set to eval mode
        self.vla.eval()
        self.fe_action_head.eval()

    def _setup_datasets(self):
        """Setup dataset list for evaluation."""
        if self.config.dataset_names:
            self.dataset_names = self.config.dataset_names
        else:
            # Use default list from OXE mixtures
            seen_datasets = OXE_NAMED_MIXTURES["oxe_magic_soup_plus"]
            seen_datasets = [d[0] for d in seen_datasets]  # Remove weights

            unseen_datasets = []
            for dataset_name, _ in OXE_NAMED_MIXTURES["rtx_franka"]:
                if dataset_name not in seen_datasets:
                    unseen_datasets.append(dataset_name)

            # Combine all datasets
            self.dataset_names = seen_datasets + unseen_datasets

            # remove berkeley_rpt_converted_externally_to_rlds
            self.dataset_names = [ds for ds in self.dataset_names if ds != "berkeley_rpt_converted_externally_to_rlds"]

        print(f"Will evaluate on {len(self.dataset_names)} datasets")

    def _get_dataset_and_collator(self, dataset_name: str, *, train: bool, episodic: bool):
        """Materialize dataset and collator using get_vla_dataset_and_collator with proper root."""
        seen_ds = [d[0] for d in OXE_NAMED_MIXTURES["oxe_magic_soup_plus"]]
        data_root = (
            Path("DATA_ROOT_DIR") if dataset_name in seen_ds else "gs://gresearch/robotics/"
        )

        dataset, _action_tokenizer, collator = get_vla_dataset_and_collator(
            data_root_dir=data_root,
            data_mix=dataset_name,
            image_transform=self.processor.image_processor.apply_transform,
            tokenizer=self.processor.tokenizer,
            prompt_builder_fn=PurePromptBuilder,
            default_image_resolution=(3, 224, 224),
            padding_side="right",
            predict_stop_token=True,
            shuffle_buffer_size=1000,
            train=train,
            episodic=episodic,
            image_aug=False,
        )

        # Extract dataset statistics from the dataset object
        if train:
            # dataset_statistics is a dict with dataset name as key
            for ds_name, stats in dataset.dataset_statistics.items():
                if dataset_name in ds_name or ds_name in dataset_name:
                    self.dataset_stats_train[dataset_name] = stats
                    print(f"Loaded TRAIN stats for {dataset_name}")
                    break
        else:
            self.dataset_stats_eval[dataset_name] = dataset.dataset_statistics

        return dataset, collator

    def unnormalize_action(
        self, normalized_action: torch.Tensor, dataset_name: str, *, use_train_stats: bool
    ) -> torch.Tensor:
        """Unnormalize actions from [-1, 1] back to original scale using dataset statistics.

        Note: The gripper action (last dimension) has action_normalization_mask=False,
        so it's never normalized and stays in [0, 1] range. Only the first 6 dimensions
        (position + rotation) are normalized and need unnormalization.
        """
        if use_train_stats:
            assert dataset_name in self.dataset_stats_train, f"Missing train stats for {dataset_name}"
            stats = self.dataset_stats_train[dataset_name]
        else:
            assert dataset_name in self.dataset_stats_eval, f"Missing eval stats for {dataset_name}"
            stats = self.dataset_stats_eval[dataset_name]
        unnormalized = normalized_action.clone()

        # Check if there's a normalization mask in statistics
        has_mask = "mask" in stats["action"]
        if has_mask:
            # Use the actual normalization mask from dataset statistics
            norm_mask = torch.tensor(stats["action"]["mask"], dtype=torch.bool, device=normalized_action.device)
            # Assert that the mask is the same as the default mask
            assert torch.all(
                norm_mask == torch.tensor([True] * 6 + [False], dtype=torch.bool, device=normalized_action.device)
            )
        else:
            # Default: normalize first 6 dims (position+rotation), skip last dim (gripper)
            # This matches the standard EEF_POS action encoding pattern
            norm_mask = torch.tensor([True] * 6 + [False], dtype=torch.bool, device=normalized_action.device)

        # Only unnormalize dimensions that were actually normalized
        assert ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99
        # Use q01 and q99 for unnormalization
        q01 = torch.tensor(stats["action"]["q01"], dtype=normalized_action.dtype, device=normalized_action.device)
        q99 = torch.tensor(stats["action"]["q99"], dtype=normalized_action.dtype, device=normalized_action.device)

        # Unnormalize: original = (normalized + 1) * (q99 - q01) / 2 + q01
        unnormalized[..., norm_mask] = (normalized_action[..., norm_mask] + 1) * (
            q99[norm_mask] - q01[norm_mask]
        ) / 2 + q01[norm_mask]

        # Gripper action (dimensions with norm_mask=False) remains unchanged
        # It's already in the correct [0, 1] range from the dataloader
        return unnormalized

    def calibrate_on_dataset(self, dataset_name: str) -> bool:
        """Calibrate FE on training samples using CalibrationManager."""
        print(f"\nCalibrating on {dataset_name}...")

        # Get training dataset (non-episodic)
        dataset, collator = self._get_dataset_and_collator(dataset_name, train=True, episodic=False)
        if dataset is None:
            return False

        # Initialize CalibrationManager for this dataset
        calib_manager = CalibrationManager(dataset_names=[dataset_name], buffer_size=self.config.calibration_samples)

        # Prefill buffer with raw samples batched by collator
        loader = DataLoader(
            dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True,
            collate_fn=collator,
        )
        total_added = 0
        for batch in loader:
            bsz = batch["input_ids"].shape[0]
            for i in range(bsz):
                raw_sample = {
                    k: v[i]
                    for k, v in batch.items()
                    if k in {"pixel_values", "input_ids", "attention_mask", "labels", "actions"}
                }
                calib_manager.add_training_sample(raw_sample, dataset_name)
                total_added += 1
                if total_added >= self.config.calibration_samples:
                    break
            if total_added >= self.config.calibration_samples:
                break

        if len(calib_manager.buffers[dataset_name]) == 0:
            print(f"No calibration samples collected for {dataset_name}")
            return False

        # Calibrate using manager (handles hidden extraction and CVX)
        calib_manager.eval_mode = True
        calib_manager.calibrate_all_datasets(self.fe_action_head, self.vla, collator)

        print(f"Calibration completed for {dataset_name}")

    def evaluate_on_dataset(self, dataset_name: str) -> Dict:
        """Evaluate L1 loss on validation split of a dataset."""
        print(f"\nEvaluating on {dataset_name} validation split...")

        # Get episodic validation dataset
        val_dataset, collator = self._get_dataset_and_collator(dataset_name, train=False, episodic=True)
        if val_dataset is None:
            return {"error": "Failed to load dataset"}

        # Check if we have calibration coefficients
        if self.fe_action_head.has_dataset_coefficients(dataset_name):
            print(f"Using existing calibration coefficients for {dataset_name}")
        else:
            print(f"No calibration coefficients for {dataset_name}, calibrating first...")
            self.calibrate_on_dataset(dataset_name)

        # Evaluate L1 loss on first N episodes, collecting per-episode metrics like eval_autoregressive.py
        raw_traj_data = {}
        traj_idx = 0

        if self.config.num_val_episodes == -1 or self.config.num_val_episodes > len(val_dataset):
            num_val_episodes = len(val_dataset)
        else:
            num_val_episodes = self.config.num_val_episodes

        with torch.no_grad():
            for episode in tqdm(val_dataset, desc="Episodes", total=num_val_episodes):
                if traj_idx >= num_val_episodes:
                    break

                # Collect raw arrays for current trajectory
                pred_norm_list, gt_norm_list = [], []
                pred_unnorm_list, gt_unnorm_list = [], []

                for i in range(0, len(episode), self.config.batch_size):
                    batch_items = episode[i : i + self.config.batch_size]
                    batch = collator(batch_items)
                    batch = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in batch.items()}
                    batch_size = batch["input_ids"].shape[0]

                    # Forward pass through VLA
                    with torch.autocast(device_type="cuda", dtype=self.dtype, enabled=True):
                        outputs = self.vla(
                            input_ids=batch["input_ids"],
                            attention_mask=batch["attention_mask"],
                            pixel_values=batch["pixel_values"],
                            labels=batch["labels"],
                            output_hidden_states=True,
                        )

                    # Align with training inference: get text hidden states and masks
                    last_hidden_states = outputs.hidden_states[-1]
                    num_patches = (
                        self.vla.vision_backbone.get_num_patches() * self.vla.vision_backbone.get_num_images_in_input()
                    )
                    text_hidden_states = last_hidden_states[:, num_patches:-1]

                    ground_truth_token_ids = batch["labels"][:, 1:]
                    current_action_mask = get_current_action_mask(ground_truth_token_ids)
                    next_actions_mask = get_next_actions_mask(ground_truth_token_ids)

                    actions_hidden_states = text_hidden_states[current_action_mask | next_actions_mask]
                    actions_hidden_states = actions_hidden_states.reshape(
                        batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1
                    ).to(torch.bfloat16)

                    # Predict actions with FE
                    dataset_names_batch = [dataset_name] * batch_size
                    predicted_actions, _ = self.fe_action_head.predict_action(actions_hidden_states, dataset_names_batch)

                    # Ground truth continuous actions from collator (normalized)
                    predicted_actions = predicted_actions.to(torch.float32)
                    gt_actions_normalized = batch["actions"].to(self.device)

                    # Unnormalize predicted with TRAIN stats; GT with EVAL stats
                    predicted_actions_unnorm = self.unnormalize_action(
                        predicted_actions, dataset_name, use_train_stats=True
                    )
                    gt_actions_unnorm = self.unnormalize_action(
                        gt_actions_normalized, dataset_name, use_train_stats=False
                    )

                    # Save current-step (index 0 in chunk) predictions/GT
                    pred_norm_list.append(predicted_actions[:, 0, :].detach().cpu().numpy())
                    gt_norm_list.append(gt_actions_normalized[:, 0, :].detach().cpu().numpy())
                    pred_unnorm_list.append(predicted_actions_unnorm[:, 0, :].detach().cpu().numpy())
                    gt_unnorm_list.append(gt_actions_unnorm[:, 0, :].detach().cpu().numpy())

                # Stack collected arrays for this trajectory
                raw_traj_data[traj_idx] = {
                    "pred_norm": np.vstack(pred_norm_list),
                    "gt_norm": np.vstack(gt_norm_list),
                    "pred_unnorm": np.vstack(pred_unnorm_list),
                    "gt_unnorm": np.vstack(gt_unnorm_list),
                }
                traj_idx += 1

        return raw_traj_data

    def run_full_evaluation(self):
        """Run evaluation on all datasets."""
        print("=" * 80)
        print("Starting full evaluation on all datasets...")
        print(f"Timestamp: {datetime.now().isoformat()}")
        print("=" * 80)

        # Evaluate per dataset and save like eval_autoregressive.py
        for dataset_name in tqdm(self.dataset_names, desc="Evaluating datasets"):
            print("\n" + "=" * 80)
            raw_traj_data = self.evaluate_on_dataset(dataset_name)

            # Save raw arrays via pickle
            raw_path = self.output_dir / f"{dataset_name}.pkl"
            with open(raw_path, "wb") as f:
                pickle.dump(raw_traj_data, f)
            print(f"Saved raw predictions to: {raw_path}")

        print("\n" + "=" * 80)
        print("Evaluation completed!")
        return {"status": "ok"}


def main():
    """Main entry point."""
    import argparse

    parser = argparse.ArgumentParser(description="Experiment 0: FE L1 Loss Validation")

    # Model arguments
    parser.add_argument("--vla_path", type=str, default="openvla/openvla-7b", help="Path to base VLA model")
    parser.add_argument("--fe_checkpoint_path", type=str, default=None, help="Path to FE checkpoint")

    # FE configuration
    parser.add_argument("--fe_basis_functions", type=int, default=16, help="Number of FE basis functions")
    parser.add_argument("--n_continuous_actions", type=int, default=6, help="Number of continuous actions")
    parser.add_argument(
        "--calibration_samples", type=int, default=512, help="Number of training samples for calibration"
    )

    # Dataset configuration
    parser.add_argument(
        "--data_root_dir", type=str, default="DATA_ROOT_DIR", help="Root directory for datasets"
    )
    parser.add_argument("--dataset_names", type=str, nargs="+", default=None, help="Specific datasets to evaluate")

    # Evaluation settings
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size for evaluation")
    parser.add_argument("--num_val_episodes", type=int, default=20, help="Number of validation episodes per dataset")

    # LoRA configuration
    parser.add_argument("--use_lora", action="store_true", help="Use LoRA adaptation")
    parser.add_argument("--lora_rank", type=int, default=32, help="LoRA rank")
    parser.add_argument("--lora_dropout", type=float, default=0.0, help="LoRA dropout")

    # Output
    parser.add_argument(
        "--output_dir", type=str, default="eval-results/validation/exp0", help="Output directory for results"
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed")

    args = parser.parse_args()

    # Create configuration
    config = Exp0Config(
        vla_path=args.vla_path,
        fe_checkpoint_path=args.fe_checkpoint_path,
        fe_basis_functions=args.fe_basis_functions,
        n_continuous_actions=args.n_continuous_actions,
        calibration_samples=args.calibration_samples,
        data_root_dir=args.data_root_dir,
        dataset_names=args.dataset_names,
        batch_size=args.batch_size,
        num_val_episodes=args.num_val_episodes,
        use_lora=args.use_lora,
        lora_rank=args.lora_rank,
        lora_dropout=args.lora_dropout,
        output_dir=args.output_dir,
        seed=args.seed,
    )

    # Run evaluation
    evaluator = FEL1Evaluator(config)
    evaluator.run_full_evaluation()


if __name__ == "__main__":
    main()
