""" Merge the best checkpoint at the end of training. """

from overrides import overrides
from tasker import BaseTask
import os
import re
try:
    import ujson as json
except ImportError:
    import json
import torch
import logging
from typing import (
    Text,
    Optional
)
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


@BaseTask.register('merge-best-lora')
class MergeBestLoraTask(BaseTask):
    
    __VERSION__ = "0.0.1"
    
    def __init__(
        self,
        input_dir,
        output_dir,
    ):
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        
    @overrides
    def _run(self):
        """ """
        
        def _parse_ckpt_dir(directory: Text) -> int:
            match = re.search(r"checkpoint-(\d+)", directory)
            return int(match.group(1))

        
        ckpts = []
        for ckpt_dir in os.listdir(self._input_dir):
            if not os.path.isdir(os.path.join(self._input_dir, ckpt_dir)):
                continue
            ckpt_num = (ckpt_dir, _parse_ckpt_dir(ckpt_dir))
            ckpts.append(ckpt_num)
            
        # find the latest
        latest_ckpt_dir = sorted(ckpts, key=lambda x: x[1], reverse=True)[0][0]
        logger.info(f"Loading the state file checkpoint from {latest_ckpt_dir}.")
        
        with open(os.path.join(self._input_dir, latest_ckpt_dir, "trainer_state.json") ,'r', encoding='utf-8') as file_:
            training_state = json.load(file_)
            best_ckpt = training_state['best_model_checkpoint']
        logger.info(f"Best checkpoint is {best_ckpt}.")

        # load the best ckpt
        config = PeftConfig.from_pretrained(best_ckpt)
        model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
        peft_model = PeftModel.from_pretrained(model, best_ckpt)
        tokenizer = AutoTokenizer.from_pretrained(best_ckpt)
        merged_model = peft_model.merge_and_unload()

        return (merged_model, tokenizer)
    
    @overrides
    def _write(self, outputs):
        """ """
        merged_model, tokenizer = outputs
        merged_model.save_pretrained(self._output_dir)
        tokenizer.save_pretrained(self._output_dir)