import os 
from omegaconf import OmegaConf 

import lightning as L 


MULTINODE_HACKS =True 



class SetupCallback (L .Callback ):
    def __init__ (
    self ,
    resume ,
    now ,
    logdir ,
    ckptdir ,
    cfgdir ,
    config ,
    lightning_config ,
    debug ,
    ckpt_name =None ,
    ):
        super ().__init__ ()
        self .resume =resume 
        self .now =now 
        self .logdir =logdir 
        self .ckptdir =ckptdir 
        self .cfgdir =cfgdir 
        self .config =config 
        self .lightning_config =lightning_config 
        self .debug =debug 
        self .ckpt_name =ckpt_name 

    def on_exception (self ,trainer :L .Trainer ,pl_module ,exception ):
        if not self .debug and trainer .global_rank ==0 :
            print ("Summoning checkpoint.")
            if self .ckpt_name is None :
                ckpt_path =os .path .join (self .ckptdir ,"last.ckpt")
            else :
                ckpt_path =os .path .join (self .ckptdir ,self .ckpt_name )
            trainer .save_checkpoint (ckpt_path )

    def on_fit_start (self ,trainer ,pl_module ):
        if trainer .global_rank ==0 :

            os .makedirs (self .logdir ,exist_ok =True )
            os .makedirs (self .ckptdir ,exist_ok =True )
            os .makedirs (self .cfgdir ,exist_ok =True )

            if "callbacks"in self .lightning_config :
                if (
                "metrics_over_trainsteps_checkpoint"
                in self .lightning_config ["callbacks"]
                ):
                    os .makedirs (
                    os .path .join (self .ckptdir ,"trainstep_checkpoints"),
                    exist_ok =True ,
                    )
            print ("Project config")
            print (OmegaConf .to_yaml (self .config ))
            if MULTINODE_HACKS :
                import time 

                time .sleep (5 )
            OmegaConf .save (
            self .config ,
            os .path .join (self .cfgdir ,"{}-project.yaml".format (self .now )),
            )

            print ("Lightning config")
            print (OmegaConf .to_yaml (self .lightning_config ))
            OmegaConf .save (
            OmegaConf .create ({"lightning":self .lightning_config }),
            os .path .join (self .cfgdir ,"{}-lightning.yaml".format (self .now )),
            )

        else :

            if not MULTINODE_HACKS and not self .resume and os .path .exists (self .logdir ):
                dst ,name =os .path .split (self .logdir )
                dst =os .path .join (dst ,"child_runs",name )
                os .makedirs (os .path .split (dst )[0 ],exist_ok =True )
                try :
                    os .rename (self .logdir ,dst )
                except FileNotFoundError :
                    pass 
