import argparse
import json
import logging
from pathlib import Path

import networkx as nx
import pandas as pd
from rich.logging import RichHandler

from graphs.graph import FrankensteinGraph
from graphs.report import GraphReport


def batch_generate_graphs_and_reports(
    df_path: str,
    tool_schema_path: str,
    out_dir: str = 'graphs',
    limit: int | None = None,
    save_fig: bool = True,
    include_errors: bool = True,
):
    """Batch generate GEXF graphs, YAML reports, and messages JSON for each row in a dataframe.

    Parameters
    ----------
    df_path : str
        Path to the input dataframe (JSONL).
    tool_schema_path : str
        Path to the tool schema JSONL file.
    out_dir : str
        Directory to save all outputs (GEXF, YAML, messages).
    limit : int or None
        If set, process only the first `limit` rows.
    save_fig : bool
        If True, also save a PNG figure of the graph (default: True).

    """
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # Create subfolders (no correct/incorrect split)
    graphs_dir = out_dir / 'graphs'
    reports_dir = out_dir / 'reports'
    messages_dir = out_dir / 'messages'
    graphs_dir.mkdir(parents=True, exist_ok=True)
    reports_dir.mkdir(parents=True, exist_ok=True)
    messages_dir.mkdir(parents=True, exist_ok=True)

    df = pd.read_json(df_path, orient='records', lines=True)

    # Optionally limit number of rows
    if limit is not None:
        df = df.sample(n=min(limit, len(df))).reset_index(drop=True)

    logging.info(f'Loaded dataframe with {len(df)} rows from {df_path}')

    summary_log = []

    rows = list(df.to_dict(orient='records'))
    for idx, row in enumerate(rows):
        row_id = row.get('id', f'row_{idx}')
        logging.info(f'[Row {idx + 1}/{len(rows)}] Processing id={row_id}')

        # Build graph
        G = FrankensteinGraph(row, enable_logging=False, include_errors=include_errors)
        # Flatten node/edge attributes for GEXF
        for n, data in G.nodes(data=True):
            for k, v in list(data.items()):
                if isinstance(v, (list, dict)):
                    data[k] = str(v)
        for u, v, data in G.edges(data=True):
            for k, v in list(data.items()):
                if isinstance(v, (list, dict)):
                    data[k] = str(v)

        gexf_path = graphs_dir / f'{row_id}.gexf'
        nx.write_gexf(G, gexf_path)
        logging.info(f'[Row {idx + 1}/{len(rows)}] Graph written: {gexf_path}')

        # Optionally save figure
        if save_fig:
            try:
                G.draw()
                logging.info(f'[Row {idx + 1}/{len(rows)}] Graph figure saved for id={row_id}')
            except Exception as e:
                logging.warning(f'[Row {idx + 1}/{len(rows)}] Could not save figure for id={row_id}: {e}')

        # Generate YAML report
        report = GraphReport(
            path_to_graph_file=gexf_path,
            tool_schema_path=tool_schema_path,
            enable_logging=False,
            include_errors=include_errors,
        )
        yaml_path = reports_dir / f'{row_id}.yaml'
        report.report_args(yaml_path=str(yaml_path))
        logging.info(f'[Row {idx + 1}/{len(rows)}] YAML report written: {yaml_path}')

        # Collect summary info for log.jsonl
        summary = report.get_summary()
        summary['id'] = row_id
        summary_log.append(summary)
        logging.info(
            f'[Row {idx + 1}/{len(rows)}] Summary collected for id={row_id}: nodes={summary.get("num_nodes")}, edges={summary.get("num_edges")}, issues={summary.get("num_issues")}'
        )

        # Output messages column to JSON
        messages = row.get('messages')
        if messages is not None:
            messages_path = messages_dir / f'{row_id}.json'
            with open(messages_path, 'w', encoding='utf-8') as f:
                cleaned_messages = clean_messages(messages)
                json.dump(cleaned_messages, f, ensure_ascii=False, indent=2)
            logging.info(f'[Row {idx + 1}/{len(rows)}] Messages file written: {messages_path}')

    # --- Write summary log.jsonl ---
    log_path = out_dir / 'log.jsonl'
    with open(log_path, 'w', encoding='utf-8') as f:
        f.writelines(json.dumps(entry, ensure_ascii=False) + '\n' for entry in summary_log)
    logging.info(f'Summary log written: {log_path} ({len(summary_log)} rows)')


def clean_messages(messages):
    """Clean a list of conversation messages for provenance explanation generation.

    Keeps:
      - User messages (the questions)
      - Assistant messages that contain tool_calls
      - Tool messages (tool results)

    Removes:
      - Assistant free-text narrative or <think> reasoning
      - Any other extraneous messages

    Args:
        messages (list of dict): Raw message list

    Returns:
        list of dict: Cleaned message list

    """
    cleaned = []
    for msg in messages:
        role = msg.get('role')

        if role == 'user':
            # Keep user question as-is
            cleaned.append({'role': 'user', 'content': msg.get('content')})

        elif role == 'assistant':
            # Keep only tool calls, discard narrative
            if msg.get('tool_calls'):
                cleaned.append({'role': 'assistant', 'tool_calls': msg['tool_calls']})

        elif role == 'tool':
            # Keep tool outputs as-is
            cleaned.append({'role': 'tool', 'tool_call_id': msg.get('tool_call_id'), 'content': msg.get('content')})

    return cleaned


if __name__ == '__main__':
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(message)s',
        datefmt='[%X]',
        handlers=[RichHandler()],
    )

    parser = argparse.ArgumentParser(description='Batch generate graphs and YAML reports from a dataframe.')
    parser.add_argument('--df', type=str, required=True, help='Path to the input dataframe (JSONL)')
    parser.add_argument(
        '--tool-schema', type=str, default='frankenstein/tools/tool_schema.jsonl', help='Path to the tool schema JSONL file'
    )
    parser.add_argument('--out-dir', type=str, default='graphs/outputs', help='Directory to save all outputs')
    parser.add_argument('--n', type=int, default=None, help='Limit the number of rows processed')
    parser.add_argument('--no-fig', action='store_true', help='Disable saving PNG figures of the graphs')
    parser.add_argument(
        '--include-errors', action='store_true', default=False, help='Include tool calls with error/warning results'
    )
    args = parser.parse_args()

    batch_generate_graphs_and_reports(
        df_path=args.df,
        tool_schema_path=args.tool_schema,
        out_dir=args.out_dir,
        limit=args.n,
        save_fig=not args.no_fig,
        include_errors=args.include_errors,  # <-- pass CLI arg
    )
