"""
RLBench Experiment Runner

static/experiment_tasks.yml     ,
 /   .

Usage:
    python rlbench_experiment.py
    python rlbench_experiment.py --output_dir results/exp_001
    python rlbench_experiment.py --task_ids task_000 task_001
    python rlbench_experiment.py --difficulty easy
"""

import argparse
import json
import os
import random
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

import numpy as np
import yaml

# rlbench_loop run_task_episode EpisodeResult 
from rlbench_loop import run_task_episode, EpisodeResult


# ============================================================================
# Seed Setting
# ============================================================================

def set_seed(seed: int = 42):
    """   random seed """
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

    # torch   torch seed 
    try:
        import torch
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # multi-GPU
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except ImportError:
        pass

    print(f"[SEED] All random seeds set to {seed}")


# ============================================================================
# NumPy Type Conversion
# ============================================================================

def convert_numpy_types(obj: Any) -> Any:
    """NumPy  JSON   Python  """
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, np.bool_):
        return bool(obj)
    elif isinstance(obj, (np.integer, np.int_, np.int64, np.int32, np.int16, np.int8)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float64, np.float32, np.float16)):
        return float(obj)
    elif isinstance(obj, dict):
        return {k: convert_numpy_types(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [convert_numpy_types(item) for item in obj]
    return obj


# ============================================================================
# Result Data Structures
# ============================================================================

class SubtaskResult:
    """()  """
    def __init__(self):
        self.episode_name: str = ""
        self.success: bool = False
        self.execution_time_sec: float = 0.0
        self.llm_calls: List[Dict[str, Any]] = []
        self.total_llm_time_sec: float = 0.0
        self.error: Optional[str] = None
        self.error_type: Optional[str] = None

    def to_dict(self) -> Dict[str, Any]:
        return {
            "episode_name": self.episode_name,
            "success": self.success,
            "execution_time_sec": self.execution_time_sec,
            "llm_calls": self.llm_calls,
            "total_llm_time_sec": self.total_llm_time_sec,
            "error": self.error,
            "error_type": self.error_type,
        }


class TaskResult:
    """   (  )"""
    def __init__(self):
        self.task_id: str = ""
        self.difficulty: str = ""
        self.subtasks: List[SubtaskResult] = []
        self.success: bool = False  #     True
        self.subtask_success_rate: float = 0.0
        self.total_execution_time_sec: float = 0.0
        self.total_llm_time_sec: float = 0.0

    def compute_metrics(self):
        """   """
        if not self.subtasks:
            return

        success_count = sum(1 for s in self.subtasks if s.success)
        self.subtask_success_rate = success_count / len(self.subtasks)
        self.success = success_count == len(self.subtasks)
        self.total_execution_time_sec = sum(s.execution_time_sec for s in self.subtasks)
        self.total_llm_time_sec = sum(s.total_llm_time_sec for s in self.subtasks)

    def to_dict(self) -> Dict[str, Any]:
        self.compute_metrics()
        return {
            "task_id": self.task_id,
            "difficulty": self.difficulty,
            "success": self.success,
            "subtask_success_rate": self.subtask_success_rate,
            "total_execution_time_sec": self.total_execution_time_sec,
            "total_llm_time_sec": self.total_llm_time_sec,
            "subtasks": [s.to_dict() for s in self.subtasks],
        }


class ExperimentResult:
    """  """
    def __init__(self):
        self.experiment_id: str = ""
        self.start_time: str = ""
        self.end_time: str = ""
        self.config: Dict[str, Any] = {}
        self.tasks: List[TaskResult] = []
        self.overall_success_rate: float = 0.0
        self.overall_subtask_success_rate: float = 0.0

    def compute_metrics(self):
        """  """
        if not self.tasks:
            return

        for t in self.tasks:
            t.compute_metrics()

        success_count = sum(1 for t in self.tasks if t.success)
        self.overall_success_rate = success_count / len(self.tasks)

        total_subtasks = sum(len(t.subtasks) for t in self.tasks)
        if total_subtasks > 0:
            success_subtasks = sum(
                sum(1 for s in t.subtasks if s.success)
                for t in self.tasks
            )
            self.overall_subtask_success_rate = success_subtasks / total_subtasks

    def to_dict(self) -> Dict[str, Any]:
        self.compute_metrics()
        return {
            "experiment_id": self.experiment_id,
            "start_time": self.start_time,
            "end_time": self.end_time,
            "config": self.config,
            "overall_success_rate": self.overall_success_rate,
            "overall_subtask_success_rate": self.overall_subtask_success_rate,
            "tasks": [t.to_dict() for t in self.tasks],
        }


# ============================================================================
# File I/O
# ============================================================================

def load_experiment_tasks(yaml_path: str) -> List[Dict[str, Any]]:
    """experiment_tasks.yml  """
    with open(yaml_path, "r", encoding="utf-8") as f:
        data = yaml.safe_load(f)
    return data.get("tasks", [])


def save_result_json(result: Dict[str, Any], output_path: str):
    """ JSON  """
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    # NumPy  
    result_converted = convert_numpy_types(result)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(result_converted, f, indent=2, ensure_ascii=False)
    print(f"[SAVE] Result saved to: {output_path}")


def load_checkpoint(checkpoint_path: str) -> Optional[ExperimentResult]:
    """   """
    if os.path.exists(checkpoint_path):
        with open(checkpoint_path, "r", encoding="utf-8") as f:
            data = json.load(f)

        result = ExperimentResult()
        result.experiment_id = data.get("experiment_id", "")
        result.start_time = data.get("start_time", "")
        result.config = data.get("config", {})

        for task_data in data.get("tasks", []):
            task_result = TaskResult()
            task_result.task_id = task_data.get("task_id", "")
            task_result.difficulty = task_data.get("difficulty", "")

            for subtask_data in task_data.get("subtasks", []):
                subtask_result = SubtaskResult()
                subtask_result.episode_name = subtask_data.get("episode_name", "")
                subtask_result.success = subtask_data.get("success", False)
                subtask_result.execution_time_sec = subtask_data.get("execution_time_sec", 0.0)
                subtask_result.llm_calls = subtask_data.get("llm_calls", [])
                subtask_result.total_llm_time_sec = subtask_data.get("total_llm_time_sec", 0.0)
                subtask_result.error = subtask_data.get("error")
                subtask_result.error_type = subtask_data.get("error_type")
                task_result.subtasks.append(subtask_result)

            result.tasks.append(task_result)

        print(f"[CHECKPOINT] Loaded checkpoint from: {checkpoint_path}")
        return result

    return None


# ============================================================================
# LLM Call Logger
# ============================================================================

class LLMCallLogger:
    """LLM    """

    def __init__(self, log_file: str):
        self.log_file = log_file
        os.makedirs(os.path.dirname(log_file), exist_ok=True)

        #  
        with open(self.log_file, "w", encoding="utf-8") as f:
            f.write("timestamp,task,subtask,action,duration_sec,num_tokens\n")

    def log(self, call_info: Dict[str, Any]):
        """LLM     """
        timestamp = datetime.now().isoformat()
        task = call_info.get("task", "")
        subtask = call_info.get("subtask", "")
        action = call_info.get("action", "")
        duration = call_info.get("duration_sec", 0.0)
        num_tokens = call_info.get("num_tokens", "")  #   (  )

        with open(self.log_file, "a", encoding="utf-8") as f:
            f.write(f"{timestamp},{task},{subtask},{action},{duration:.3f},{num_tokens}\n")

        tokens_str = f", tokens: {num_tokens}" if num_tokens else ""
        print(f"[LLM_CALL] {task}/{subtask} - {action}: {duration:.2f}s{tokens_str}")


# ============================================================================
# Experiment Runner
# ============================================================================

def run_experiment(
    tasks_config: List[Dict[str, Any]],
    output_dir: str,
    mode: str = "step",
    source_robot: str = "panda",
    target_robot: str = "panda",
    max_skill_calls: int = 10,
    max_repair_attempts: int = 5,
    code_source: str = "remote_llm",
    model: str = "ours",
    remote_host: str = "172.17.0.1",
    remote_port: int = 5000,
    task_ids: Optional[List[str]] = None,
    difficulty: Optional[str] = None,
    resume_from_checkpoint: bool = True,
) -> ExperimentResult:
    """
       

    Args:
        tasks_config: experiment_tasks.yml   
        output_dir:   
        mode:   (step  failure)
        source_robot: Source robot type
        target_robot: Target robot type
        max_skill_calls: Maximum number of primitive skill calls
        max_repair_attempts: Maximum number of repair attempts before giving up (default: 5)
        code_source: Where to get task skill code
        model: Model type
        remote_host: Remote LLM server host
        remote_port: Remote LLM server port
        task_ids:    ID  (None )
        difficulty:    (None )
        resume_from_checkpoint:    

    Returns:
        ExperimentResult:  
    """
    #  ID 
    experiment_id = datetime.now().strftime("%Y%m%d_%H%M%S")

    #   
    os.makedirs(output_dir, exist_ok=True)

    #  
    checkpoint_path = os.path.join(output_dir, "checkpoint.json")
    final_result_path = os.path.join(output_dir, "final_result.json")
    llm_log_path = os.path.join(output_dir, "llm_calls.csv")

    # LLM  
    llm_logger = LLMCallLogger(llm_log_path)

    #   
    result = None
    completed_tasks = set()

    if resume_from_checkpoint:
        result = load_checkpoint(checkpoint_path)
        if result is not None:
            experiment_id = result.experiment_id
            completed_tasks = {t.task_id for t in result.tasks}
            print(f"[RESUME] Resuming experiment {experiment_id}")
            print(f"[RESUME] Already completed tasks: {completed_tasks}")

    if result is None:
        result = ExperimentResult()
        result.experiment_id = experiment_id
        result.start_time = datetime.now().isoformat()

    #  
    result.config = {
        "mode": mode,
        "source_robot": source_robot,
        "target_robot": target_robot,
        "max_skill_calls": max_skill_calls,
        "code_source": code_source,
        "model": model,
        "remote_host": remote_host,
        "remote_port": remote_port,
        "task_ids_filter": task_ids,
        "difficulty_filter": difficulty,
    }

    #  
    filtered_tasks = tasks_config
    if task_ids:
        filtered_tasks = [t for t in filtered_tasks if t.get("id") in task_ids]
    if difficulty:
        filtered_tasks = [t for t in filtered_tasks if t.get("difficulty") == difficulty]

    print(f"\n{'='*60}")
    print(f"[EXPERIMENT] Starting experiment: {experiment_id}")
    print(f"[EXPERIMENT] Total tasks: {len(filtered_tasks)}")
    print(f"[EXPERIMENT] Output directory: {output_dir}")
    print(f"{'='*60}\n")

    #   
    for task_idx, task_config in enumerate(filtered_tasks):
        task_id = task_config.get("id", f"task_{task_idx:03d}")
        task_difficulty = task_config.get("difficulty", "unknown")
        episodes = task_config.get("episodes", [])

        #    
        if task_id in completed_tasks:
            print(f"[SKIP] Task {task_id} already completed")
            continue

        print(f"\n{'='*60}")
        print(f"[TASK {task_idx + 1}/{len(filtered_tasks)}] {task_id} ({task_difficulty})")
        print(f"[TASK] Episodes: {episodes}")
        print(f"{'='*60}\n")

        task_result = TaskResult()
        task_result.task_id = task_id
        task_result.difficulty = task_difficulty

        #  () 
        for ep_idx, episode_name in enumerate(episodes):
            print(f"\n[SUBTASK {ep_idx + 1}/{len(episodes)}] {episode_name}")
            print("-" * 40)

            subtask_result = SubtaskResult()
            subtask_result.episode_name = episode_name

            try:
                # run_task_episode 
                episode_result: EpisodeResult = run_task_episode(
                    task_name=episode_name,
                    mode=mode,
                    source_robot=source_robot,
                    target_robot=target_robot,
                    max_skill_calls=max_skill_calls,
                    max_repair_attempts=max_repair_attempts,
                    code_source=code_source,
                    model=model,
                    remote_host=remote_host,
                    remote_port=remote_port,
                    llm_call_logger=llm_logger.log,
                )
                if episode_result is None:
                    continue

                #  
                subtask_result.success = episode_result.success
                subtask_result.execution_time_sec = episode_result.execution_time
                subtask_result.llm_calls = episode_result.llm_calls
                subtask_result.total_llm_time_sec = sum(
                    c.get("duration_sec", 0) for c in episode_result.llm_calls
                )
                subtask_result.error = episode_result.error
                subtask_result.error_type = episode_result.error_type

                # grounded_skills    
                if episode_result.final_code:
                    grounded_skills_dir = os.path.join(output_dir, "grounded_skills")
                    os.makedirs(grounded_skills_dir, exist_ok=True)

                    # : task_{task_}_{episode_}.py
                    # task_id   (: task_000 -> 000)
                    task_num = task_id.replace("task_", "") if task_id.startswith("task_") else task_id
                    skill_filename = f"task_{task_num}_{episode_name}.py"
                    skill_filepath = os.path.join(grounded_skills_dir, skill_filename)

                    with open(skill_filepath, "w", encoding="utf-8") as f:
                        f.write(episode_result.final_code)
                    print(f"  - Grounded skill saved to: {skill_filepath}")

                status = "SUCCESS" if subtask_result.success else "FAILED"
                print(f"\n[SUBTASK RESULT] {episode_name}: {status}")
                print(f"  - Execution time: {subtask_result.execution_time_sec:.2f}s")
                print(f"  - LLM time: {subtask_result.total_llm_time_sec:.2f}s")
                print(f"  - LLM calls: {len(subtask_result.llm_calls)}")

            except Exception as e:
                import traceback
                print(f"[SUBTASK ERROR] {episode_name}: {e}")
                traceback.print_exc()
                subtask_result.success = False
                subtask_result.error = str(e)
                subtask_result.error_type = type(e).__name__

            task_result.subtasks.append(subtask_result)

        #   
        task_result.compute_metrics()
        result.tasks.append(task_result)

        print(f"\n[TASK RESULT] {task_id}")
        print(f"  - Success: {task_result.success}")
        print(f"  - Subtask success rate: {task_result.subtask_success_rate:.2%}")
        print(f"  - Total execution time: {task_result.total_execution_time_sec:.2f}s")
        print(f"  - Total LLM time: {task_result.total_llm_time_sec:.2f}s")

        #   (  )
        save_result_json(result.to_dict(), checkpoint_path)

    #   
    result.end_time = datetime.now().isoformat()
    result.compute_metrics()
    save_result_json(result.to_dict(), final_result_path)

    #  
    print(f"\n{'='*60}")
    print("[EXPERIMENT COMPLETE]")
    print(f"  - Overall task success rate: {result.overall_success_rate:.2%}")
    print(f"  - Overall subtask success rate: {result.overall_subtask_success_rate:.2%}")
    print(f"  - Results saved to: {output_dir}")
    print(f"{'='*60}\n")

    return result


# ============================================================================
# Main
# ============================================================================

def main():
    parser = argparse.ArgumentParser(
        description="RLBench Experiment Runner - Run experiments from experiment_tasks.yml"
    )
    parser.add_argument(
        "--yaml_path",
        type=str,
        default="./static/experiment_tasks.yml",
        help="Path to experiment_tasks.yml"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help="Output directory for results (default: results/exp_YYYYMMDD_HHMMSS)"
    )
    parser.add_argument(
        "--mode", "-m",
        choices=["step", "failure"],
        default="step",
        help="Execution mode"
    )
    parser.add_argument(
        "--source_robot", "-s",
        default="panda",
        help="Source robot type (panda, ur5, sawyer, jaco)"
    )
    parser.add_argument(
        "--target_robot", "-r",
        default="ur5",
        help="Target robot type for environment setup"
    )
    parser.add_argument(
        "--max_steps", "-n",
        type=int,
        default=None,
        help="Maximum number of primitive skill calls per episode"
    )
    parser.add_argument(
        "--max_repairs",
        type=int,
        default=5,
        help="Maximum number of repair attempts before giving up (default: 5)"
    )
    parser.add_argument(
        "--code_source", "-c",
        type=str,
        default="remote_llm",
        choices=["static", "local_llm", "remote_llm"],
        help="Where to get task skill code"
    )
    parser.add_argument(
        "--model",
        type=str,
        default="ours",
        choices=["code_agent", "ours", "codex"],
        help="Model type"
    )
    parser.add_argument(
        "--remote_host",
        type=str,
        default="172.17.0.1",
        help="Remote LLM server host"
    )
    parser.add_argument(
        "--remote_port",
        type=int,
        default=5000,
        help="Remote LLM server port"
    )
    parser.add_argument(
        "--task_ids",
        nargs="+",
        type=str,
        default=None,
        help="Specific task IDs to run (e.g., task_000 task_001)"
    )
    parser.add_argument(
        "--difficulty",
        type=str,
        choices=["easy", "mixed", "hard"],
        default=None,
        help="Filter tasks by difficulty"
    )
    parser.add_argument(
        "--no_resume",
        action="store_true",
        help="Do not resume from checkpoint"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility (default: 42)"
    )

    args = parser.parse_args()

    # Seed 
    set_seed(args.seed)

    #   
    if args.output_dir is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        args.output_dir = f"./results/exp_{args.target_robot}_{timestamp}"

    # YAML  
    if not os.path.exists(args.yaml_path):
        print(f"[ERROR] YAML file not found: {args.yaml_path}")
        sys.exit(1)

    tasks_config = load_experiment_tasks(args.yaml_path)
    print(f"[INFO] Loaded {len(tasks_config)} tasks from {args.yaml_path}")

    #  
    run_experiment(
        tasks_config=tasks_config,
        output_dir=args.output_dir,
        mode=args.mode,
        source_robot=args.source_robot,
        target_robot=args.target_robot,
        max_skill_calls=args.max_steps,
        max_repair_attempts=args.max_repairs,
        code_source=args.code_source,
        model=args.model,
        remote_host=args.remote_host,
        remote_port=args.remote_port,
        task_ids=args.task_ids,
        difficulty=args.difficulty,
        resume_from_checkpoint=not args.no_resume,
    )


if __name__ == "__main__":
    main()
