import os
import csv
import logging
import numpy as np

BATCH_WRITE_THRESHOLD = 100

class SudokuDataRecorder:
    def __init__(self, csv_dir, exp_id):
        self.csv_dir = os.path.join(csv_dir, exp_id) if csv_dir else f"./training_records/{exp_id}"
        self.exp_id = exp_id
        self.logger = logging.getLogger("TLAD-DataRecorder")
        
        self.step_energy_buffer = []
        
        self._init_csv_files()
        self.logger.info(f"Data recorder initialized, CSV save dir: {self.csv_dir}")

    def _init_csv_files(self):
        os.makedirs(self.csv_dir, exist_ok=True)
        
        step_energy_path = self._get_csv_path("step_energy_evolution.csv")
        if not os.path.exists(step_energy_path):
            with open(step_energy_path, 'w', newline='', encoding='utf-8') as f:
                writer = csv.writer(f)
                writer.writerow([
                    "epoch", "batch_idx", "sample_idx", 
                    "step", "step_free_energy", 
                    "total_constraint_error", "row_error", "col_error", "box_error"
                ])

    def _get_csv_path(self, filename):
        return os.path.join(self.csv_dir, filename)

    def append_step_energy(self, epoch, batch_idx, sample_idx, step_idx,
                          step_free_energy, constraint_error_tuple):
        if not isinstance(constraint_error_tuple, tuple) or len(constraint_error_tuple) != 4:
            raise ValueError(f"constraint_error_tuple must be 4-element tuple, got {type(constraint_error_tuple)} with length {len(constraint_error_tuple) if isinstance(constraint_error_tuple, tuple) else 'N/A'}")
        
        total_err, row_err, col_err, box_err = constraint_error_tuple
        self.step_energy_buffer.append([
            epoch, batch_idx, sample_idx,
            step_idx, step_free_energy,
            total_err, row_err, col_err, box_err
        ])
        
        if len(self.step_energy_buffer) >= BATCH_WRITE_THRESHOLD:
            self._flush_step_energy()

    def _flush_step_energy(self):
        if not self.step_energy_buffer:
            return
        
        step_energy_path = self._get_csv_path("step_energy_evolution.csv")
        with open(step_energy_path, 'a', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerows(self.step_energy_buffer)
        
        self.logger.debug(f"Flushed {len(self.step_energy_buffer)} step energy records to CSV")
        self.step_energy_buffer = []

    def flush_all(self):
        self._flush_step_energy()
        self.logger.info(f"All cached data flushed to CSV, save dir: {self.csv_dir}")

    def get_csv_summary(self):
        step_energy_path = self._get_csv_path("step_energy_evolution.csv")
        
        step_count = sum(1 for _ in open(step_energy_path, 'r', encoding='utf-8')) - 1  # 减表头
        
        return {
            "csv_dir": self.csv_dir,
            "step_energy_count": step_count
        }