"""
Enhanced logging system for OfficeArena with comprehensive action logging and timing.
"""

import logging
import json
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from PIL import Image, ImageDraw
import io


class OfficeArenaLogger:
    """
    Comprehensive logger for OfficeArena that tracks:
    - All agent actions with timestamps
    - Time taken for each action and total task time
    - Mouse click positions on screenshots
    - Creates both file logs and annotated screenshots
    """

    def __init__(self, log_dir: Optional[str] = None, task_id: Optional[str] = None):
        """
        Initialize the logger.

        Args:
            log_dir: Directory to save log files. If None, creates logs in current directory
            task_id: Unique identifier for the task (used in log file names)
        """
        self.task_id = task_id or f"task_{int(time.time())}"
        self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
        self.log_dir.mkdir(exist_ok=True)

        # Create log file paths
        self.log_file_path = self.log_dir / f"{self.task_id}_detailed.log"
        self.json_log_path = self.log_dir / f"{self.task_id}_actions.json"
        self.timing_log_path = self.log_dir / f"{self.task_id}_timing.log"

        # Initialize logging
        self._setup_logging()

        # Tracking variables
        self.start_time: Optional[float] = None
        self.last_action_time: Optional[float] = None
        self.action_log: List[Dict[str, Any]] = []
        self.step_counter = 0

        # Screenshot directories
        self.screenshot_dir = self.log_dir / f"{self.task_id}_screenshots"
        self.screenshot_dir.mkdir(exist_ok=True)

        self.annotated_screenshot_dir = (
            self.log_dir / f"{self.task_id}_screenshots_annotated"
        )
        self.annotated_screenshot_dir.mkdir(exist_ok=True)

    def _setup_logging(self):
        """Set up file logging configuration."""
        # Create custom logger
        self.logger = logging.getLogger(f"OfficeArena_{self.task_id}")
        self.logger.setLevel(logging.DEBUG)

        # Prevent duplicate handlers
        if self.logger.handlers:
            self.logger.handlers.clear()

        # File handler for detailed logs
        file_handler = logging.FileHandler(
            self.log_file_path, mode="w", encoding="utf-8"
        )
        file_handler.setLevel(logging.DEBUG)

        # Console handler
        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)

        # Formatter
        formatter = logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",
        )
        file_handler.setFormatter(formatter)
        console_handler.setFormatter(formatter)

        self.logger.addHandler(file_handler)
        self.logger.addHandler(console_handler)

    def start_task(self, file_path: str, task_instruction: str):
        """
        Log the start of a task.

        Args:
            file_path: Path to the Office file
            task_instruction: Description of the task
        """
        self.start_time = time.time()
        self.last_action_time = self.start_time

        start_info = {
            "event": "task_start",
            "timestamp": self.start_time,
            "datetime": datetime.fromtimestamp(self.start_time).isoformat(),
            "file_path": file_path,
            "task_instruction": task_instruction,
            "task_id": self.task_id,
        }

        self.action_log.append(start_info)

        self.logger.info("=" * 80)
        self.logger.info(f"🚀 TASK STARTED: {self.task_id}")
        self.logger.info(f"📁 File: {file_path}")
        self.logger.info(f"📋 Task: {task_instruction}")
        self.logger.info(
            f"⏰ Start time: {datetime.fromtimestamp(self.start_time).strftime('%Y-%m-%d %H:%M:%S')}"
        )
        self.logger.info("=" * 80)

    def log_action(
        self,
        action: str,
        action_result: Dict[str, Any],
        screenshot: Optional[bytes] = None,
    ):
        """
        Log an agent action with timing and result information.

        Args:
            action: The action string or description
            action_result: Result dictionary from action execution
            screenshot: Screenshot bytes (optional)
        """
        current_time = time.time()
        self.step_counter += 1

        # Calculate timing
        action_duration = (
            current_time - self.last_action_time if self.last_action_time else 0
        )
        total_elapsed = current_time - self.start_time if self.start_time else 0

        # Extract action details
        action_type = action_result.get("action_type", "unknown")
        success = action_result.get("success", False)
        task_completed = action_result.get("task_completed", False)

        # Parse action for mouse clicks
        click_position = None
        if action_type in ["left_click", "right_click", "click"]:
            click_position = self._extract_click_position(action)

        # Create action log entry
        action_entry = {
            "event": "action_executed",
            "step": self.step_counter,
            "timestamp": current_time,
            "datetime": datetime.fromtimestamp(current_time).isoformat(),
            "action": action,
            "action_type": action_type,
            "success": success,
            "task_completed": task_completed,
            "action_duration_seconds": round(action_duration, 3),
            "total_elapsed_seconds": round(total_elapsed, 3),
            "click_position": click_position,
            "full_result": action_result,
        }

        self.action_log.append(action_entry)

        # Log to file and console
        status_emoji = "✅" if success else "❌"
        time_emoji = "⏱️"

        self.logger.info(
            f"{status_emoji} Step {self.step_counter}: {action_type.upper()}"
        )
        self.logger.debug(f"   📝 Action: {action}")
        self.logger.info(
            f"   {time_emoji} Duration: {action_duration:.3f}s | Total: {total_elapsed:.3f}s"
        )

        if click_position:
            self.logger.info(
                f"   🖱️  Click position: ({click_position['x']}, {click_position['y']})"
            )

        if not success:
            error_msg = action_result.get("error", "Unknown error")
            self.logger.warning(f"   ⚠️  Error: {error_msg}")

        if task_completed:
            self.logger.info(f"   🎉 Task marked as completed!")

        # Handle screenshot saving and annotation
        if screenshot:
            self._save_screenshots(screenshot, click_position, action_type)

        # Update timing
        self.last_action_time = current_time

    def _extract_click_position(self, action: str) -> Optional[Dict[str, int]]:
        """
        Extract click position from action string.

        Args:
            action: Action string that may contain coordinates

        Returns:
            Dictionary with x, y coordinates or None
        """
        try:
            # Try to parse JSON action
            if action.strip().startswith("{"):
                action_data = json.loads(action)

                # Look for coordinates in different possible structures
                for item in action_data.get("output", []):
                    if item.get("type") == "computer_call":
                        action_details = item.get("action", {})
                        if "x" in action_details and "y" in action_details:
                            return {"x": action_details["x"], "y": action_details["y"]}

            # Fallback: try to extract coordinates from string representation
            import re

            coord_pattern = r'"x":\s*(\d+).*?"y":\s*(\d+)'
            match = re.search(coord_pattern, action)
            if match:
                return {"x": int(match.group(1)), "y": int(match.group(2))}

        except Exception as e:
            self.logger.debug(f"Could not extract click position from action: {e}")

        return None

    def _save_screenshots(
        self,
        screenshot_bytes: bytes,
        click_position: Optional[Dict[str, int]],
        action_type: str,
    ):
        """
        Save both original and annotated screenshots.

        Args:
            screenshot_bytes: Screenshot as bytes
            click_position: Click coordinates if applicable
            action_type: Type of action performed
        """
        try:
            # Save original screenshot
            original_path = (
                self.screenshot_dir / f"step_{self.step_counter:03d}_{action_type}.png"
            )
            with open(original_path, "wb") as f:
                f.write(screenshot_bytes)

            # Create annotated version if there's a click position
            if click_position:
                annotated_path = (
                    self.annotated_screenshot_dir
                    / f"step_{self.step_counter:03d}_{action_type}_annotated.png"
                )
                self._create_annotated_screenshot(
                    screenshot_bytes, click_position, action_type, annotated_path
                )

                self.logger.debug(
                    f"💾 Screenshots saved: {original_path.name} (original), {annotated_path.name} (annotated)"
                )
            else:
                self.logger.debug(f"💾 Screenshot saved: {original_path.name}")

        except Exception as e:
            self.logger.error(f"Failed to save screenshot: {e}")

    def _create_annotated_screenshot(
        self,
        screenshot_bytes: bytes,
        click_position: Dict[str, int],
        action_type: str,
        output_path: Path,
    ):
        """
        Create an annotated screenshot with click position marked.

        Args:
            screenshot_bytes: Original screenshot
            click_position: Click coordinates
            action_type: Type of action (determines circle color)
            output_path: Where to save the annotated image
        """
        try:
            # Load image
            image = Image.open(io.BytesIO(screenshot_bytes))
            draw = ImageDraw.Draw(image)

            x, y = click_position["x"], click_position["y"]

            # Determine circle color based on action type
            if "right_click" in action_type.lower() or action_type == "right_click":
                color = "cyan"
                label = "R"
            else:  # left_click or other click types
                color = "orange"
                label = "L"

            # Draw circle (larger for visibility)
            radius = 15
            circle_bbox = [x - radius, y - radius, x + radius, y + radius]
            draw.ellipse(circle_bbox, outline=color, width=3)

            # Draw smaller filled circle in center
            small_radius = 5
            small_bbox = [
                x - small_radius,
                y - small_radius,
                x + small_radius,
                y + small_radius,
            ]
            draw.ellipse(small_bbox, fill=color)

            # Add label
            try:
                # Try to use a better font if available
                from PIL import ImageFont

                font = ImageFont.load_default()
            except:
                font = None

            # Position label near the circle
            label_x = x + radius + 5
            label_y = y - 10
            draw.text((label_x, label_y), f"{label}({x},{y})", fill=color, font=font)

            # Save annotated image
            image.save(output_path)

        except Exception as e:
            self.logger.error(f"Failed to create annotated screenshot: {e}")

    def end_task(self, success: bool, final_step_count: int):
        """
        Log the end of a task with final timing information.

        Args:
            success: Whether the task was completed successfully
            final_step_count: Final number of steps taken
        """
        end_time = time.time()
        total_duration = end_time - self.start_time if self.start_time else 0

        end_info = {
            "event": "task_end",
            "timestamp": end_time,
            "datetime": datetime.fromtimestamp(end_time).isoformat(),
            "success": success,
            "total_steps": final_step_count,
            "total_duration_seconds": round(total_duration, 3),
            "total_duration_formatted": self._format_duration(total_duration),
            "task_id": self.task_id,
        }

        self.action_log.append(end_info)

        # Log summary
        status_emoji = "🎉" if success else "💔"
        self.logger.info("=" * 80)
        self.logger.info(f"{status_emoji} TASK COMPLETED: {self.task_id}")
        self.logger.info(f"✅ Success: {success}")
        self.logger.info(f"📊 Total steps: {final_step_count}")
        self.logger.info(f"⏱️  Total time: {self._format_duration(total_duration)}")
        self.logger.info(
            f"📈 Average time per step: {(total_duration/final_step_count):.3f}s"
        )
        self.logger.info("=" * 80)

        # Save JSON log
        self._save_json_log()

        # Save timing summary
        self._save_timing_summary(total_duration, final_step_count)

    def _format_duration(self, seconds: float) -> str:
        """Format duration in a human-readable way."""
        if seconds < 60:
            return f"{seconds:.3f}s"
        elif seconds < 3600:
            minutes = int(seconds // 60)
            remaining_seconds = seconds % 60
            return f"{minutes}m {remaining_seconds:.3f}s"
        else:
            hours = int(seconds // 3600)
            remaining_seconds = seconds % 3600
            minutes = int(remaining_seconds // 60)
            seconds = remaining_seconds % 60
            return f"{hours}h {minutes}m {seconds:.3f}s"

    def _save_json_log(self):
        """Save complete action log as JSON."""
        try:
            # Create a JSON-serializable version of the action log
            json_serializable_log = []
            for entry in self.action_log:
                serializable_entry = {}
                for key, value in entry.items():
                    if key == "full_result" and isinstance(value, dict):
                        # Clean the full_result to remove non-serializable objects
                        cleaned_result = {}
                        for result_key, result_value in value.items():
                            if isinstance(result_value, bytes):
                                cleaned_result[result_key] = f"<bytes object: {len(result_value)} bytes>"
                            elif isinstance(result_value, dict):
                                # Recursively clean nested dictionaries
                                cleaned_result[result_key] = self._clean_dict_for_json(result_value)
                            else:
                                try:
                                    # Test if the value is JSON serializable
                                    json.dumps(result_value)
                                    cleaned_result[result_key] = result_value
                                except (TypeError, ValueError):
                                    cleaned_result[result_key] = str(result_value)
                        serializable_entry[key] = cleaned_result
                    elif isinstance(value, bytes):
                        serializable_entry[key] = f"<bytes object: {len(value)} bytes>"
                    else:
                        try:
                            # Test if the value is JSON serializable
                            json.dumps(value)
                            serializable_entry[key] = value
                        except (TypeError, ValueError):
                            serializable_entry[key] = str(value)
                json_serializable_log.append(serializable_entry)
            
            with open(self.json_log_path, "w", encoding="utf-8") as f:
                json.dump(json_serializable_log, f, indent=2, ensure_ascii=False)
            self.logger.debug(f"📄 JSON log saved: {self.json_log_path}")
        except Exception as e:
            self.logger.error(f"Failed to save JSON log: {e}")

    def _clean_dict_for_json(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Recursively clean a dictionary to make it JSON serializable."""
        cleaned = {}
        for key, value in data.items():
            if isinstance(value, bytes):
                cleaned[key] = f"<bytes object: {len(value)} bytes>"
            elif isinstance(value, dict):
                cleaned[key] = self._clean_dict_for_json(value)
            elif isinstance(value, (list, tuple)):
                cleaned[key] = self._clean_list_for_json(value)
            else:
                try:
                    # Test if the value is JSON serializable
                    json.dumps(value)
                    cleaned[key] = value
                except (TypeError, ValueError):
                    cleaned[key] = str(value)
        return cleaned

    def _clean_list_for_json(self, data: List[Any]) -> List[Any]:
        """Recursively clean a list to make it JSON serializable."""
        cleaned = []
        for item in data:
            if isinstance(item, bytes):
                cleaned.append(f"<bytes object: {len(item)} bytes>")
            elif isinstance(item, dict):
                cleaned.append(self._clean_dict_for_json(item))
            elif isinstance(item, (list, tuple)):
                cleaned.append(self._clean_list_for_json(item))
            else:
                try:
                    # Test if the value is JSON serializable
                    json.dumps(item)
                    cleaned.append(item)
                except (TypeError, ValueError):
                    cleaned.append(str(item))
        return cleaned

    def _save_timing_summary(self, total_duration: float, final_step_count: int):
        """Save timing summary to separate file."""
        try:
            with open(self.timing_log_path, "w", encoding="utf-8") as f:
                f.write(f"Task Timing Summary - {self.task_id}\n")
                f.write("=" * 50 + "\n\n")
                f.write(f"Total Duration: {self._format_duration(total_duration)}\n")
                f.write(f"Total Steps: {final_step_count}\n")
                f.write(
                    f"Average Time per Step: {(total_duration/final_step_count):.3f}s\n\n"
                )

                f.write("Step-by-Step Timing:\n")
                f.write("-" * 30 + "\n")

                for entry in self.action_log:
                    if entry.get("event") == "action_executed":
                        step = entry["step"]
                        action_type = entry["action_type"]
                        duration = entry["action_duration_seconds"]
                        f.write(
                            f"Step {step:3d}: {action_type:15s} - {duration:6.3f}s\n"
                        )

            self.logger.debug(f"⏰ Timing summary saved: {self.timing_log_path}")
        except Exception as e:
            self.logger.error(f"Failed to save timing summary: {e}")

    def log_error(self, error: str, context: Optional[str] = None):
        """
        Log an error with context.

        Args:
            error: Error message
            context: Additional context information
        """
        self.logger.error(f"❌ Error: {error}")
        if context:
            self.logger.error(f"   Context: {context}")

    def log_info(self, message: str):
        """Log an informational message."""
        self.logger.info(f"ℹ️  {message}")

    def get_log_paths(self) -> Dict[str, str]:
        """Get all log file paths."""
        return {
            "detailed_log": str(self.log_file_path),
            "json_log": str(self.json_log_path),
            "timing_log": str(self.timing_log_path),
            "screenshots_dir": str(self.screenshot_dir),
            "annotated_screenshots_dir": str(self.annotated_screenshot_dir),
        }

    def close(self):
        """Close the logger and clean up resources."""
        # Close file handlers
        for handler in self.logger.handlers:
            if isinstance(handler, logging.FileHandler):
                handler.close()

        # Clear handlers to prevent memory leaks
        self.logger.handlers.clear()
