
import argparse
import torch
import datetime
import os
import glob

from dateutil import tz
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    LearningRateMonitor,
)

import lgdea
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c",
        "--config",
        metavar="all_stages_config.yaml",
        help="paths to config file",
        required=True,
    )
    parser.add_argument(
        "--train", action="store_true", default=False, help="specify to train model"
    )
    parser.add_argument(
        "--test",
        action="store_true",
        default=False,
        help="specify to test model",
    )
    parser.add_argument("--random_seed", type=int, default=23, help="Random seed")
    parser.add_argument(
        "--train_pct", type=float, default=1.0, help="Percent of training data"
    )
    parser.add_argument(
        "--skip_stage1", action="store_true", default=False, help="Skip Stage 1 if checkpoint exists"
    )
    parser.add_argument(
        "--skip_stage2", action="store_true", default=False, help="Skip Stage 2 if checkpoint exists"
    )
    parser.add_argument(
        "--skip_stage3", action="store_true", default=False, help="Skip Stage 3 if checkpoint exists"
    )
    parser = Trainer.add_argparse_args(parser)
    return parser


def find_best_checkpoint(checkpoint_dir):
    if not os.path.exists(checkpoint_dir):
        return None

    best_ckpt = os.path.join(checkpoint_dir, "best.ckpt")
    if os.path.exists(best_ckpt):
        return best_ckpt

    last_ckpt = os.path.join(checkpoint_dir, "last.ckpt")
    if os.path.exists(last_ckpt):
        return last_ckpt

    ckpt_files = glob.glob(os.path.join(checkpoint_dir, "*.ckpt"))
    if ckpt_files:
        return max(ckpt_files, key=os.path.getmtime)

    ckpt_files = glob.glob(os.path.join(checkpoint_dir, "**/*.ckpt"), recursive=True)
    if ckpt_files:
        best_files = [f for f in ckpt_files if os.path.basename(f) == "best.ckpt"]
        if best_files:
            return max(best_files, key=os.path.getmtime)
        last_files = [f for f in ckpt_files if os.path.basename(f) == "last.ckpt"]
        if last_files:
            return max(last_files, key=os.path.getmtime)
        return max(ckpt_files, key=os.path.getmtime)

    return None


