"""
Prompt Token Data Collection Module for Graph-Based ARC Analysis

This module scans the datasets directory structure and collects token statistics
for graph descriptions in different encodings, tasks, and sizes.
Uses the existing token counting logic for consistency.
"""

import os
import glob
import re
from typing import Dict, List, Tuple, Any, Optional
import pandas as pd

# Import existing token counting utilities for consistency
from scripts.utils.token_utils import estimate_token_count


def count_tokens_in_file(filepath: str) -> int:
    """
    Count tokens in a text file using the existing token estimation logic.

    Parameters:
    - filepath: Path to the text file

    Returns:
    - Number of tokens in the file
    """
    try:
        with open(filepath, "r", encoding="utf-8") as f:
            content = f.read().strip()

        if not content:
            return 0

        # Use existing token estimation for consistency with the rest of the system
        return estimate_token_count(content)

    except (FileNotFoundError, UnicodeDecodeError) as e:
        print(f"Warning: Could not read {filepath}: {e}")
        return 0


def discover_available_sizes(datasets_dir: str = "datasets") -> List[int]:
    """
    Discover all available graph sizes by scanning the datasets directory structure.

    Parameters:
    - datasets_dir: Path to datasets directory

    Returns:
    - Sorted list of available sizes
    """
    sizes = set()

    # Pattern to find size directories: datasets/*/*/textual/*/*/
    pattern = os.path.join(datasets_dir, "*/*/textual/*/*/")
    size_dirs = glob.glob(pattern)

    for size_dir in size_dirs:
        # Extract the size from the directory name (last part of path)
        size_name = os.path.basename(size_dir.rstrip("/"))
        try:
            size = int(size_name)
            sizes.add(size)
        except ValueError:
            # Skip non-numeric directory names
            continue

    return sorted(list(sizes))


def discover_dataset_structure(datasets_dir: str = "datasets") -> Dict[str, Any]:
    """
    Discover the complete structure of the datasets directory.

    Parameters:
    - datasets_dir: Path to datasets directory

    Returns:
    - Dictionary with discovered structure
    """
    structure = {
        "tasks": set(),
        "graph_types": set(),
        "encodings": set(),
        "sizes": set(),
        "total_files": 0,
    }

    # Find all graph text files
    pattern = os.path.join(datasets_dir, "*/*/textual/*/*/*.txt")
    all_files = glob.glob(pattern)

    structure["total_files"] = len(all_files)

    for filepath in all_files:
        metadata = parse_graph_file_path(filepath)
        if metadata:
            structure["tasks"].add(metadata["task"])
            structure["graph_types"].add(metadata["graph_type"])
            structure["encodings"].add(metadata["encoding"])
            structure["sizes"].add(metadata["size"])

    # Convert sets to sorted lists
    structure["tasks"] = sorted(list(structure["tasks"]))
    structure["graph_types"] = sorted(list(structure["graph_types"]))
    structure["encodings"] = sorted(list(structure["encodings"]))
    structure["sizes"] = sorted(list(structure["sizes"]))

    return structure


def parse_graph_file_path(filepath: str) -> Optional[Dict[str, str]]:
    """
    Parse graph file path to extract metadata.

    Expected format: datasets/<task>/<graph_type>/textual/<input|output>/<size>/<encoding>_<number>.txt

    Parameters:
    - filepath: Full path to graph file

    Returns:
    - Dictionary with extracted metadata
    """
    # Normalize path separators
    filepath = filepath.replace("\\", "/")
    parts = filepath.split("/")

    try:
        # Find the datasets directory index
        datasets_idx = parts.index("datasets")

        # Extract components
        task = parts[datasets_idx + 1]
        graph_type = parts[datasets_idx + 2]
        # Skip 'textual' part
        input_output = parts[datasets_idx + 4]  # 'input' or 'output'
        size = int(parts[datasets_idx + 5])

        # Extract encoding and number from filename
        filename = parts[-1]  # e.g., "adjacency1.txt" or "incident2.txt"
        encoding_number = filename.replace(".txt", "")

        # Handle both formats: "adjacency_1" and "adjacency1"
        if "_" in encoding_number:
            encoding, number = encoding_number.rsplit("_", 1)
        else:
            # Find where the number starts (from the end)
            match = re.match(r"([a-zA-Z]+)(\d+)", encoding_number)
            if match:
                encoding, number = match.groups()
            else:
                raise ValueError(f"Cannot parse filename format: {filename}")

        return {
            "task": task,
            "graph_type": graph_type,
            "input_output": input_output,
            "size": size,
            "encoding": encoding,
            "number": int(number),
            "filepath": filepath,
        }

    except (ValueError, IndexError) as e:
        print(f"Warning: Could not parse path {filepath}: {e}")
        return None


