import numpy as np
from typing import Callable
import os
import json
from datetime import datetime
from tqdm import tqdm


class Runner:
    """
    Run a function multiple times and save the results, the mean and the standard deviation.
    Can also save/load the results from a json file.
    """

    def __init__(
        self,
        n_trials: int = 10,
        log_dir: str = os.path.join(os.path.dirname(__file__), "..", "runner_logs"),
        show_progress: bool = False,
    ):
        self.n_trials = n_trials
        self.values = []
        self.index = 0
        self.log_dir = log_dir
        os.makedirs(self.log_dir, exist_ok=True)
        self.show_progress = show_progress

    def run(self, target_func: Callable, task_name: str = "task"):
        self.values = [
            target_func()
            for _ in tqdm(range(self.n_trials), disable=not self.show_progress)
        ]

        # Save results after running
        self.save_results(task_name)

    def get_values(self):
        return {
            "values": self.values,
            "mean": float(np.mean(self.values)),
            "std": float(np.std(self.values)),
        }

    def save_results(self, task_name: str):
        """Save the current results to a json file.

        Args:
            task_name (str): The name of the task.
        """
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{task_name}_{timestamp}.json"
        filepath = os.path.join(self.log_dir, filename)

        results_dict = {
            "task_name": task_name,
            "timestamp": timestamp,
            "n_trials": self.n_trials,
            "values": self.get_values(),
        }

        with open(filepath, "w") as f:
            json.dump(results_dict, f, indent=2)

    def load_results(self, filepath: str):
        """Load results from a json file.

        Args:
            filepath (str): The path to the json file.
        """
        with open(filepath, "r") as f:
            results = json.load(f)
        self.values = results["values"]
        return results


if __name__ == "__main__":

    def test_func(x):
        return x * 2

    runner = Runner(n_trials=10, show_progress=True)
    runner.run(test_func, "test")
    print(runner.get_values())
