import argparse
import logging
from pathlib import Path

import openai
import pandas as pd
from rich.logging import RichHandler
from runner import Runner

logging.basicConfig(
    level=logging.INFO,
    format='%(message)s',
    datefmt='[%X]',
    handlers=[
        RichHandler(
            tracebacks_suppress=[openai],
            markup=True,
        )
    ],
)


class FrankensteinEvaluator:
    """Evaluate the performance of a transformer model on a split/portion/template of the dataset."""

    def __init__(
        self,
        model_name: str,
        toolbox: str = 'all',
        save: bool = False,
        num_samples: int = -1,
        split: str = 'answerable-full',
        n_shots: int = 0,
        debug: bool = False,  # Add debug argument
    ):
        """Initialize the evaluator.

        Parameters
        ----------
        model_name : str
            Path or name of the transformer model.
        toolbox : str
            Toolbox to use for the evaluation. Can be 'all', 'arithmetic', 'data', or 'none'.
        save : bool
            Whether to save the evaluation results.
        num_samples : int
            Number of samples to evaluate.
        split : str
            Dataset split to use.
        n_shots : int
            Number of n-shot tool call examples to prepend to the prompt.

        """
        self.model_name = model_name
        self.toolbox = toolbox
        self.save = save
        self.num_samples = num_samples
        self.split = split
        self.n_shots = n_shots
        self.debug = debug

        # Load dataset from dataset/{split}.jsonl or .json
        dataset_path = Path('dataset', f'{self.split}.jsonl')
        self.dataset = pd.read_json(dataset_path, orient='records', lines=True, precise_float=True)
        if self.num_samples != -1:
            self.dataset = self.dataset.sample(self.num_samples)
        logging.info(f'Loaded dataset from {dataset_path} with {len(self.dataset)} samples.')

        self.log_config(vars(self))

    def run(
        self,
    ) -> list:
        """Evaluate the model on the dataset.

        Returns
        -------
        list
            List of messages generated by the model for each input in the dataset.

        """
        results = []
        runner = Runner(
            model_name=self.model_name,
            toolbox=self.toolbox,
            n_shots=self.n_shots,
            debug=self.debug,
        )

        model_name = str(self.model_name).split('/')[-1]
        output_path = Path('eval', 'runs', f'{model_name}_{self.split}_{self.toolbox}-tools_{self.n_shots}-shot.jsonl')

        # --- Resume logic: load previous results and skip already processed questions ---
        completed_ids = set()
        if output_path.exists():
            try:
                prev_results_df = pd.read_json(output_path, orient='records', lines=True, precise_float=True)
                results = prev_results_df.to_dict(orient='records')
                # Use a unique identifier for each row; fallback to question text if 'id' not present
                if 'id' in prev_results_df.columns:
                    completed_ids = set(prev_results_df['id'])
                else:
                    completed_ids = set(prev_results_df['question'])
                logging.info(f'Resuming from partial run: {len(completed_ids)} questions already processed.')
            except Exception as e:
                logging.warning(f'Could not load previous results for resuming: {e}')

        for idx, (_, row) in enumerate(self.dataset.iterrows()):
            # Use 'id' if present, else fallback to question text as unique identifier
            row_id = row['id'] if 'id' in row else row['question']
            if row_id in completed_ids:
                continue

            runner.reset()

            logging.info(f"✨ Processing question {idx + 1}/{len(self.dataset)} of '{output_path}'")
            logging.info('🔎 Question Metadata')
            self.log_question_info(row)

            messages, tokens_used = runner.loop(row['question'])
            gold_answer = row['answer']
            answer_format = row['answer_format']

            correct, error = runner.match_results(messages, gold_answer, answer_format)
            pred = runner.matcher.extract_final_answer(messages)

            result_row = row.to_dict()
            result_row.update(
                {
                    'messages': runner.format_messages(messages),
                    'tokens': tokens_used,
                    'pred': pred,
                    'correct': correct if correct is not None else False,
                    'error': error,
                }
            )
            results.append(result_row)
            completed_ids.add(row_id)

            # Save after every iteration
            if self.save:
                output_path.parent.mkdir(parents=True, exist_ok=True)
                pd.DataFrame(results).to_json(output_path, orient='records', lines=True)

        results_df = pd.DataFrame(results)

        if self.save:
            output_path.parent.mkdir(parents=True, exist_ok=True)
            results_df.to_json(output_path, orient='records', lines=True)
            logging.info(f'Saved evaluation results to {output_path}')

        return results_df['messages'].tolist()

    def log_config(
        self,
        config: dict,
    ) -> None:
        """Log the configuration in a formatted way.

        Parameters
        ----------
        config : dict
            Configuration dictionary to log.

        """
        key_width = max(len(str(k)) for k in config)
        logging.info('Model Config')
        for k, v in config.items():
            if k == 'dataset':
                continue
            arrow = '-' * (key_width + 1 - len(str(k))) + '>'
            logging.info(f"⚙️ '{k}' {arrow} {v!r}")

    def log_question_info(
        self,
        row: pd.Series,
    ) -> None:
        """Log metadata in a formatted table.

        Parameters
        ----------
        metadata : dict
            Metadata dictionary to log.

        """
        keys = ['question_template', 'slot_values', 'answerable', 'answer_format', 'data_availability']
        key_width = max(len(str(k)) for k in keys)
        for k in keys:
            if k == 'slot_values':
                for sk, sv in row.get(k, {}).items():
                    # Arrow line replaces padding: key + ('-' * (key_width - len(key))) + '>'
                    arrow = '-' * (key_width + 1 - len(str(sk))) + '>'
                    logging.info(f"🔑 '{sk}' {arrow} {sv!r}")
            else:
                v = row.get(k)
                # Arrow line replaces padding: key + ('-' * (key_width - len(key))) + '>'
                arrow = '-' * (key_width + 1 - len(str(k))) + '>'
                logging.info(f"🔑 '{k}' {arrow} {v!r}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate a transformer model.')
    parser.add_argument(
        '--model-name',
        type=str,
        default='/public/hf/models/meta-llama/Meta-Llama-3.1-8B-Instruct/',
        help='Path or name of the transformer model.',
    )
    parser.add_argument(
        '--split',
        type=str,
        default='answerable-full',
        help='Dataset split to use (e.g., "answerable-full", "unanswerable-partial", etc.).',
    )
    parser.add_argument(
        '--num-samples',
        type=int,
        default=-1,
        help='Number of samples to evaluate. Use -1 for all samples.',
    )
    parser.add_argument(
        '--toolbox',
        type=str,
        choices=['all', 'arithmetic', 'data', 'none'],
        default='all',
        help='Toolbox to use for the evaluation. Can be "all", "arithmetic", "data", or "none".',
    )
    parser.add_argument(
        '--save',
        action='store_true',
        help='Whether to save the evaluation results.',
    )
    parser.add_argument(
        '--n-shots',
        type=int,
        default=0,
        help='Number of n-shot tool call examples to prepend to the prompt.',
    )
    parser.add_argument(
        '--debug',
        action='store_true',
        help='If set, the loop will wait for user input after each message.',
    )
    args = parser.parse_args()

    evaluator = FrankensteinEvaluator(
        model_name=args.model_name,
        toolbox=args.toolbox,
        save=args.save,
        num_samples=args.num_samples,
        split=args.split,
        n_shots=args.n_shots,
        debug=args.debug,  # Pass debug argument
    )
    evaluator.args = args  # Attach args for logging

    evaluator.run()
