from transformers import TrainerCallback
from accelerate import Accelerator
import os
import logging
import json


class LoggerCallback(TrainerCallback):

    def __init__(self, exp_dir, mode="train", use_accel=False, rank = None):
        self.accelerator = Accelerator() if use_accel else None
        super().__init__()
        self.exp_dir = exp_dir
        self.mode = mode
        self.model_dir = f"{exp_dir}/model"
        self.log_dir = f"{exp_dir}/logs"
        self.eval_dir = f"{exp_dir}/eval"
        self.logger = logging.getLogger(__name__)
        os.makedirs(self.exp_dir, exist_ok=True)
        os.makedirs(self.log_dir, exist_ok=True)
        os.makedirs(self.eval_dir, exist_ok=True)
        self.log_file = open(f"{self.log_dir}/{self.mode}.log", "w")
        self.main_process = (use_accel and self.accelerator.is_main_process) or (rank == 0)
 
        if self.main_process:
            if os.path.exists(f'{self.log_dir}/{self.mode}.log'):
                os.remove(f'{self.log_dir}/{self.mode}.log')
            open(f'{self.log_dir}/{self.mode}.log', 'w').close()
            self.logger.setLevel(logging.INFO)
            self.logger.addHandler(logging.FileHandler(f'{self.log_dir}/{self.mode}.log', 'a'))
            self.logger.addHandler(logging.StreamHandler())
            self.logger.handlers[0].setFormatter(logging.Formatter('%(message)s'))
            self.logger.handlers[1].setFormatter(logging.Formatter('%(message)s'))
    
    def set_rank(self, rank):
        self.rank = rank
        self.main_process = (self.use_accel and self.accelerator.is_main_process) or (rank == 0)
        
    def log(self, message):
        if self.main_process:
            self.logger.info(message)

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            self.log("\t".join([f"{k}: {v}" for k, v in logs.items()]))
    
    def dump_json(self, data, file_path):
        if self.main_process:
            with open(file_path, 'w') as f:
                json.dump(data, f, indent=4)