import argparse
import json
import os
import random
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

import numpy as np
import yaml

# ============================================================================
# Seed Fixing for Reproducibility
# ============================================================================

def set_seed(seed: int = 42) -> None:
    # Python random
    random.seed(seed)

    # NumPy
    np.random.seed(seed)

    # Environment variables for hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)

    # PyTorch (if available)
    try:
        import torch
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except ImportError:
        pass

    # TensorFlow (if available)
    try:
        import tensorflow as tf
        tf.random.set_seed(seed)
    except ImportError:
        pass

    # Genesis (if available)
    try:
        import genesis as gs
        gs.set_seed(seed)
    except (ImportError, AttributeError):
        pass

    print(f"[SEED] Random seed set to: {seed}")

from genesis_loop import run_task_episode, EpisodeResult

# ============================================================================
# NumPy Type Conversion
# ============================================================================

def convert_numpy_types(obj: Any) -> Any:
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, np.bool_):
        return bool(obj)
    elif isinstance(obj, (np.integer, np.int_, np.int64, np.int32, np.int16, np.int8)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float64, np.float32, np.float16)):
        return float(obj)
    elif isinstance(obj, dict):
        return {k: convert_numpy_types(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [convert_numpy_types(item) for item in obj]
    return obj

# ============================================================================
# Result Data Structures
# ============================================================================

class TaskResult:
    def __init__(self):
        self.task_id: str = ""
        self.task_name: str = ""
        self.difficulty: str = ""
        self.description: str = ""
        self.success: bool = False
        self.execution_time_sec: float = 0.0
        self.llm_calls: List[Dict[str, Any]] = []
        self.total_llm_time_sec: float = 0.0
        self.error: Optional[str] = None
        self.error_type: Optional[str] = None
        self.subtask_results: List[Dict[str, Any]] = []

    def to_dict(self) -> Dict[str, Any]:
        return {
            "task_id": self.task_id,
            "task_name": self.task_name,
            "difficulty": self.difficulty,
            "description": self.description,
            "success": self.success,
            "execution_time_sec": self.execution_time_sec,
            "llm_calls": self.llm_calls,
            "total_llm_time_sec": self.total_llm_time_sec,
            "error": self.error,
            "error_type": self.error_type,
            "subtask_results": self.subtask_results,
            "subtask_summary": self._compute_subtask_summary(),
        }

    def _compute_subtask_summary(self) -> Dict[str, Any]:
        if not self.subtask_results:
            return {}

        total = len(self.subtask_results)
        successful = sum(1 for s in self.subtask_results if s.get("success", False))
        skipped = sum(1 for s in self.subtask_results if s.get("skipped", False))
        total_repairs = sum(s.get("repair_count", 0) for s in self.subtask_results)

        return {
            "total_subtasks": total,
            "successful_subtasks": successful,
            "skipped_subtasks": skipped,
            "failed_subtasks": total - successful - skipped,
            "total_repairs": total_repairs,
            "subtask_success_rate": successful / total if total > 0 else 0.0,
        }

class ExperimentResult:
    def __init__(self):
        self.experiment_id: str = ""
        self.start_time: str = ""
        self.end_time: str = ""
        self.config: Dict[str, Any] = {}
        self.tasks: List[TaskResult] = []
        self.overall_success_rate: float = 0.0

    def compute_metrics(self):
        if not self.tasks:
            return

        success_count = sum(1 for t in self.tasks if t.success)
        self.overall_success_rate = success_count / len(self.tasks)

    def to_dict(self) -> Dict[str, Any]:
        self.compute_metrics()
        return {
            "experiment_id": self.experiment_id,
            "start_time": self.start_time,
            "end_time": self.end_time,
            "config": self.config,
            "overall_success_rate": self.overall_success_rate,
            "total_tasks": len(self.tasks),
            "successful_tasks": sum(1 for t in self.tasks if t.success),
            "tasks": [t.to_dict() for t in self.tasks],
        }

# ============================================================================
# File I/O
# ============================================================================

def load_experiment_tasks(yaml_path: str) -> List[Dict[str, Any]]:
    with open(yaml_path, "r", encoding="utf-8") as f:
        data = yaml.safe_load(f)
    return data.get("tasks", [])

def save_result_json(result: Dict[str, Any], output_path: str):
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    result_converted = convert_numpy_types(result)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(result_converted, f, indent=2, ensure_ascii=False)
    print(f"[SAVE] Result saved to: {output_path}")

def load_checkpoint(checkpoint_path: str) -> Optional[ExperimentResult]:
    if os.path.exists(checkpoint_path):
        with open(checkpoint_path, "r", encoding="utf-8") as f:
            data = json.load(f)

        result = ExperimentResult()
        result.experiment_id = data.get("experiment_id", "")
        result.start_time = data.get("start_time", "")
        result.config = data.get("config", {})

        for task_data in data.get("tasks", []):
            task_result = TaskResult()
            task_result.task_id = task_data.get("task_id", "")
            task_result.task_name = task_data.get("task_name", "")
            task_result.difficulty = task_data.get("difficulty", "")
            task_result.description = task_data.get("description", "")
            task_result.success = task_data.get("success", False)
            task_result.execution_time_sec = task_data.get("execution_time_sec", 0.0)
            task_result.llm_calls = task_data.get("llm_calls", [])
            task_result.total_llm_time_sec = task_data.get("total_llm_time_sec", 0.0)
            task_result.error = task_data.get("error")
            task_result.error_type = task_data.get("error_type")
            result.tasks.append(task_result)

        print(f"[CHECKPOINT] Loaded checkpoint from: {checkpoint_path}")
        return result

    return None

# ============================================================================
# LLM Call Logger
# ============================================================================

class LLMCallLogger:

    def __init__(self, log_file: str):
        self.log_file = log_file
        os.makedirs(os.path.dirname(log_file), exist_ok=True)

        # Write header
        with open(self.log_file, "w", encoding="utf-8") as f:
            f.write("timestamp,task,subtask,action,duration_sec,num_tokens\n")

    def log(self, call_info: Dict[str, Any]):
        timestamp = datetime.now().isoformat()
        task = call_info.get("task", "")
        subtask = call_info.get("subtask", "")
        action = call_info.get("action", "")
        duration = call_info.get("duration_sec", 0.0)
        num_tokens = call_info.get("num_tokens", "")

        with open(self.log_file, "a", encoding="utf-8") as f:
            f.write(f"{timestamp},{task},{subtask},{action},{duration:.3f},{num_tokens}\n")

        tokens_str = f", tokens: {num_tokens}" if num_tokens else ""
        print(f"[LLM_CALL] {task}/{subtask} - {action}: {duration:.2f}s{tokens_str}")

# ============================================================================
# Experiment Runner
# ============================================================================

def run_experiment(
    tasks_config: List[Dict[str, Any]],
    output_dir: str,
    source_robot: str = "panda",
    target_robot: str = "suction",
    max_subtask_repairs: int = 5,
    code_source: str = "static",
    model: str = "code_agent",
    remote_host: str = "127.0.0.1",
    remote_port: int = 9000,
    task_ids: Optional[List[str]] = None,
    difficulty: Optional[str] = None,
    resume_from_checkpoint: bool = True,
    show_viewer: bool = False,
    skip_env: bool = False,
    enable_background_validation: bool = True,
    enable_code_cache: bool = False,
) -> ExperimentResult:
    # Generate experiment ID
    experiment_id = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Setup output directory
    os.makedirs(output_dir, exist_ok=True)

    # File paths
    checkpoint_path = os.path.join(output_dir, "checkpoint.json")
    final_result_path = os.path.join(output_dir, "final_result.json")
    llm_log_path = os.path.join(output_dir, "llm_calls.csv")

    # Create LLM logger
    llm_logger = LLMCallLogger(llm_log_path)

    # Try to load checkpoint
    result = None
    completed_tasks = set()

    if resume_from_checkpoint:
        result = load_checkpoint(checkpoint_path)
        if result is not None:
            experiment_id = result.experiment_id
            completed_tasks = {t.task_id for t in result.tasks}
            print(f"[RESUME] Resuming experiment {experiment_id}")
            print(f"[RESUME] Already completed tasks: {completed_tasks}")

    if result is None:
        result = ExperimentResult()
        result.experiment_id = experiment_id
        result.start_time = datetime.now().isoformat()

    # Save config
    result.config = {
        "source_robot": source_robot,
        "target_robot": target_robot,
        "max_subtask_repairs": max_subtask_repairs,
        "code_source": code_source,
        "model": model,
        "remote_host": remote_host,
        "remote_port": remote_port,
        "task_ids_filter": task_ids,
        "difficulty_filter": difficulty,
    }

    # Filter tasks
    filtered_tasks = tasks_config
    if task_ids:
        filtered_tasks = [t for t in filtered_tasks if t.get("id") in task_ids]
    if difficulty:
        filtered_tasks = [t for t in filtered_tasks if t.get("difficulty") == difficulty]

    print(f"\n{'='*60}")
    print(f"[EXPERIMENT] Starting experiment: {experiment_id}")
    print(f"[EXPERIMENT] Total tasks: {len(filtered_tasks)}")
    print(f"[EXPERIMENT] Source robot: {source_robot}")
    print(f"[EXPERIMENT] Target robot: {target_robot}")
    print(f"[EXPERIMENT] Code source: {code_source}")
    print(f"[EXPERIMENT] Output directory: {output_dir}")
    print(f"{'='*60}\n")

    # Execute each task
    for task_idx, task_config in enumerate(filtered_tasks):
        task_id = task_config.get("id", f"task_{task_idx:03d}")
        task_name = task_config.get("task_name", "")
        task_difficulty = task_config.get("difficulty", "unknown")
        task_description = task_config.get("description", "")

        # Skip already completed tasks
        if task_id in completed_tasks:
            print(f"[SKIP] Task {task_id} already completed")
            continue

        print(f"\n{'='*60}")
        print(f"[TASK {task_idx + 1}/{len(filtered_tasks)}] {task_id}: {task_name}")
        print(f"[TASK] Difficulty: {task_difficulty}")
        print(f"[TASK] Description: {task_description}")
        print(f"{'='*60}\n")

        task_result = TaskResult()
        task_result.task_id = task_id
        task_result.task_name = task_name
        task_result.difficulty = task_difficulty
        task_result.description = task_description

        try:
            # Run task episode
            episode_result: EpisodeResult = run_task_episode(
                task_name=task_name,
                source_robot=source_robot,
                target_robot=target_robot,
                max_subtask_repairs=max_subtask_repairs,
                code_source=code_source,
                model=model,
                remote_host=remote_host,
                remote_port=remote_port,
                show_viewer=show_viewer,
                llm_call_logger=llm_logger.log,
                skip_env=skip_env,
                enable_background_validation=enable_background_validation,
                enable_code_cache=enable_code_cache,
            )

            # Copy results
            task_result.success = episode_result.success
            task_result.execution_time_sec = episode_result.execution_time
            task_result.llm_calls = episode_result.llm_calls
            task_result.total_llm_time_sec = sum(
                c.get("duration_sec", 0) for c in episode_result.llm_calls
            )
            task_result.error = episode_result.error
            task_result.error_type = episode_result.error_type
            # Copy subtask results
            task_result.subtask_results = [
                s.to_dict() for s in episode_result.subtask_results
            ]

            # Save subtask codes (each subtask in its own file)
            if episode_result.subtask_results:
                task_num = task_id.replace("task_", "") if task_id.startswith("task_") else task_id
                subtask_dir = os.path.join(output_dir, f"subtasks_{target_robot}", f"{task_num}_{task_name}")
                os.makedirs(subtask_dir, exist_ok=True)

                for idx, subtask_result in enumerate(episode_result.subtask_results):
                    subtask_name = subtask_result.subtask_name
                    generated_code = subtask_result.generated_code
                    original_code = subtask_result.original_code
                    repair_history = subtask_result.repair_history

                    if generated_code:
                        # Include obj_name in filename if present
                        obj_suffix = f"_{subtask_result.obj_name}" if subtask_result.obj_name else ""
                        subtask_filename = f"{idx:02d}_{subtask_name}{obj_suffix}.py"
                        subtask_filepath = os.path.join(subtask_dir, subtask_filename)

                        with open(subtask_filepath, "w", encoding="utf-8") as f:
                            f.write(generated_code)

                        # Save original code if there were repairs
                        if original_code and original_code != generated_code:
                            original_filepath = os.path.join(subtask_dir, f"{idx:02d}_{subtask_name}{obj_suffix}_original.py")
                            with open(original_filepath, "w", encoding="utf-8") as f:
                                f.write(original_code)

                        # Save repair history if any repairs occurred
                        if repair_history:
                            history_filepath = os.path.join(subtask_dir, f"{idx:02d}_{subtask_name}{obj_suffix}_repairs.json")
                            repair_data = {
                                "subtask_name": subtask_name,
                                "obj_name": subtask_result.obj_name,
                                "repair_count": subtask_result.repair_count,
                                "bg_repair_count": subtask_result.bg_repair_count,
                                "repairs": [r.to_dict() for r in repair_history],
                            }
                            with open(history_filepath, "w", encoding="utf-8") as f:
                                json.dump(repair_data, f, indent=2)

                print(f"  - Subtask codes saved to: {subtask_dir}")

            status = "SUCCESS" if task_result.success else "FAILED"
            print(f"\n[TASK RESULT] {task_name}: {status}")
            print(f"  - Execution time: {task_result.execution_time_sec:.2f}s")
            print(f"  - LLM time: {task_result.total_llm_time_sec:.2f}s")
            print(f"  - LLM calls: {len(task_result.llm_calls)}")

        except Exception as e:
            import traceback
            print(f"[TASK ERROR] {task_name}: {e}")
            traceback.print_exc()
            task_result.success = False
            task_result.error = str(e)
            task_result.error_type = type(e).__name__

        result.tasks.append(task_result)

        # Save checkpoint after each task
        save_result_json(result.to_dict(), checkpoint_path)

    # Save final result
    result.end_time = datetime.now().isoformat()
    result.compute_metrics()
    save_result_json(result.to_dict(), final_result_path)

    # Print summary
    print(f"\n{'='*60}")
    print("[EXPERIMENT COMPLETE]")
    print(f"  - Overall success rate: {result.overall_success_rate:.2%}")
    print(f"  - Successful: {sum(1 for t in result.tasks if t.success)}/{len(result.tasks)}")

    # Print subtask success rate per task
    print(f"\n  [Subtask Success Rate]")
    total_subtasks = 0
    successful_subtasks = 0
    for task in result.tasks:
        summary = task._compute_subtask_summary()
        if summary and summary["total_subtasks"] > 0:
            total_subtasks += summary["total_subtasks"]
            successful_subtasks += summary["successful_subtasks"]
            print(f"    {task.task_name}: {summary['subtask_success_rate']:.2%} ({summary['successful_subtasks']}/{summary['total_subtasks']})")
    if total_subtasks > 0:
        overall_subtask_rate = successful_subtasks / total_subtasks
        print(f"  - Overall subtask success rate: {overall_subtask_rate:.2%} ({successful_subtasks}/{total_subtasks})")

    print(f"\n  - Results saved to: {output_dir}")
    print(f"{'='*60}\n")

    return result

# ============================================================================
# Main
# ============================================================================

def main():
    parser = argparse.ArgumentParser(
        description="Genesis Experiment Runner - Run experiments from experiment_tasks.yml"
    )
    parser.add_argument(
        "--yaml_path",
        type=str,
        default="./static/experiment_tasks.yml",
        help="Path to experiment_tasks.yml"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help="Output directory for results (default: results/exp_YYYYMMDD_HHMMSS)"
    )
    parser.add_argument(
        "--source_robot", "-s",
        choices=["panda", "suction", "robotiq"],
        default="panda",
        help="Source robot type (for reference code)"
    )
    parser.add_argument(
        "--target_robot", "-r",
        choices=["panda", "suction", "robotiq"],
        default="panda",
        help="Target robot type (for execution)"
    )
    parser.add_argument(
        "--code_source", "-c",
        choices=["static", "remote_llm"],
        default="remote_llm",
        help="Where to get task code"
    )
    parser.add_argument(
        "--model",
        choices=["code_agent", "ours"],
        default="code_agent",
        help="Model type for code generation"
    )
    parser.add_argument(
        "--max_repairs",
        type=int,
        default=5,
        help="Maximum repair attempts (default: 5)"
    )
    parser.add_argument(
        "--remote_host",
        type=str,
        default="127.0.0.1",
        help="Remote LLM server host"
    )
    parser.add_argument(
        "--remote_port",
        type=int,
        default=6000,
        help="Remote LLM server port"
    )
    parser.add_argument(
        "--task_ids",
        nargs="+",
        type=str,
        default=None,
        help="Specific task IDs to run (e.g., task_000 task_001)"
    )
    parser.add_argument(
        "--difficulty",
        type=str,
        choices=["easy", "medium", "hard"],
        default=None,
        help="Filter tasks by difficulty"
    )
    parser.add_argument(
        "--no_resume",
        action="store_true",
        help="Do not resume from checkpoint"
    )
    parser.add_argument(
        "--viewer",
        action="store_true",
        help="Show Genesis viewer"
    )
    parser.add_argument(
        "--skip_env",
        action="store_true",
        help="Skip Genesis environment initialization (faster code generation)"
    )
    parser.add_argument(
        "--no_bg_validation",
        action="store_true",
        help="Disable background validation in ours mode"
    )
    parser.add_argument(
        "--enable_code_cache",
        action="store_true",
        help="Enable subtask code caching (default: disabled)"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility (default: 42)"
    )

    args = parser.parse_args()

    # Set random seed for reproducibility
    set_seed(args.seed)

    # Setup output directory
    if args.output_dir is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        args.output_dir = f"./results/exp_{timestamp}"

    # Load YAML file
    if not os.path.exists(args.yaml_path):
        print(f"[ERROR] YAML file not found: {args.yaml_path}")
        sys.exit(1)

    tasks_config = load_experiment_tasks(args.yaml_path)
    print(f"[INFO] Loaded {len(tasks_config)} tasks from {args.yaml_path}")

    # Run experiment
    run_experiment(
        tasks_config=tasks_config,
        output_dir=args.output_dir,
        source_robot=args.source_robot,
        target_robot=args.target_robot,
        max_subtask_repairs=args.max_repairs,
        code_source=args.code_source,
        model=args.model,
        remote_host=args.remote_host,
        remote_port=args.remote_port,
        task_ids=args.task_ids,
        difficulty=args.difficulty,
        resume_from_checkpoint=not args.no_resume,
        show_viewer=args.viewer,
        skip_env=args.skip_env,
        enable_background_validation=not args.no_bg_validation,
        enable_code_cache=args.enable_code_cache,
    )

if __name__ == "__main__":
    main()
