import gzip
import json
import logging
import os
import pickle
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

logger = logging.getLogger(__name__)


def save_chat_messages_to_pkl(
    chat_data: Dict[str, Any], output_dir: str, task_name: str, seed: str, timestamp: str = None
):
    """
    Save chat messages and related data to compressed pickle files.

    Args:
        chat_data: Dictionary containing chat messages and metadata
        output_dir: Directory to save the files
        task_name: Name of the task
        seed: Seed value for the task
        timestamp: Optional timestamp, defaults to current time
    """
    if timestamp is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Create a unique filename for this chat session
    filename = f"chat_{task_name}_{seed}_{timestamp}.pkl.gz"
    filepath = os.path.join(output_dir, filename)

    try:
        with gzip.open(filepath, "wb") as f:
            pickle.dump(chat_data, f)
        logger.info(f"Chat messages saved to {filepath}")
        return filepath
    except Exception as e:
        logger.error(f"Failed to save chat messages to {filepath}: {e}")
        return None


def create_chat_data(
    task_name: str,
    seed: str,
    trace_info: List[Dict[str, Any]],
    system_prompt: str,
    user_prompt: str,
    llm_response: Any,
    model_name: str,
    agent_name: str,
    domain_name: str,
    user_name: str,
    source: str,
    metadata: Dict[str, Any] = None,
    extracted_components: Dict[str, str] = None,
) -> Dict[str, Any]:
    """
    Create a structured chat data object for saving.

    Args:
        task_name: Name of the task
        seed: Seed value for the task
        trace_info: The trace information used for hint generation
        system_prompt: System prompt sent to the LLM
        user_prompt: User prompt sent to the LLM
        llm_response: Response from the LLM
        model_name: Name of the LLM model used
        agent_name: Name of the agent
        domain_name: Domain name
        user_name: User name
        source: Source of the hint generation
        metadata: Additional metadata to include

    Returns:
        Dictionary containing all chat data for saving
    """
    chat_data = {
        "metadata": {
            "task_name": task_name,
            "seed": seed,
            "timestamp": datetime.now().isoformat(),
            "model_name": model_name,
            "agent_name": agent_name,
            "domain_name": domain_name,
            "user_name": user_name,
            "source": source,
            "trace_count": len(trace_info),
            "trace_steps": sum(len(trace) for trace in trace_info),
        },
        "prompts": {
            "system_prompt": system_prompt,
            "user_prompt": user_prompt,
        },
        "llm_response": {
            "raw_response": str(llm_response),
            "think": getattr(llm_response, "think", ""),
            "response_type": type(llm_response).__name__,
        },
        "trace_data": trace_info,
        "generated_hint": getattr(llm_response, "think", str(llm_response)).strip(),
        "extracted_components": extracted_components or {"thinking": "", "hint": ""},
    }

    # Add any additional metadata
    if metadata:
        chat_data["metadata"].update(metadata)

    return chat_data


def load_all_chat_pickles(root_dir, hint_db_path=None):
    """Load all chat message pickle files from a directory."""
    chat_data = []

    # First try to find chat messages in the hint database folder
    if hint_db_path:
        hint_db_dir = os.path.dirname(hint_db_path)
        if hint_db_dir:
            chat_dir = os.path.join(hint_db_dir, "chat_messages")
            if os.path.exists(chat_dir):
                logger.info(f"Loading chat messages from hint database folder: {chat_dir}")
                for dirpath, _, filenames in os.walk(chat_dir):
                    for fname in filenames:
                        if fname.endswith(".pkl.gz") and fname.startswith("chat_"):
                            fpath = os.path.join(dirpath, fname)
                            try:
                                with gzip.open(fpath, "rb") as f:
                                    compressed_data = f.read()
                                data = pickle.loads(compressed_data)  # type: ignore
                                chat_data.append({"file": fpath, "data": data})
                            except Exception as e:
                                logger.info(f"Failed to load {fpath}: {e}")
                if chat_data:
                    return chat_data

    # Fallback to looking in the root_dir/chat_messages subdirectory
    chat_dir = os.path.join(root_dir, "chat_messages")
    if not os.path.exists(chat_dir):
        logger.info(f"Chat messages directory not found: {chat_dir}")
        return chat_data

    logger.info(f"Loading chat messages from traces folder: {chat_dir}")
    for dirpath, _, filenames in os.walk(chat_dir):
        for fname in filenames:
            if fname.endswith(".pkl.gz") and fname.startswith("chat_"):
                fpath = os.path.join(dirpath, fname)
                try:
                    with gzip.open(fpath, "rb") as f:
                        compressed_data = f.read()
                    data = pickle.loads(compressed_data)  # type: ignore
                    chat_data.append({"file": fpath, "data": data})
                except Exception as e:
                    logger.info(f"Failed to load {fpath}: {e}")
    return chat_data