def run_stage(cfg, stage_name, stage_cfg, args):
    print(f"\n{'=' * 80}")
    print(f"Starting {stage_name}")
    print(f"{'=' * 80}\n")

    stage_cfg.extension = cfg.extension
    stage_cfg.output_dir = os.path.join(cfg.output_dir, stage_name)
    if "checkpoint_callback" in stage_cfg.lightning:
        stage_cfg.lightning.checkpoint_callback.dirpath = os.path.join(
            stage_cfg.lightning.checkpoint_callback.dirpath,
            f"{stage_cfg.experiment_name}/{stage_cfg.extension}",
        )

    logger_type = None
    if "logger" in stage_cfg.lightning and "logger_type" in stage_cfg.lightning.logger:
        logger_type = stage_cfg.lightning.logger.logger_type

    if "logger" in stage_cfg.lightning and "save_dir" in stage_cfg.lightning.logger:
        if logger_type != "WandbLogger":
            stage_cfg.lightning.logger.save_dir = os.path.join(
                stage_cfg.lightning.logger.save_dir,
                stage_name
            )

    if "checkpoint_callback" in stage_cfg.lightning:
        os.makedirs(stage_cfg.lightning.checkpoint_callback.dirpath, exist_ok=True)
    if "logger" in stage_cfg.lightning and "save_dir" in stage_cfg.lightning.logger:
        if logger_type != "WandbLogger":
            os.makedirs(stage_cfg.lightning.logger.save_dir, exist_ok=True)
    os.makedirs(stage_cfg.output_dir, exist_ok=True)

    config_path = os.path.join(stage_cfg.output_dir, "config.yaml")
    OmegaConf.save(config=stage_cfg, f=config_path)

    dm =lgdea.builder.build_data_module(stage_cfg)
    model = lgdea.builder.build_lightning_model(stage_cfg, dm)

    callbacks = [LearningRateMonitor(logging_interval="step")]
    checkpoint_callback = None
    if "checkpoint_callback" in stage_cfg.lightning:
        checkpoint_callback = ModelCheckpoint(**stage_cfg.lightning.checkpoint_callback)
        callbacks.append(checkpoint_callback)
    if "early_stopping_callback" in stage_cfg.lightning:
        callbacks.append(EarlyStopping(**stage_cfg.lightning.early_stopping_callback))

    logger = None
    if "logger" in stage_cfg.lightning:
        logger_type = stage_cfg.lightning.logger.pop("logger_type")
        logger_class = getattr(pl_loggers, logger_type)
        stage_cfg.lightning.logger.name = f"{stage_cfg.experiment_name}_{stage_cfg.extension}"

        if logger_type == "WandbLogger":
            if "project" not in stage_cfg.lightning.logger:
                if stage_name == "stage2":
                    stage_cfg.lightning.logger["project"] = "graph_evalign_stage2"
                elif stage_name == "stage3":
                    stage_cfg.lightning.logger["project"] = "graph_evalign_stage3"
            if "save_dir" in stage_cfg.lightning.logger:
                stage_cfg.lightning.logger.pop("save_dir")

        logger = logger_class(**stage_cfg.lightning.logger)
        stage_cfg.lightning.logger.logger_type = logger_type

    if 'auto_lr_find' not in stage_cfg.lightning.trainer:
        stage_cfg.lightning.trainer.auto_lr_find = False

    trainer_args = argparse.Namespace(**stage_cfg.lightning.trainer)
    trainer = Trainer.from_argparse_args(
        args=trainer_args,
        deterministic=False,
        callbacks=callbacks,
        logger=logger,
        accelerator="gpu",
        devices=2,
        strategy='ddp'
    )

    if hasattr(trainer_args, 'auto_lr_find') and trainer_args.auto_lr_find is not False:
        try:
            lr_finder = trainer.tuner.lr_find(model, datamodule=dm)
            model.lr = lr_finder.suggestion()
            print(f"Learning rate updated to {model.lr}")
        except Exception as e:
            print(f"Learning rate finder failed: {e}")

    if args.train:
        trainer.fit(model, dm)

    if args.test:
        ckpt_path = checkpoint_callback.best_model_path if (checkpoint_callback and args.train) else getattr(
            stage_cfg.model, 'checkpoint', None)
        if ckpt_path is None:
            raise ValueError("Checkpoint path is required for testing")
        trainer.test(model=model, datamodule=dm, ckpt_path=ckpt_path)

    if checkpoint_callback and "checkpoint_callback" in stage_cfg.lightning:
        checkpoint_callback.to_yaml(filepath=os.path.join(
            stage_cfg.lightning.checkpoint_callback.dirpath, "best_ckpts.yaml"
        ))

    if checkpoint_callback and hasattr(checkpoint_callback, 'best_model_path') and checkpoint_callback.best_model_path:
        return checkpoint_callback.best_model_path
    elif "checkpoint_callback" in stage_cfg.lightning:
        return find_best_checkpoint(stage_cfg.lightning.checkpoint_callback.dirpath)

    return None


