import argparse
import json
import logging
import random
from pathlib import Path

import litellm
from rich.columns import Columns
from rich.console import Console
from rich.logging import RichHandler
from rich.panel import Panel
from tqdm import tqdm

from graphs.prompts import ADDITIONAL_CONV_GRAPH_PROMPT, ADDITIONAL_GRAPH_PROMPT, BASE_SYSTEM_PROMPT

litellm._logging._disable_debugging()

logging.basicConfig(
    level='INFO',
    format='%(message)s',
    datefmt='[%X]',
    handlers=[RichHandler()],
)
console = Console(record=True)

IN_DIR = Path('graphs', 'outputs')
OUT_DIR = Path('graphs', 'explanations')


def generate_explanation(
    id: str,
    model: str = 'openai/gpt-4.1-mini',
    mode: str = 'base',
):
    """Generate an explanation for a graph in a given mode."""
    if mode not in ('base', 'with_graph'):
        raise ValueError("mode must be 'base' or 'with_graph'")

    msg_path = IN_DIR / 'messages' / f'{id}.json'
    yaml_path = IN_DIR / 'reports' / f'{id}.yaml'

    # Read tool calls
    with open(msg_path, encoding='utf-8') as f:
        tool_calls = f.read()

    # Build system prompt and optional report
    if mode == 'base':
        system_prompt = BASE_SYSTEM_PROMPT
        report = None
    else:
        system_prompt = BASE_SYSTEM_PROMPT + '\n\n' + ADDITIONAL_GRAPH_PROMPT
        with yaml_path.open(encoding='utf-8') as f:
            report = f.read()

    messages = [{'role': 'system', 'content': system_prompt}]
    messages.append({'role': 'user', 'content': f'```json\n{tool_calls}\n```'})

    if report is not None:
        messages.append({'role': 'user', 'content': f'```yaml\n{report}\n```'})

    response = litellm.completion(
        model=model,
        messages=messages,
        temperature=0,
        api_base=None,
    )

    return response.choices[0].message.content.strip()


def generate_two_stage_explanation(
    id: str,
    model: str = 'openai/gpt-4.1-mini',
):
    """Generate explanation in two-stage mode (base + augmentation) using two completions."""
    msg_path = IN_DIR / 'messages' / f'{id}.json'
    yaml_path = IN_DIR / 'reports' / f'{id}.yaml'

    # Step 1: load tool calls (as text)
    with open(msg_path, encoding='utf-8') as f:
        tool_calls = f.read()

    # Step 2: generate base explanation (system + user/tool_calls)
    base_messages = [
        {'role': 'system', 'content': BASE_SYSTEM_PROMPT},
        {'role': 'user', 'content': f'```json\n{tool_calls}\n```'},
    ]
    base_response = litellm.completion(
        model=model,
        messages=base_messages,
        temperature=0,
        api_base=None,
    )
    base_expl = base_response.choices[0].message.content.strip()

    # Step 3: load provenance report
    with yaml_path.open(encoding='utf-8') as f:
        report = f.read()

    # Step 4: two-stage augmentation (append prompt and report)
    two_stage_messages = [
        *base_messages,
        {'role': 'assistant', 'content': base_expl},
        {'role': 'system', 'content': ADDITIONAL_CONV_GRAPH_PROMPT},
        {'role': 'user', 'content': f'Provenance report:\n```yaml\n{report}\n```'},
    ]
    two_stage_response = litellm.completion(
        model=model,
        messages=two_stage_messages,
        temperature=0,
        api_base=None,
    )
    augmented_expl = two_stage_response.choices[0].message.content.strip()

    return augmented_expl


def batch_generate_explanations(
    model: str = 'openai/gpt-4.1-mini',
    save: bool = True,
    n: int = None,
    do_print: bool = False,
    two_stage: bool = False,
):
    """Batch generate explanations in different modes."""
    OUT_DIR.mkdir(parents=True, exist_ok=True)

    # Create subfolders for explanations
    base_dir = OUT_DIR / 'base'
    with_report_dir = OUT_DIR / 'with_report'
    base_dir.mkdir(parents=True, exist_ok=True)
    with_report_dir.mkdir(parents=True, exist_ok=True)

    # --- Read IDs from log.jsonl ---
    log_path = IN_DIR / 'log.jsonl'
    with open(log_path, encoding='utf-8') as f:
        log_entries = [json.loads(line) for line in f if line.strip()]
    log_entries = [entry for entry in log_entries if entry.get('num_issues') > 0]
    ids = [entry['id'] for entry in log_entries if 'id' in entry]

    ids = sorted(ids)
    if n is not None:
        random.shuffle(ids)
        ids = ids[:n]

    explanations = {}
    survey_rows = []

    for id in tqdm(ids, desc='Generating explanations'):
        explanations[id] = {}
        try:
            msg_path = IN_DIR / 'messages' / f'{id}.json'
            yaml_path = IN_DIR / 'reports' / f'{id}.yaml'

            # Step 1: generate base explanation
            base_expl = generate_explanation(id=id, model=model, mode='base')
            explanations[id]['base'] = base_expl

            # Step 2: augment with graph (two-stage or not)
            if two_stage:
                explanation = generate_two_stage_explanation(id=id, model=model)
            else:
                explanation = generate_explanation(id=id, model=model, mode='with_graph')
            explanations[id]['with_report'] = explanation

            # Step 3: get question text from messages file
            with open(msg_path, encoding='utf-8') as f:
                messages = json.load(f)
            question = ''
            for msg in messages:
                if msg.get('role') == 'user' and msg.get('content'):
                    question = msg['content']
                    break

            # Collect for survey output
            survey_rows.append(
                {
                    'id': id,
                    'question': question,
                    'base_explanation': base_expl,
                    'with_report_explanation': explanation,
                }
            )

            if save:
                base_path = base_dir / f'{id}.txt'
                base_path.write_text(base_expl, encoding='utf-8')
                exp_path = with_report_dir / f'{id}.txt'
                exp_path.write_text(explanation, encoding='utf-8')

            if do_print:
                panel_width = console.size.width // 2
                panels = [
                    Panel(
                        base_expl,
                        title=f'[bold cyan]{id}[/bold cyan] - base',
                        border_style='blue',
                        width=panel_width,
                    ),
                    Panel(
                        explanation,
                        title=f'[bold magenta]{id}[/bold magenta] - graph-augmented',
                        border_style='magenta',
                        width=panel_width,
                    ),
                ]
                console.print(Columns(panels))

        except Exception as e:
            logging.exception(f'Error generating explanation for {id}: {e}')

    # Save structured survey file if requested
    if save:
        survey_path = OUT_DIR / 'survey.jsonl'
        with open(survey_path, 'w', encoding='utf-8') as f:
            f.writelines(json.dumps(row, ensure_ascii=False) + '\n' for row in survey_rows)
        logging.info(f'Survey file written: {survey_path} ({len(survey_rows)} rows)')

    return explanations


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Generate explanations for graph files')
    parser.add_argument('--model', type=str, default='openai/gpt-4.1-mini', help='Model to use for generation')
    parser.add_argument('--save', action='store_true', help='If set, save explanations to files')
    parser.add_argument('--n', type=int, default=None, help='Limit the number of examples to generate')
    parser.add_argument('--print', action='store_true', help='If set, pretty-print explanations with rich')
    parser.add_argument('--two-stage', action='store_true', help='Use two-stage augmentation mode')
    args = parser.parse_args()

    batch_generate_explanations(
        model=args.model,
        save=args.save,
        n=args.n,
        do_print=args.print,
        two_stage=args.two_stage,
    )
