import argparse 
import datetime 
import glob 
import inspect 
import os 
import sys 

import lightning as L 
import torch 
from lightning .pytorch .loggers import WandbLogger 
from omegaconf import OmegaConf 

from ltm .utils import get_checkpoint_name ,instantiate_from_config 
from ltm .modules .lightning .fsdp import fsdp_huggingface 

os .environ ["HF_HUB_ENABLE_HF_TRANSFER"]="1"


def default_trainer_args ():
    argspec =dict (inspect .signature (L .Trainer .__init__ ).parameters )
    argspec .pop ("self")
    default_args ={
    param :argspec [param ].default 
    for param in argspec 
    if argspec [param ]!=inspect .Parameter .empty 
    }
    return default_args 


def get_parser ():
    parser =argparse .ArgumentParser ()
    parser .add_argument (
    "-b",
    "--base",
    nargs ="*",
    metavar ="base_config.yaml",
    help ="paths to base configs. Loaded from left-to-right. "
    "Parameters can be overwritten or added with command-line options of the form `--key value`.",
    default =list (),
    )

    parser .add_argument (
    "--resume",type =str ,default =None ,help ="resume from checkpoint"
    )
    parser .add_argument (
    "--resume_from_checkpoint",
    type =str ,
    default =None ,
    help ="single checkpoint file to resume from",
    )
    parser .add_argument (
    "--projectname",
    type =str ,
    default ="lltm",
    help ="project name in wandb",
    )
    parser .add_argument ("--debug",action ="store_true",help ="debug mode")
    parser .add_argument ("--postfix",type =str ,default ="",help ="postfix for the run")
    parser .add_argument ("--logdir",type =str ,default ="logs",help ="logdir for the run")
    parser .add_argument ("--seed",type =int ,default =42 ,help ="seed for reproducibility")
    parser .add_argument ("--name",type =str ,default =None ,help ="name for the run")
    default_args =default_trainer_args ()
    for key in default_args :
        parser .add_argument ("--"+key ,default =default_args [key ])

    return parser 