def main(cfg, args):
    if not args.skip_stage1:
        stage1_cfg = OmegaConf.create(cfg.stage1)
        stage1_cfg.experiment_name = cfg.stage1.experiment_name
        stage1_cfg.phase = cfg.stage1.phase
        stage1_cfg.data.frac = args.train_pct
        if cfg.trial_name is not None:
            stage1_cfg.experiment_name = f"{stage1_cfg.experiment_name}_{cfg.trial_name}"

        stage1_ckpt = run_stage(cfg, "stage1", stage1_cfg, args)
        if stage1_ckpt:
            print(f"\nStage 1 completed. Checkpoint: {stage1_ckpt}\n")
        else:
            print("\nStage 1 completed but checkpoint not found.\n")
    else:
        stage1_ckpt = None
        searched_paths = []
        from_config = False
        config_specified_path = None

        if hasattr(cfg, 'stage2') and hasattr(cfg.stage2, 'model') and hasattr(cfg.stage2.model,
                                                                               'stage1_checkpoint_path'):
            specified_path = cfg.stage2.model.stage1_checkpoint_path
            if specified_path is not None and specified_path != "":
                config_specified_path = specified_path  # 保存配置文件中指定的路径
                searched_paths.append(f"From stage2 config: {specified_path}")
                if os.path.exists(specified_path):
                    stage1_ckpt = specified_path
                    from_config = True
                    print(f"\nUsing Stage 1 checkpoint from stage2 config: {stage1_ckpt}\n")
                else:
                    print(
                        f"Warning: Stage1 checkpoint path specified in stage2 config does not exist: {specified_path}")
                    print(
                        f"Will try to search for checkpoint in other locations, but will fall back to config path if search fails.")

        if stage1_ckpt is None:
            stage1_cfg = OmegaConf.create(cfg.stage1)
            if cfg.trial_name is not None:
                stage1_cfg.experiment_name = f"{stage1_cfg.experiment_name}_{cfg.trial_name}"

            if "checkpoint_callback" in stage1_cfg.lightning:
                base_dir = stage1_cfg.lightning.checkpoint_callback.dirpath

                checkpoint_dir = os.path.join(
                    base_dir,
                    f"{stage1_cfg.experiment_name}/{cfg.extension}"
                )
                searched_paths.append(checkpoint_dir)
                stage1_ckpt = find_best_checkpoint(checkpoint_dir)

                if stage1_ckpt is None:
                    checkpoint_dir = os.path.join(
                        base_dir,
                        stage1_cfg.experiment_name
                    )
                    searched_paths.append(checkpoint_dir)
                    stage1_ckpt = find_best_checkpoint(checkpoint_dir)

                if stage1_ckpt is None:
                    searched_paths.append(base_dir)
                    stage1_ckpt = find_best_checkpoint(base_dir)

        if stage1_ckpt is None:
            if config_specified_path is not None:

                stage1_ckpt = None
                print(
                    f"\nStage 1 checkpoint not found in searched locations, but will use path from yaml config: {config_specified_path}\n")
            else:
                error_msg = "Stage 1 checkpoint not found. Please run Stage 1 first or remove --skip_stage1 flag.\n"
                error_msg += f"Searched paths:\n"
                for path in searched_paths:
                    if path.startswith("From stage2 config:"):
                        error_msg += f"  - {path}\n"
                    else:
                        exists = "exists" if os.path.exists(path) else "not exists"
                        error_msg += f"  - {path} ({exists})\n"
                raise ValueError(error_msg)

        if not from_config:
            print(f"\nUsing existing Stage 1 checkpoint: {stage1_ckpt}\n")


    if not args.skip_stage2:
        stage2_cfg = OmegaConf.create(cfg.stage2)
        stage2_cfg.experiment_name = cfg.stage2.experiment_name
        stage2_cfg.phase = cfg.stage2.phase
        stage2_cfg.data.frac = args.train_pct

        if stage1_ckpt is not None:
            stage2_cfg.model.stage1_checkpoint_path = stage1_ckpt
        else:
            if not hasattr(stage2_cfg.model,
                           'stage1_checkpoint_path') or stage2_cfg.model.stage1_checkpoint_path is None or stage2_cfg.model.stage1_checkpoint_path == "":
                raise ValueError(
                    "Stage 1 checkpoint is required for Stage 2. Please run Stage 1 first, provide a valid checkpoint in yaml config (stage2.model.stage1_checkpoint_path), or remove --skip_stage1 flag.")
            print(f"\nUsing Stage 1 checkpoint path from yaml config: {stage2_cfg.model.stage1_checkpoint_path}\n")

        if cfg.trial_name is not None:
            stage2_cfg.experiment_name = f"{stage2_cfg.experiment_name}_{cfg.trial_name}"

        stage2_ckpt = run_stage(cfg, "stage2", stage2_cfg, args)
        if stage2_ckpt:
            print(f"\nStage 2 completed. Checkpoint: {stage2_ckpt}\n")
        else:
            print("\nStage 2 completed but checkpoint not found.\n")
    else:
        stage2_ckpt = None


    stage3_ckpt = None
    if not args.skip_stage3:
        stage3_cfg = OmegaConf.create(cfg.stage3)
        stage3_cfg.experiment_name = cfg.stage3.experiment_name
        stage3_cfg.phase = cfg.stage3.phase
        stage3_cfg.data.frac = args.train_pct

        if stage1_ckpt is not None:
            stage3_cfg.model.stage1_checkpoint_path = stage1_ckpt
        else:
            if not hasattr(stage3_cfg.model,
                           'stage1_checkpoint_path') or stage3_cfg.model.stage1_checkpoint_path is None or stage3_cfg.model.stage1_checkpoint_path == "":
                raise ValueError(
                    "Stage 1 checkpoint is required for Stage 3. Please run Stage 1 first, provide a valid checkpoint in yaml config (stage3.model.stage1_checkpoint_path), or remove --skip_stage1 flag.")

            specified_path = stage3_cfg.model.stage1_checkpoint_path
            if not os.path.exists(specified_path):
                raise ValueError(f"Stage1 checkpoint path specified in yaml does not exist: {specified_path}")
            stage3_cfg.model.stage1_checkpoint_path = specified_path
            print(f"\nUsing Stage 1 checkpoint path from yaml config: {specified_path}\n")

        if stage2_ckpt is not None:
            stage3_cfg.model.stage2_checkpoint_path = stage2_ckpt
            print(f"\nUsing stage2_checkpoint_path from Stage 2 output: {stage2_ckpt}\n")
        else:
            if not hasattr(stage3_cfg.model,
                           'stage2_checkpoint_path') or stage3_cfg.model.stage2_checkpoint_path is None:
                raise ValueError(
                    "stage2_checkpoint_path must be specified in yaml config (stage3.model.stage2_checkpoint_path) or Stage 2 must be run first.")

            specified_path = stage3_cfg.model.stage2_checkpoint_path
            if not os.path.exists(specified_path):
                raise ValueError(f"Stage2 checkpoint path specified in yaml does not exist: {specified_path}")

            stage3_cfg.model.stage2_checkpoint_path = specified_path
            print(f"\nUsing stage2_checkpoint_path from yaml config: {specified_path}\n")

        if cfg.trial_name is not None:
            stage3_cfg.experiment_name = f"{stage3_cfg.experiment_name}_{cfg.trial_name}"

        stage3_ckpt = run_stage(cfg, "stage3", stage3_cfg, args)
        if stage3_ckpt:
            print(f"\nStage 3 completed. Checkpoint: {stage3_ckpt}\n")
        else:
            print("\nStage 3 completed but checkpoint not found.\n")

    print(f"\n{'=' * 80}")
    print("All stages completed successfully!")
    print(f"{'=' * 80}\n")
    print(f"Stage 1 checkpoint: {stage1_ckpt}")
    print(f"Stage 2 checkpoint: {stage2_ckpt}")
    if not args.skip_stage3:
        print(f"Stage 3 checkpoint: {stage3_ckpt}")


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()
    cfg = OmegaConf.load(args.config)

    now = datetime.datetime.now(tz.tzlocal())
    timestamp = now.strftime("%Y_%m_%d_%H_%M_%S")
    seed_everything(args.random_seed)
    cfg.extension = timestamp
    cfg.output_dir = f"./data/output/{cfg.experiment_name}/{cfg.extension}"

    os.makedirs(cfg.output_dir, exist_ok=True)

    # save main config
    config_path = os.path.join(cfg.output_dir, "all_stages_config.yaml")
    OmegaConf.save(config=cfg, f=config_path)

    main(cfg, args)
