import os
import datetime

class Logger:
    def __init__(self, log_file, log_dir="./logs", initial_message=None):
        self.log_dir = log_dir
        self.log_file_path = os.path.join(log_dir, log_file)
        
        os.makedirs(log_dir, exist_ok=True)
        
        with open(self.log_file_path, "a") as f:
            f.write("=" * 50 + "\n")
            f.write(f"Logger at {datetime.datetime.now()}\n")
            if initial_message:
                f.write(f"{initial_message}\n")
    
    
    def log_training_info(self, train_seq, val_seq):
        with open(self.log_file_path, "a") as f:
            f.write(f"Number of train traj: {train_seq}" + "\n")
            f.write("Epoch\tTrain Loss\n")
    
    
    def log_ddpm_loss(self, epoch, train_loss):
        with open(self.log_file_path, "a") as f:
            f.write(f"{epoch}\t{train_loss:.6f}" + "\n")
            
    def log_train_time(self, time):
        with open(self.log_file_path, "a") as f:
            f.write(f"Training time: {time}" + "\n")
            

class Lam_Logger:
    def __init__(self, env_name, save_path, initial_message=None):
        if os.path.exists(save_path) is False:
            os.makedirs(save_path)
        
        save_name = f"{env_name}_lam.txt"
        #os.makedirs(save_path, exist_ok=True)
        log_file_path = os.path.join(save_path, save_name)
        
        with open(self.save_path, "a") as f:
            f.write("=" * 50 + "\n")
            f.write(f"Logger at {datetime.datetime.now()}\n")
            if initial_message:
                f.write(f"{initial_message}\n")
    
    
    def log_info(self, lams):
        with open(self.path, "a") as f:
                f.write(",".join([f"{v:.4f}" for v in lams]) + "\n")
    