import csv
import os
from typing import List

import numpy as np
import torch


class CSVLogger:
    def __init__(
        self,
        path: str,
        file_name: str,
        keys: List[str],
        check_keys: bool = True,
        overwrite: bool = False,
    ):
        self.path = os.path.join(path, file_name)
        self.keys = keys
        self.keys.append("timestep")
        self.check_keys = check_keys
        self.overwrite = overwrite

        if not os.path.exists(self.path) or overwrite:
            with open(self.path, "w") as f:
                writer = csv.DictWriter(f, fieldnames=self.keys)
                writer.writeheader()

        self.buffer = {k: None for k in self.keys}

    def add_scalar(self, key, value, timestep):
        if key not in self.keys:
            return
        if self.check_keys:
            assert self.buffer[key] is None, f"{key} is already set"
        if self.buffer["timestep"] is not None:
            assert (
                self.buffer["timestep"] == timestep
            ), f"timestep is set to {self.buffer['timestep']} but got {timestep}"
        else:
            self.buffer["timestep"] = timestep
        if isinstance(value, torch.Tensor):
            value = value.item()
        elif isinstance(value, np.ndarray):
            value = value[0]
        self.buffer[key] = value
        if self._check_full():
            self._flush()

    def _check_full(self):
        return all(v is not None for v in self.buffer.values())

    def _flush(self):
        with open(self.path, "a") as f:
            writer = csv.DictWriter(f, fieldnames=self.keys)
            writer.writerow(self.buffer)
        self.buffer = {k: None for k in self.keys}