def collect_all_graph_token_data(
    datasets_dir: str = "datasets",
) -> Tuple[pd.DataFrame, Dict[str, Any]]:
    """
    Collect token statistics for all graph files in the datasets directory.

    Parameters:
    - datasets_dir: Path to datasets directory

    Returns:
    - Tuple of (DataFrame with token statistics, structure info)
    """
    print(f"Scanning {datasets_dir} for graph files...")

    # First discover the structure
    structure = discover_dataset_structure(datasets_dir)

    print(f"Discovered structure:")
    print(
        f"  Tasks: {len(structure['tasks'])} ({', '.join(structure['tasks'][:5])}{'...' if len(structure['tasks']) > 5 else ''})"
    )
    print(
        f"  Graph types: {len(structure['graph_types'])} ({', '.join(structure['graph_types'])})"
    )
    print(
        f"  Encodings: {len(structure['encodings'])} ({', '.join(structure['encodings'])})"
    )
    print(
        f"  Sizes: {len(structure['sizes'])} ({', '.join(map(str, structure['sizes']))})"
    )
    print(f"  Total files: {structure['total_files']}")

    # Find all graph text files
    pattern = os.path.join(datasets_dir, "*/*/textual/*/*/*.txt")
    all_files = glob.glob(pattern)

    print(f"Processing {len(all_files)} graph files...")

    token_data = []

    for i, filepath in enumerate(all_files):
        if i % 100 == 0:
            print(f"Processing file {i+1}/{len(all_files)}...")

        # Parse metadata from filepath
        metadata = parse_graph_file_path(filepath)
        if metadata is None:
            continue

        # Count tokens using existing logic
        token_count = count_tokens_in_file(filepath)

        # Add to data
        record = {**metadata, "token_count": token_count}
        token_data.append(record)

    df = pd.DataFrame(token_data)
    print(f"Successfully processed {len(df)} files")

    return df, structure


def calculate_prompt_token_statistics(token_df: pd.DataFrame) -> pd.DataFrame:
    """
    Calculate token requirements for prompts with different numbers of examples.

    For each task/encoding/size combination:
    - 1 example: 1 input + 1 output + 1 final input
    - 2 examples: 2 inputs + 2 outputs + 1 final input
    - 3 examples: 3 inputs + 3 outputs + 1 final input

    Parameters:
    - token_df: DataFrame with individual file token counts

    Returns:
    - DataFrame with prompt-level statistics
    """
    print("Calculating prompt token statistics...")

    # Group by task, encoding, and size to get average tokens per input/output
    grouped = (
        token_df.groupby(["task", "encoding", "size", "input_output"])["token_count"]
        .agg(["mean", "std", "count", "min", "max"])
        .reset_index()
    )

    # Pivot to get input and output columns
    pivot_df = grouped.pivot_table(
        index=["task", "encoding", "size"],
        columns="input_output",
        values=["mean", "std", "count", "min", "max"],
    ).reset_index()

    # Flatten column names
    pivot_df.columns = [
        "_".join(col).strip() if col[1] else col[0] for col in pivot_df.columns.values
    ]

    # Calculate prompt tokens for different example counts
    prompt_stats = []

    for _, row in pivot_df.iterrows():
        input_tokens = row.get("mean_input", 0)
        output_tokens = row.get("mean_output", 0)

        base_stats = {
            "task": row["task"],
            "encoding": row["encoding"],
            "size": row["size"],
            "avg_input_tokens": input_tokens,
            "avg_output_tokens": output_tokens,
            "input_std": row.get("std_input", 0),
            "output_std": row.get("std_output", 0),
            "input_samples": row.get("count_input", 0),
            "output_samples": row.get("count_output", 0),
        }

        # Calculate prompt tokens for different example counts
        for n_examples in [1, 2, 3]:
            prompt_tokens = (
                n_examples * (input_tokens + output_tokens) + input_tokens
            )  # +1 final input

            prompt_stats.append(
                {
                    **base_stats,
                    "n_examples": n_examples,
                    "total_prompt_tokens": prompt_tokens,
                    "example_tokens": n_examples * (input_tokens + output_tokens),
                    "final_input_tokens": input_tokens,
                }
            )

    return pd.DataFrame(prompt_stats)


def get_token_summary_statistics(
    token_df: pd.DataFrame, structure: Dict[str, Any]
) -> Dict[str, Any]:
    """
    Generate summary statistics for the token analysis.

    Parameters:
    - token_df: DataFrame with token statistics
    - structure: Dataset structure information

    Returns:
    - Dictionary with summary statistics
    """
    summary = {
        "dataset_structure": structure,
        "total_files_analyzed": len(token_df),
        "unique_tasks": len(token_df["task"].unique()),
        "unique_encodings": len(token_df["encoding"].unique()),
        "unique_graph_types": len(token_df["graph_type"].unique()),
        "size_range": {
            "min": token_df["size"].min(),
            "max": token_df["size"].max(),
            "values": sorted(token_df["size"].unique()),
        },
        "token_statistics": {
            "overall_mean": token_df["token_count"].mean(),
            "overall_std": token_df["token_count"].std(),
            "overall_min": token_df["token_count"].min(),
            "overall_max": token_df["token_count"].max(),
        },
        "by_encoding": token_df.groupby("encoding")["token_count"]
        .agg(["mean", "std", "count"])
        .to_dict(),
        "by_input_output": token_df.groupby("input_output")["token_count"]
        .agg(["mean", "std", "count"])
        .to_dict(),
        "by_task": token_df.groupby("task")["token_count"]
        .agg(["mean", "std", "count"])
        .sort_values("mean", ascending=False)
        .to_dict(),
        "by_size": token_df.groupby("size")["token_count"]
        .agg(["mean", "std", "count"])
        .to_dict(),
    }

    return summary


if __name__ == "__main__":
    # Example usage
    token_data, structure = collect_all_graph_token_data()
    prompt_data = calculate_prompt_token_statistics(token_data)
    summary = get_token_summary_statistics(token_data, structure)

    print("\n=== PROMPT TOKEN ANALYSIS SUMMARY ===")
    print(f"Total files: {summary['total_files_analyzed']}")
    print(f"Tasks: {summary['unique_tasks']}")
    print(f"Encodings: {summary['unique_encodings']}")
    print(f"Sizes: {summary['size_range']['values']}")
    print(f"Average tokens per file: {summary['token_statistics']['overall_mean']:.1f}")