def analyze_chat_messages(chat_data):
    """Analyze saved chat messages to extract insights."""
    analysis = {
        "total_chats": len(chat_data),
        "tasks_covered": set(),
        "models_used": set(),
        "avg_prompt_length": 0,
        "avg_response_length": 0,
        "hint_quality_metrics": [],
    }

    total_prompt_length = 0
    total_response_length = 0

    for chat in chat_data:
        data = chat["data"]
        metadata = data.get("metadata", {})

        # Track tasks and models
        analysis["tasks_covered"].add(metadata.get("task_name", "unknown"))
        analysis["models_used"].add(metadata.get("model_name", "unknown"))

        # Track prompt and response lengths
        prompts = data.get("prompts", {})
        system_prompt = prompts.get("system_prompt", "")
        user_prompt = prompts.get("user_prompt", "")
        total_prompt_length += len(system_prompt) + len(user_prompt)

        llm_response = data.get("llm_response", {})
        response_text = llm_response.get("think", "") or llm_response.get("raw_response", "")
        total_response_length += len(response_text)

        # Track hint quality (basic metrics)
        generated_hint = data.get("generated_hint", "")
        if generated_hint:
            analysis["hint_quality_metrics"].append(
                {
                    "task": metadata.get("task_name", "unknown"),
                    "hint_length": len(generated_hint),
                    "has_actionable_content": any(
                        word in generated_hint.lower()
                        for word in ["click", "type", "select", "submit", "navigate"]
                    ),
                    "timestamp": metadata.get("timestamp", ""),
                }
            )

    # Calculate averages
    if analysis["total_chats"] > 0:
        analysis["avg_prompt_length"] = total_prompt_length / analysis["total_chats"]
        analysis["avg_response_length"] = total_response_length / analysis["total_chats"]

    # Convert sets to lists for JSON serialization
    analysis["tasks_covered"] = list(analysis["tasks_covered"])
    analysis["models_used"] = list(analysis["models_used"])

    return analysis


def export_chat_messages_to_json(
    traces_folder: str,
    output_file: str = None,
    include_trace_data: bool = False,
    hint_db_path: str = None,
) -> str:
    """
    Export chat messages from JephHinter to JSON format.

    Args:
        traces_folder: Path to the traces folder containing chat_messages subdirectory
        output_file: Output JSON file path (default: chat_messages_export.json in hint database folder or traces folder)
        include_trace_data: Whether to include the full trace data (can make file very large)
        hint_db_path: Optional path to hint database to prioritize loading from hint database folder

    Returns:
        Path to the exported JSON file
    """
    if output_file is None:
        # Try to save in hint database folder first
        if hint_db_path:
            hint_db_dir = os.path.dirname(hint_db_path)
            if hint_db_dir:
                output_file = os.path.join(hint_db_dir, "chat_messages_export.json")
            else:
                output_file = os.path.join(traces_folder, "chat_messages_export.json")
        else:
            output_file = os.path.join(traces_folder, "chat_messages_export.json")

    # Load chat messages
    logger.info(f"Loading chat messages from {traces_folder}")
    chat_data = load_all_chat_pickles(traces_folder, hint_db_path)

    if not chat_data:
        logger.warning("No chat messages found")
        return output_file

    logger.info(f"Loaded {len(chat_data)} chat message files")

    # Convert to exportable format
    export_data = []
    for chat in chat_data:
        data = chat["data"]

        # Create export entry
        export_entry = {
            "metadata": data.get("metadata", {}),
            "prompts": data.get("prompts", {}),
            "llm_response": data.get("llm_response", {}),
            "generated_hint": data.get("generated_hint", ""),
        }

        # Optionally include trace data
        if include_trace_data:
            export_entry["trace_data"] = data.get("trace_data", [])

        export_data.append(export_entry)

    # Save to JSON file
    try:
        with open(output_file, "w", encoding="utf-8") as f:
            json.dump(export_data, f, indent=2, default=str, ensure_ascii=False)

        logger.info(f"✅ Chat messages exported to {output_file}")
        logger.info(f"   Total chats: {len(export_data)}")
        logger.info(f"   File size: {os.path.getsize(output_file) / 1024 / 1024:.2f} MB")

        return output_file

    except Exception as e:
        logger.error(f"Failed to export chat messages to {output_file}: {e}")
        raise


