from utils.s3_file_handler import S3FileHandler, make_json_serializable
import os
import json
from os.path import dirname, abspath


class CustomMetricLogger:
    def __init__(self, args, metric_name):
        if "S3_ENDPOINT" in os.environ:
            self.data_path = f"{args.exp_main_name}/{args.name}/{args.env_args['key']}/{args.seed}/{metric_name}/"
            self.file_handler = S3FileHandler(
                s3_endpoint=os.environ["S3_ENDPOINT"],
                bucket="input",
            )

        else:
            data_path = f"local_results/{args.exp_main_name}/{args.name}/{args.env_args['key']}/{args.seed}/{metric_name}/"
            full_path = os.path.join(
                dirname(dirname(dirname(abspath(__file__)))), data_path
            )
            os.makedirs(os.path.dirname(full_path), exist_ok=True)
            self.full_path = full_path

    def write(self, metric_dict):
        # The metric dict here will be a dictionary containing the metric name as the key
        # and a list of values as the value.

        # Get the timestep value from the metric_dict key
        timestep = str(list(metric_dict.keys())[0])

        # Get all values from the metric_dict value
        values = list(metric_dict.values())[0]

        json_data = make_json_serializable(values)

        # Save data to S3 bucket:
        if "S3_ENDPOINT" in os.environ:
            write_path = f"{self.data_path}{timestep}.json"
            self.file_handler.save_json(path=write_path, data=json_data)

        # Save similar data to S3 bucket locally in a json file:
        else:
            write_path = f"{self.full_path}{timestep}.json"
            #  Create the directory if it doesn't exist
            os.makedirs(os.path.dirname(write_path), exist_ok=True)

            with open(write_path, "w") as f:
                json.dump(json_data, f, indent=4)
