import csv
import os
import json
from datetime import datetime

class BaseDataLogger:
    def __init__(self):
        self.log_dir = "experiment_logs"
        self.ensure_log_directory()
        self.current_session = None

    def ensure_log_directory(self):
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

    def start_new_session(self, task, model, prompt_type, presentation_mode, is_human, session, subject,impairment_type):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        if not is_human:
            if not impairment_type:
                session_name = f"{task}_{model}_{prompt_type}_{presentation_mode}_session{session}_{timestamp}"
            else:
                session_name = f"{task}_{model}_{prompt_type}_{presentation_mode}_session{session}_{timestamp}_{impairment_type}"
        else:
            session_name = f"{task}_{'human'}_subject{subject}_{prompt_type}_{presentation_mode}_session{session}_{timestamp}"
        self.current_session = os.path.join(self.log_dir, session_name)
        os.makedirs(self.current_session)

        metadata = {
            "task": task,
            "model": model,
            "prompt_type": prompt_type,
            "presentation_mode": presentation_mode,
            "is_human": is_human,
            "session": session,
            "timestamp": timestamp,
        }
        with open(os.path.join(self.current_session, "metadata.json"), "w") as f:
            json.dump(metadata, f, indent=2)

        self.csv_file = open(os.path.join(self.current_session, "trial_data.csv"), "w", newline="", encoding='utf-8')
        self.csv_writer = csv.writer(self.csv_file)
        self.write_csv_header()

    def write_csv_header(self):
        raise NotImplementedError("Subclasses must implement this method")

    def log_trial(self, trial_data):
        raise NotImplementedError("Subclasses must implement this method")

    def end_session(self):
        if self.csv_file:
            self.csv_file.close()

class WCSTDataLogger(BaseDataLogger):
    def write_csv_header(self):
        self.csv_writer.writerow(["Trial", "Image", "Prompt", "Response", "Prompt_tokens", "Completion_tokens", "Total_tokens", "Is_Correct", "Correct_Card", "Correct_In_Row", "Current_Rule", "Category_Completed", "Applied_Rules", "Impairment_Type"])

    def log_trial(self, trial_data):
        self.csv_writer.writerow([
            trial_data["trial_number"], 
            trial_data["image_path"], 
            trial_data["prompt"], 
            trial_data["response"], 
            trial_data["prompt_tokens"],
            trial_data["completion_tokens"],
            trial_data["total_tokens"],
            # trial_data["tokens"],
            trial_data["is_correct"], 
            trial_data["correct_card"], 
            trial_data["correct_in_row"], 
            trial_data["current_rule"], 
            trial_data["category_completed"], 
            trial_data["applied_rules"],
            trial_data["impairment_type"]
        ])
        self.csv_file.flush()

        with open(os.path.join(self.current_session, f"trial_{trial_data['trial_number']}_data.json"), "w") as f:
            json.dump(trial_data, f, indent=2)