def create_summary_report(export_data: List[Dict[str, Any]], output_dir: str):
    """
    Create a summary report of the exported chat messages.

    Args:
        export_data: The exported chat data
        output_dir: Directory to save the report
    """
    if not export_data:
        return

    # Analyze the data
    tasks = {}
    models = {}
    total_prompts = 0
    total_hints = 0

    for chat in export_data:
        metadata = chat.get("metadata", {})
        task_name = metadata.get("task_name", "unknown")
        model_name = metadata.get("model_name", "unknown")

        # Count by task
        if task_name not in tasks:
            tasks[task_name] = {
                "count": 0,
                "models": set(),
                "avg_prompt_length": 0,
                "avg_hint_length": 0,
            }
        tasks[task_name]["count"] += 1
        tasks[task_name]["models"].add(model_name)

        # Count by model
        if model_name not in models:
            models[model_name] = {"count": 0, "tasks": set()}
        models[model_name]["count"] += 1
        models[model_name]["tasks"].add(task_name)

        # Calculate lengths
        prompts = chat.get("prompts", {})
        system_len = len(prompts.get("system_prompt", ""))
        user_len = len(prompts.get("user_prompt", ""))
        hint_len = len(chat.get("generated_hint", ""))

        total_prompts += system_len + user_len
        total_hints += hint_len

        # Update task averages
        task_stats = tasks[task_name]
        task_stats["avg_prompt_length"] = (
            task_stats["avg_prompt_length"] * (task_stats["count"] - 1) + system_len + user_len
        ) / task_stats["count"]
        task_stats["avg_hint_length"] = (
            task_stats["avg_hint_length"] * (task_stats["count"] - 1) + hint_len
        ) / task_stats["count"]

    # Convert sets to lists for JSON serialization
    for task_stats in tasks.values():
        task_stats["models"] = list(task_stats["models"])

    for model_stats in models.values():
        model_stats["tasks"] = list(model_stats["tasks"])

    # Create summary
    summary = {
        "export_info": {
            "total_chats": len(export_data),
            "total_tasks": len(tasks),
            "total_models": len(models),
            "avg_prompt_length": total_prompts / len(export_data) if export_data else 0,
            "avg_hint_length": total_hints / len(export_data) if export_data else 0,
        },
        "tasks": tasks,
        "models": models,
        "export_timestamp": datetime.now().isoformat(),
        "export_command": " ".join(sys.argv) if "sys" in globals() else "unknown",
    }

    # Save summary
    summary_file = os.path.join(output_dir, "chat_messages_summary.json")
    try:
        with open(summary_file, "w", encoding="utf-8") as f:
            json.dump(summary, f, indent=2, default=str, ensure_ascii=False)
        logger.info(f"✅ Summary report saved to {summary_file}")
    except Exception as e:
        logger.error(f"Failed to save summary report: {e}")


# Script execution
if __name__ == "__main__":
    import argparse
    import sys

    def main():
        parser = argparse.ArgumentParser(description="Export JephHinter chat messages to JSON")
        parser.add_argument(
            "traces_folder",
            help="Path to the traces folder (chat messages will be loaded from hint database folder if --hint-db is specified)",
        )
        parser.add_argument(
            "--output",
            "-o",
            help="Output JSON file path (default: chat_messages_export.json in hint database folder or traces folder)",
        )
        parser.add_argument(
            "--include-traces",
            "-t",
            action="store_true",
            help="Include full trace data in export (can make file very large)",
        )
        parser.add_argument(
            "--summary",
            "-s",
            action="store_true",
            help="Create a summary report alongside the export",
        )
        parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
        parser.add_argument(
            "--hint-db",
            "-d",
            help="Path to hint database to prioritize loading from hint database folder",
        )

        args = parser.parse_args()

        if args.verbose:
            logger.setLevel(logging.DEBUG)

        # Check if traces folder exists
        if not os.path.exists(args.traces_folder):
            logger.error(f"Traces folder not found: {args.traces_folder}")
            return 1

        try:
            # Export chat messages
            output_file = export_chat_messages_to_json(
                traces_folder=args.traces_folder,
                output_file=args.output,
                include_trace_data=args.include_traces,
                hint_db_path=args.hint_db,
            )

            # Create summary if requested
            if args.summary:
                # Load the exported data to create summary
                with open(output_file, "r", encoding="utf-8") as f:
                    export_data = json.load(f)

                output_dir = os.path.dirname(output_file)
                create_summary_report(export_data, output_dir)

            logger.info(f"\n🎉 Export completed successfully!")
            logger.info(f"   Output file: {output_file}")
            logger.info(f"   You can now open the HTML viewer and load this JSON file")

            return 0

        except Exception as e:
            logger.error(f"Export failed: {e}")
            return 1

    sys.exit(main())
