#-*- coding:utf-8 -*-

import datetime
import torch
import os

class CheckPointManager:
    def __init__(
            self, 
            epochs:int,
            base_folder:str="./weights", 
            save_last_epochs:int=10,
            save_per_epochs:int=50,
            base_name:str="DDPM",
            task_name:str='PUSHT',
            task_tag:str='NONE',
            control_type:str='STATE',
            model_type:str='CNN',
            ignore_make_folder:bool=False,
            rank:int=0,
        ) -> None:
        self.epochs = epochs
        self.base_folder = base_folder
        self.save_last_epochs = save_last_epochs
        self.save_per_epochs = save_per_epochs
        self.now = datetime.datetime.now()
        self.timestamp = f"{self.now.month}.{self.now.day}.{self.now.hour}"
        self.save_folder_name = f"{base_name}-{task_name}-{task_tag}-{control_type}-{model_type}-{self.timestamp}"
        self.folder_path = os.path.join(self.base_folder, self.save_folder_name)

        if not os.path.exists(self.folder_path) and not ignore_make_folder and rank == 0:
            os.mkdir(self.folder_path)
    
    def update(self, current_epoch:int, model:any):
        if self.epochs - (current_epoch + 1) < self.save_last_epochs:
             torch.save(model.state_dict(), os.path.join(self.folder_path, f'last_epoch{current_epoch+1}'+'.pt'))
        elif (current_epoch + 1) % self.save_per_epochs == 0:
            torch.save(model.state_dict(), os.path.join(self.folder_path, f'[training]epoch{current_epoch+1}'+'.pt'))

    def save_training_end(self, model:any):
        torch.save(model.state_dict(), os.path.join(self.folder_path, f'final_epoch{self.epochs}'+'.pt'))

    def save_checkpoint(self, model:any, current_epoch:int, optimizer:any, lr_scheduler:any, wandb_id:any):
        checkpoint = { 
            'epoch': current_epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'save_folder_path':self.folder_path
        }
        torch.save(checkpoint, os.path.join(self.folder_path, f'checkpoint_{wandb_id}.pth'))