def main ():
    now =datetime .datetime .now (
    tz =datetime .timezone (datetime .timedelta (hours =9 ),name ="JST")
    ).strftime ("%Y-%m-%dT%H-%M-%S")
    sys .path .append (os .getcwd ())
    parser =get_parser ()
    args ,unknow =parser .parse_known_args ()


    melk_ckpt_name =None 
    name =None 
    if args .resume :
        if not os .path .exists (args .resume ):
            raise ValueError ("Cannot find {}".format (args .resume ))
        if os .path .isfile (args .resume )or args .resume .rstrip ("/").endswith (".ckpt"):
            paths =args .resume .split ("/")
            logdir ="/".join (paths [:-2 ])
            ckpt =args .resume 
            _ ,melk_ckpt_name =get_checkpoint_name (logdir )
        else :
            assert os .path .isdir (args .resume ),args .resume 
            logdir =args .resume .rstrip ("/")
            ckpt ,melk_ckpt_name =get_checkpoint_name (logdir )

        print ("#"*100 )
        print (f'Resuming from checkpoint "{ckpt }"')
        print ("#"*100 )

        args .resume_from_checkpoint =ckpt 
        configs =sorted (glob .glob (os .path .join (logdir ,"configs/*.yaml")))
        args .base =configs +args .base 
        _tmp =logdir .split ("/")
        nowname =_tmp [-1 ]
    else :
        cfg_path =os .path .split (args .base [0 ])[0 ].split (os .sep )[
        os .path .split (args .base [0 ])[0 ].split (os .sep ).index ("configs")+1 :
        ]
        cfg_name =os .path .splitext (os .path .split (args .base [0 ])[-1 ])[0 ]
        cfg_name ="-".join (cfg_path )+f"-{cfg_name }"
        name ="_"+cfg_name 
        if args .name :
            nowname =args .name +args .postfix 
        else :
            nowname =now +name +args .postfix 
        if nowname .startswith ("_"):
            nowname =nowname [1 :]
        logdir =os .path .join (args .logdir ,nowname )
        print ("Logdir:",logdir )
    ckptdir =os .path .join (logdir ,"checkpoints")
    cfgdir =os .path .join (logdir ,"configs")
    L .seed_everything (args .seed ,workers =True )

    torch .set_float32_matmul_precision ("high")
    torch .backends .cuda .matmul .allow_tf32 =True 
    torch .backends .cudnn .allow_tf32 =True 

    configs =[OmegaConf .load (cfg )for cfg in args .base ]
    cli =OmegaConf .from_dotlist (unknow )
    config =OmegaConf .merge (*configs ,cli )
    lightning_config =config .pop ("lightning",OmegaConf .create ())

    trainer_config =lightning_config .get ("trainer",OmegaConf .create ())
    standard_args =default_trainer_args ()
    for k in standard_args :
        if getattr (args ,k )!=standard_args [k ]:
            trainer_config [k ]=getattr (args ,k )
    trainer_opt =argparse .Namespace (**trainer_config )
    trainer_kwargs ={}
    lightning_config .trainer =trainer_config 

    model :L .LightningModule =instantiate_from_config (
    OmegaConf .to_container (config .model ,resolve =True )
    )

    ignore_base ="configs"+os .sep 
    path =args .base [-1 ]
    start_idx =path .find (ignore_base )
    if start_idx !=-1 :
        path =path [start_idx +len (ignore_base ):]
    path_dirs =path .split (os .sep )[:1 ]
    group_name ="-".join (path_dirs )
    trainer_kwargs ["logger"]=WandbLogger (
    name =nowname ,
    project =args .projectname ,
    offline =args .debug ,
    save_dir =logdir ,
    group =group_name ,
    )

    default_modelckpt_cfg ={
    "target":"lightning.pytorch.callbacks.ModelCheckpoint",
    "params":{
    "dirpath":ckptdir ,
    "filename":"{epoch:06}-{step:09}",
    "verbose":True ,
    "save_last":True ,
    },
    }
    if "modelcheckpoint"in lightning_config :
        modelckpt_cfg =lightning_config .modelcheckpoint 
    else :
        modelckpt_cfg =OmegaConf .create ()
    modelckpt_cfg =OmegaConf .merge (default_modelckpt_cfg ,modelckpt_cfg )
    print (f"Merged modelckpt-cfg: \n{modelckpt_cfg }")
    default_callbacks_cfg ={
    "setup_callback":{
    "target":"ltm.modules.lightning.SetupCallback",
    "params":{
    "resume":args .resume ,
    "now":now ,
    "logdir":logdir ,
    "ckptdir":ckptdir ,
    "cfgdir":cfgdir ,
    "config":config ,
    "lightning_config":lightning_config ,
    "debug":args .debug ,
    "ckpt_name":melk_ckpt_name ,
    },
    },
    "learning_rate_logger":{
    "target":"lightning.pytorch.callbacks.LearningRateMonitor",
    "params":{
    "logging_interval":"step",
    },
    },
    "checkpoint_callback":modelckpt_cfg ,
    }
    if "callbacks"in lightning_config :
        callbacks_cfg =lightning_config .callbacks 
    else :
        callbacks_cfg =OmegaConf .create ()
    callbacks_cfg =OmegaConf .merge (default_callbacks_cfg ,callbacks_cfg )
    trainer_kwargs ["callbacks"]=[
    instantiate_from_config (callbacks_cfg [k ])for k in callbacks_cfg 
    ]

    if "strategy"in lightning_config :
        strategy_cfg =lightning_config .strategy 
        print (f"Strategy: \n{strategy_cfg }")
        if strategy_cfg .target =="fsdp_huggingface":
            trainer_kwargs ["strategy"]=fsdp_huggingface (
            model =model ,**strategy_cfg .get ("params",{})
            )
        else :
            trainer_kwargs ["strategy"]=instantiate_from_config (strategy_cfg )

    trainer_opt =vars (trainer_opt )
    trainer_kwargs ={
    key :val for key ,val in trainer_kwargs .items ()if key not in trainer_opt 
    }
    if "num_nodes"in trainer_opt :
        trainer_opt ["num_nodes"]=int (trainer_opt ["num_nodes"])
    trainer =L .Trainer (**trainer_opt ,**trainer_kwargs )
    trainer .logdir =logdir 


    data :L .LightningDataModule =instantiate_from_config (
    OmegaConf .to_container (config .data ,resolve =True )
    )
    model .tokenizer =data .tokenizer 
    data .prepare_data ()


    trainer .fit (model =model ,datamodule =data ,ckpt_path =args .resume_from_checkpoint )

    trainer .print ("Training successfully completed!")
    trainer .print (
    f"Peak memory usage: {torch .cuda .max_memory_allocated ()/1e9 :.02f} GB"
    )


if __name__ =="__main__":
    main ()
