import os
import atexit
import csv
from collections import defaultdict

class Logger:
    def __init__(self, writer, output_fname="progress.txt", log_path="log", csv_fname="progress.csv"):
        self.writer = writer
        self.log_path = self.writer.get_logdir()
        self.output_file = open(os.path.join(self.writer.get_logdir(), output_fname), 'w')
        self.csv_path = os.path.join(self.writer.get_logdir(), csv_fname)

        self.metric_buffer = defaultdict(dict)  # step → {tag: value}
        self.all_tags = set()
        self._csv_initialized = False
        atexit.register(self._close_files)

    def _close_files(self):
        self.output_file.close()

    def record(self, tag, scalar_value, global_step, printed=True):
        # TensorBoard
        self.writer.add_scalar(tag, scalar_value, global_step)

        # Text output
        if printed:
            info = f"{tag}: {scalar_value:.6f}"
            print("\033[1;32m [info]\033[0m: " + info)
            self.output_file.write(info + '\n')

        # Buffer value
        self.metric_buffer[global_step][tag] = scalar_value
        self.all_tags.add(tag)

    def flush(self, step):
        os.makedirs(os.path.dirname(self.csv_path), exist_ok=True)

        fieldnames = ['step'] + sorted(self.all_tags)

        # Check if file exists to decide whether to write header
        write_header = not os.path.exists(self.csv_path) or not self._csv_initialized

        with open(self.csv_path, 'a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)

            if write_header:
                writer.writeheader()
                self._csv_initialized = True

            row = {'step': step}
            row.update(self.metric_buffer[step])
            writer.writerow(row)

        # Clean up buffer
        del self.metric_buffer[step]

    def print(self, info):
        print("\033[1;32m [info]\033[0m: " + info)
        self.output_file.write(info + '\n')
