# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
import os
import pdb

if("_CUDA_VISIBLE_DEVICES"in os.environ):
    os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
import argparse
import logging
from pathlib import Path

import torch,platform
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger 
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
torch.set_float32_matmul_precision('high')
from AR.utils import get_newest_ckpt

from collections import OrderedDict
class my_model_ckpt(ModelCheckpoint):
    def __init__(self,config,if_save_latest,if_save_every_weights,half_weights_save_dir,exp_name,**kwargs):
        super().__init__(**kwargs)
        self.if_save_latest=if_save_latest
        self.if_save_every_weights=if_save_every_weights
        self.half_weights_save_dir=half_weights_save_dir
        self.exp_name=exp_name
        self.config=config
  
    def on_train_epoch_end(self, trainer, pl_module):

        if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
            monitor_candidates = self._monitor_candidates(trainer)
            if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
                if(self.if_save_latest==True):  
                    to_clean=list(os.listdir(self.dirpath))

         
                self._save_topk_checkpoint(trainer, monitor_candidates)
                if (self.if_save_latest == True):
                    for name in to_clean:
                        try:
                            os.remove("%s/%s"%(self.dirpath,name))
                        except:pass
                if(self.if_save_every_weights==True):
                    to_save_od=OrderedDict()
                    to_save_od["weight"]=OrderedDict()
                    dictt=trainer.strategy._lightning_module.state_dict()
                    for key in dictt:to_save_od["weight"][key]=dictt[key].half()
                    to_save_od["config"]=self.config
                    to_save_od["info"]="GPT-e%s"%(trainer.current_epoch+1)
                    torch.save(to_save_od,"%s/%s-e%s.ckpt"%(self.half_weights_save_dir,self.exp_name,trainer.current_epoch+1))
            
            self._save_last_checkpoint(trainer, monitor_candidates)


def main(args):
    config = load_yaml_config(args.config_file)

    output_dir = Path(config["output_dir"])
    output_dir.mkdir(parents=True, exist_ok=True)

    ckpt_dir = output_dir / 'ckpt'
    ckpt_dir.mkdir(parents=True, exist_ok=True)


    seed_everything(config["train"]["seed"], workers=True)
    ckpt_callback: ModelCheckpoint = my_model_ckpt(
        config=config,
        if_save_latest=config["train"]["if_save_latest"], if_save_every_weights=config["train"]["if_save_every_weights"], half_weights_save_dir=config["train"]["half_weights_save_dir"], exp_name=config["train"]["exp_name"],
        save_top_k=-1,
        monitor='top_3_acc',
        mode='max',
        save_on_train_epoch_end=True,
        every_n_epochs=config["train"]["save_every_n_epoch"],
        dirpath=ckpt_dir,
    )
    logger = TensorBoardLogger(
        name=output_dir.stem,
        save_dir=output_dir
    )
    trainer: Trainer = Trainer(
        max_epochs=config["train"]["epochs"],
        accelerator='gpu',
        limit_val_batches=0,
        devices=-1,
        benchmark=False,
        fast_dev_run=False,
        strategy=DDPStrategy(process_group_backend="nccl"if platform.system()!="Windows"else "gloo"),
        precision=config["train"]["precision"],
        logger=logger,num_sanity_val_steps=0,
        callbacks=[ckpt_callback])

    model: Text2SemanticLightningModule = Text2SemanticLightningModule(
        config, output_dir)

    data_module: Text2SemanticDataModule = Text2SemanticDataModule(
        config,
        train_semantic_path=config["train_semantic_path"],
        train_phoneme_path=config["train_phoneme_path"],
    )

    try:
        newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
        ckpt_path = ckpt_dir / newest_ckpt_name
    except Exception:
        ckpt_path = None
    print("ckpt_path:", ckpt_path)

    import pathlib
    temp = pathlib.PosixPath
    pathlib.PosixPath = pathlib.WindowsPath
    trainer.fit(model, data_module, ckpt_path=ckpt_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-c',
        '--config_file',
        type=str,
        default='configs/s1.yaml',
        help='path of config file')
   
    args = parser.parse_args()
    main(args)
