import os
from typing import Any, Dict, List
import platform


class Logger:
    def __init__(self, args: Dict[str, Any] = None) -> None:
        """
        Initializes a Logger object. This will take track and log the accuracy values and other metrics in the default path (`data/results`).

        Args:
            args: The args from the command line.
            setting_str: The setting of the benchmark.
            dataset_str: The dataset used.
            model_str: The model used.
        """
        self.args = args
        self.results = []
        self.cpu_res = []
        self.gpu_res = []
        self.device_name = platform.node()
        # print green:
        if args is not None:
            print(
                "\033[92m"
                + f"Logger init - Device: {self.device_name}"
                + "\033[0m"
            )

    def results_append(self, result: Dict[str, Any]) -> None:
        """
        Appends a result to the logger.

        Args:
            result: The result to append.
        """
        self.results.append(result)

    def write(self, num_tasks: int = -1) -> None:
        """
        Writes out the logged value along with its arguments in the default path (`data/results`).

        Args:
            args: the namespace of the current experiment
        """
        wrargs = dict()
        wrargs["args"] = self.args

        wrargs["results"] = self.results

        wrargs["cpu_memory_usage"] = self.cpu_res
        wrargs["gpu_memory_usage"] = self.gpu_res

        wrargs["device_name"] = self.device_name

        path = os.path.join(
            "./experiment_results",
            str.upper(self.args["job"]),
            self.args["dataset"],
            self.args["model"],
        )
        
        if self.args["model"] == "lwp":
            path = os.path.join(path, self.args["dist_method"])
            path = os.path.join(path, "dynamic" if not self.args["disable_dynamic"] else "raw")

        if not os.path.exists(path):
            os.makedirs(path)

        if num_tasks == -1:
            dir = os.path.join(path, f"logs.pyd")
        else:
            dir = os.path.join(path, f"logs_{num_tasks}tasks.pyd")

        with open(dir, "a") as f:
            f.write(str(wrargs) + "\n")
        return

    # ---------integrating with track_system_stats----------
    # log_system_stats(self.initial_cpu_res, self.initial_gpu_res)

    def log_system_stats(
        self, cpu_res: Dict[str, Any] = None, gpu_res: Dict[str, Any] = None
    ) -> None:
        """
        Logs the system stats.
        """
        if cpu_res is not None:
            self.cpu_res.append(cpu_res)
        if gpu_res is not None:
            self.gpu_res.append(gpu_res)


def read_logger(path: str) -> List[Logger]:
    """
    Reads the results from a given path.
    """
    assert os.path.exists(path), f"Path {path} does not exist."
    # check pyd file
    assert path.endswith(".pyd"), f"Path {path} is not a pyd file."
    results = []
    try:
        with open(path, "r") as f:
            for line in f.readlines():
                logger = Logger()
                temp = eval(line)
                logger.args = temp["args"]
                logger.results = temp["results"]
                logger.cpu_res = temp["cpu_memory_usage"]
                logger.gpu_res = temp["gpu_memory_usage"]
                logger.device_name = temp["device_name"]
                results.append(logger)
    except Exception as e:
        print(f"Error reading {path}: {e}")
    return results
