import argparse, json, logging, os, sys

from typing import List
from utils import get_instances

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


def attach_predictions_to_task(
    predictions_path: str,
    instances_path: str,
    pred_col: str,
) -> List:
    """
    Given predictions, attach them to their corresponding task instances

    Args:
        predictions_path: Path to predictions file
        instances_path: Path to task instances file
    Returns:
        List of task instances with predictions attached
    """
    # Group predictions by repo
    throw_out_empty = 0
    predictions = get_instances(predictions_path)
    map_repo_to_predictions = {}
    for prediction in predictions:
        if prediction[pred_col] is None:
            throw_out_empty += 1
            continue
        repo = prediction["instance_id"].rsplit("-", 1)[0].replace("__", "/")
        if repo not in map_repo_to_predictions:
            map_repo_to_predictions[repo] = []
        # Remove unnecessary key/value pairs
        for keys in list(prediction.keys()):
            if keys not in ["instance_id", pred_col]:
                del prediction[keys]
        map_repo_to_predictions[repo].append(prediction)
    grouped_sizes = {k: len(v) for k, v in map_repo_to_predictions.items()}
    logger.info(f"Predictions grouped by repo: {grouped_sizes}")
    logger.info(f"Total predictions: {sum([v for v in grouped_sizes.values()])}")

    # Get task instances, organize by repo
    instances_all = get_instances(instances_path)
    map_repo_to_instances = {}
    for instance in instances_all:
        if instance['repo'] not in map_repo_to_instances:
            map_repo_to_instances[instance['repo']] = []
        map_repo_to_instances[instance['repo']].append(instance)

    # Attach predictions to instances
    throw_out_not_in_eval = 0
    for repo, predictions in map_repo_to_predictions.items():
        map_id_to_instances = {x["instance_id"]: x for x in map_repo_to_instances[repo]}

        new_predictions = []
        for prediction in predictions:
            if prediction["instance_id"] in map_id_to_instances:
                instance = map_id_to_instances[prediction["instance_id"]]
                prediction.update(instance)
                new_predictions.append(prediction)
            else:
                throw_out_not_in_eval += 1
        map_repo_to_predictions[repo] = new_predictions
    logger.info(f"Predictions attached to original task instances")

    return map_repo_to_predictions, throw_out_empty, throw_out_not_in_eval


def map_to_harness_keys(
    prediction_path: str,
    predictions: List,
    model: str = None,
    pred_col: str = None,
    repo: str = None,
):
    """
    Converts the keys in the predictions to match the (model, prediction) keys in the harness and
    saves the file to the same directory as the original predictions file

    Args:
        prediction_path: Path to predictions file
        predictions: List of predictions
        model: Name of model
        pred_col: Column name for prediction
        repo: Repo name
    Returns:
        Path to new predictions file
    """
    pred_parts = prediction_path.split("/")
    pred_parts[-1] = (
        f"cleaned_{pred_parts[-1]}"
        if repo is None
        else f"cleaned_{repo.replace('/', '__')}_{pred_parts[-1]}"
    )
    new_path = "/".join(pred_parts)
    # Always save as .jsonl file
    if new_path.endswith('.json'):
        new_path += 'l'

    with open(new_path, "w") as f:
        for p in predictions:
            p["model"] = model
            if pred_col is not None:
                p["prediction"] = p[pred_col]
                del p[pred_col]
            print(json.dumps(p), end="\n", flush=True, file=f)
    if pred_col is not None:
        logger.info(f"Renamed column `{pred_col}` to `prediction`")
    return new_path


def main(args):
    """
    Logic for attaching predictions to task instances and renaming columns to match harness keys

    Args:
        args: Command line arguments, which include the following:
            predictions_path: Path to predictions file
            instances_path: Path to task instances file
            split_by_repo: Split predictions by repo
            model: Name of model
            pred_col: Column name for prediction
    """
    (
        map_repo_to_predictions,
        throw_out_empty,
        throw_out_not_in_eval,
    ) = attach_predictions_to_task(
        args.predictions_path,
        args.instances_path,
        pred_col=args.pred_col,
    )
    logger.info(f"{throw_out_empty} predictions thrown out because they were None")
    logger.info(
        f"{throw_out_not_in_eval} predictions thrown out because they were not in eval"
    )

    total_predictions = 0
    if args.split_by_repo:
        for repo, predictions in map_repo_to_predictions.items():
            total_predictions += len(predictions)
            new_path = map_to_harness_keys(
                args.predictions_path,
                predictions,
                model=args.model,
                pred_col=args.pred_col,
                repo=repo,
            )
            logger.info(
                f"Harness-ready predictions for {repo} by model saved to {new_path}"
            )
    else:
        flattened_predictions = [x for y in map_repo_to_predictions.values() for x in y]
        total_predictions = len(flattened_predictions)
        new_path = map_to_harness_keys(
            args.predictions_path,
            flattened_predictions,
            model=args.model,
            pred_col=args.pred_col,
        )
        logger.info(f"Harness-ready predictions by model saved to {new_path}")

    logger.info(f"Total predictions modified: {total_predictions}")
    assert total_predictions + throw_out_empty + throw_out_not_in_eval == len(
        get_instances(args.predictions_path)
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("predictions_path", type=str, help="Path to predictions file")
    parser.add_argument(
        "instances_path", type=str, help="Path to task instances file"
    )
    parser.add_argument(
        "--split_by_repo", action="store_true", help="Split predictions by repo"
    )
    parser.add_argument("--model", type=str, help="Name of model")
    parser.add_argument("--pred_col", type=str, help="Column name for prediction")
    main(parser.parse_args())
