import os
import json
import logging
from typing import Literal
from RoboMemory.agent_utils import save_image_from_base64

class DataLogger:
    
    def __init__(
            self,
            ckpt_path = "./ckpt",
            record_name = "TASK",
            image_dir_name = "img",
            image_type = "jpeg"
        ):
       
        self.ckpt_path = ckpt_path
        self.record_name = record_name
        self.image_dir_name = image_dir_name
        self.image_type = image_type
        
        # Create directories if they do not exist
        os.makedirs(self.ckpt_path, exist_ok=True)
        self.image_dir = os.path.join(self.ckpt_path, self.image_dir_name)
        os.makedirs(self.image_dir, exist_ok=True)
        
        # Define the jsonl file path
        self.file = os.path.join(self.ckpt_path, f"{self.record_name}.jsonl")
        self.log_file = os.path.join(self.ckpt_path, f"{self.record_name}.txt")
        self.local_step = 0 
        
        
    def console_log(self, global_step : int, returns, agent_name = None):
      
        if not agent_name:
            logging_str = f"\n###### local: {self.local_step}, global: {global_step}, module: {agent_name} ######\n{returns}\n################################\n"
            
        else:
            logging_str = f"\n############# local: {self.local_step}, global: {global_step} ###################\n{returns}\n################################\n"
            
        print(logging_str)
        
        with open(self.log_file, "a", encoding="utf-8") as f:
            f.write(logging_str + "\n")
        
    
    def log(
            self, 
            global_step : int,
            memory: str,
            obs_image: str,
            obs_text: str,
            CoT: str,
            AC: Literal["Actor", "Critic"],
            save_outputs: dict,
            save_inputs: dict         
        ):
  
        # Generate the filename for the image/video
        filename = f"{global_step}.{self.image_type}"
        image_path = self.image_dir
        
        # Save the image using the provided utility function
        image_path = save_image_from_base64(obs_image, image_path, self.image_type)
        
        # Prepare the data dictionary
        data = {
            "ActorCritic": AC,
            "memory": memory,
            "obs_image": image_path,
            "obs_text": obs_text,
            "CoT": CoT,
            "input": save_inputs,
            "output": save_outputs
        }
        
        # Append the data to the jsonl file
        with open(self.file, "a", encoding="utf-8") as f:
            f.write(json.dumps(data) + "\n")
        
        # Increment the step counter for the next log entry
        self.local_step += 1