from pathlib import Path
import csv
import json

import fire
from beartype import beartype

from helpers import logger


@beartype
def csv_to_json(csv_path: str):
    """Translate a csv file downloaded from a wandb plot to json"""

    csv_path = Path(csv_path)
    if not csv_path.exists():  # does the file exist?
        raise FileNotFoundError(f"File not found: {csv_path}")
    if csv_path.suffix.lower() != ".csv":  # is it a csv file?
        raise ValueError(f"Expected a .csv file, got: {csv_path.suffix}")

    # read the CSV file
    with csv_path.open(newline="", encoding="utf-8") as f:

        reader = csv.DictReader(f)

        fieldnames = reader.fieldnames
        if not fieldnames:
            raise ValueError("CSV file has no header")
        # find the step column
        col1 = "Step"
        if col1 not in fieldnames:
            raise ValueError("CSV must contain a 'Step' column")
        # find the column that ends with "/return"
        col2_end = "/return"
        return_columns = [col for col in fieldnames if col.endswith(col2_end)]
        if len(return_columns) != 1:
            raise ValueError(f"CSV must contain exactly one column that ends with {col2_end}")
        return_col = return_columns[0]  # by now there is only one element, so take it

        # write the JSON file with one record per line
        json_path = csv_path.with_suffix(".json")
        with json_path.open("w", encoding="utf-8") as ff:
            for row in reader:
                record = {
                    "timestep": int(row[col1]),
                    "return": float(row[return_col]),
                }
                ff.write(json.dumps(record) + "\n")

    logger.warn(f"JSON file saved @ {json_path}")


if __name__ == "__main__":
    fire.Fire(csv_to_json)
