# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import os
import threading

import torch
from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType

from cosmos_predict1.utils import callback, distributed, log, misc
from cosmos_predict1.utils.config import CheckpointConfig, JobConfig
from cosmos_predict1.utils.easy_io import easy_io
from cosmos_predict1.utils.fsdp_optim_fix import scatter_full_optim_state_dict
from cosmos_predict1.utils.model import Model


class FSDPCheckpointer:
    """The checkpointer class. Supports checkpoint saving/loading to local disk."""

    def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup):
        """Constructor of the checkpointer.

        Args:
            config_checkpoint (CheckpointConfig): The config object for the checkpointer.
        """
        # Set the callback functions.
        self.callbacks = callbacks
        self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints"
        self.strict_resume = config_checkpoint.strict_resume
        self.load_path = config_checkpoint.load_path
        self.load_training_state = config_checkpoint.load_training_state
        self.save_thread = None
        self.config_checkpoint = config_checkpoint

    def _load_ckpt_file_during_init(self):
        latest_checkpoint_file = self._read_latest_checkpoint_file()
        if latest_checkpoint_file is not None:
            # 1. Resume training from latest_checkpoint.txt under the same name.
            checkpoint_dir = self.checkpoint_dir_local
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
            resume = True
            log.critical(f"[Checkpoint] Found latest checkpoint file: {latest_checkpoint_file}")
            log.critical(f"[Checkpoint] Loading from local path: {checkpoint_path}")
            log.critical("[Checkpoint] Will resume full training state (model, optimizer, scheduler)")
        else:
            if self.load_path:
                # 2. Load the module weights specified by config_checkpoint.path.
                checkpoint_path = self.load_path
                resume = self.load_training_state
                log.critical(f"[Checkpoint] Using specified checkpoint path: {checkpoint_path}")
                if resume:
                    log.critical("[Checkpoint] Will load complete training state (model, optimizer, scheduler)")
                else:
                    log.critical("[Checkpoint] Will load model weights only (no optimizer/scheduler state)")
            else:
                # 3. Randomly initialize the model parameters and train from scratch.
                checkpoint_path = None
                resume = False
                log.critical("[Checkpoint] No checkpoint path specified")
                log.critical("[Checkpoint] Starting fresh training with random initialization")
        return checkpoint_path, resume

    @misc.timer("FSDP.load_model_during_init")
    def load_model_during_init(self, model, is_ema=False, ema_id: int = 0):
        if ema_id > 0:
            assert is_ema, "ema_id should be used with is_ema=True"
        checkpoint_path, _ = self._load_ckpt_file_during_init()
        if checkpoint_path is not None:
            tag = "reg" if not is_ema else "ema"
            default_checkpoint_path = checkpoint_path.replace(".pt", f"_{tag}_model.pt")
            if not os.path.exists(default_checkpoint_path):
                default_checkpoint_path = checkpoint_path  # starting from the release checkpoint
                log.warning(f"{is_ema} model is not found. Loading from {default_checkpoint_path}")
            if tag == "ema" and ema_id > 0:
                _checkpoint_path = checkpoint_path.replace(".pt", f"_RANK{ema_id}.pt")
                _checkpoint_path = _checkpoint_path.replace(".pt", f"_{tag}_model.pt")
                if self._check_checkpoint_exists(_checkpoint_path, is_raise=False):
                    default_checkpoint_path = _checkpoint_path
                else:
                    print(
                        f"{distributed.get_rank()}: Checkpoint not found: {_checkpoint_path} "
                        f"(fallback to {default_checkpoint_path})"
                    )
            checkpoint_path = default_checkpoint_path
            self._check_checkpoint_exists(checkpoint_path)

            log.info(f"Loading checkpoint (local): {checkpoint_path}")
            state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False)
            log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
            log.info("- Loading the model...")
            if self.strict_resume:
                log.info(model.load_state_dict(state_dict, strict=self.strict_resume))
            else:
                log.critical("\t Using non-strict model")
                from cosmos_predict1.diffusion.training.utils.checkpointer import non_strict_load_model

                log.info(non_strict_load_model(model, state_dict))
            log.info("-finish model loading")
        else:
            log.info(f"{is_ema} model is not founded and loaded.")

    @misc.timer("FSDP.load_optim_scheduler_during_init")
    def load_optim_scheduler_during_init(self, fsdp_model, optimizer, scheduler):
        checkpoint_path, resume = self._load_ckpt_file_during_init()
        log.critical(f"Loading optimizer and scheduler: {checkpoint_path} (resume: {resume}")
        if checkpoint_path is not None:
            if resume:
                checkpoint_path = checkpoint_path.replace(".pt", "_optim.pt")
                self._check_checkpoint_exists(checkpoint_path)
                if distributed.get_rank() == 0:
                    log.info(f"Loading checkpoint (local): {checkpoint_path}")
                    state_dict = torch.load(
                        checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False
                    )
                    log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
                    log.info("- Loading the optimizer (FSDP scatter)...")
                else:
                    state_dict = {
                        "optimizer": None,
                        "scheduler": None,
                    }
                distributed.barrier()
                sharded_optimizer_state_dict = scatter_full_optim_state_dict(  # <---- FSDP
                    state_dict["optimizer"],
                    fsdp_model,
                )
                log.info("- Loading the optimizer (FSDP load_state_dict)...")
                log.info(optimizer.load_state_dict(sharded_optimizer_state_dict))
                log.critical("Skip loading the scheduler...")
                return
                log.info("- Loading the scheduler...")
                scheduler.load_state_dict(state_dict["scheduler"])

    @misc.timer("FSDP get_optim_scheduler_state")
    def get_optim_scheduler_state(self, optim, fsdp_model, scheduler):
        with FSDP.state_dict_type(
            fsdp_model,
            StateDictType.FULL_STATE_DICT,
            FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
            FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
        ):
            optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim)
        scheduler_statedict = scheduler.state_dict()
        return {
            "optimizer": optim_statedict,
            "scheduler": scheduler_statedict,
        }

    def save(
        self,
        model: Model,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler.LRScheduler,
        grad_scaler: torch.amp.GradScaler,
        iteration: int,
    ) -> None:
        """Save network weights, optimizer parameters, scheduler parameters to a checkpoint.

        Args:
            model (Model): The PyTorch model.
            optimizer (torch.optim.Optimizer): The model optimizer.
            scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
            grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
            iteration (int): Current iteration number.
        """
        self.callbacks.on_save_checkpoint_start(model, iteration)

        model_state_dict = model.state_dict_model()
        optim_scheduler_state_dict = self.get_optim_scheduler_state(optimizer, model.model, scheduler)
        torch.cuda.empty_cache()
        state_dict = dict(
            iteration=iteration,
        )
        self.callbacks.on_save_checkpoint(model, state_dict=state_dict)

        postfix, replicate_idx, shard_idx, total_ema_num = model.get_ckpt_postfix()
        if replicate_idx == 0 and shard_idx == 0:
            pass  # save whole; it is rank0
        elif replicate_idx < total_ema_num and shard_idx == 0:
            model_state_dict["model"] = None  # only save ema
            optim_scheduler_state_dict = None
            state_dict = None
        else:
            return

        checkpoint_file = f"iter_{iteration:09}{postfix}.pt"
        # Wait for previous saver thread to end.
        if self.save_thread:
            self.save_thread.join()
        # Run the checkpoint saver in a separate thread.
        self.save_thread = threading.Thread(
            target=self._save_worker_local,
            daemon=False,
            args=(model_state_dict, optim_scheduler_state_dict, state_dict, checkpoint_file, distributed.get_rank()),
        )
        self.save_thread.start()

        # Note: Checkpoints are saved on a separate thread and this callback is not accurate.
        # Please check logs from on_save_checkpoint_success() for better accuracy
        self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration)

    @misc.timer("checkpoint saving (local)")
    def _save_worker_local(
        self,
        model_state_dict: dict[str, torch.Tensor],
        optim_scheduler_state_dict: dict[str, torch.Tensor],
        state_dict: dict[str, torch.Tensor],
        checkpoint_file: str,
        rank: int = 0,
    ) -> None:
        """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training).

        Args:
            state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler.
            checkpoint_file (str): The file name of the model checkpoint.
            rank (int): GPU device (default: 0).
        """
        checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file)
        os.makedirs(self.checkpoint_dir_local, exist_ok=True)
        try:
            model_state_dict, ema_model_state_dict = model_state_dict["model"], model_state_dict["ema"]
            if model_state_dict is not None:
                torch.save(model_state_dict, checkpoint_path.replace(".pt", "_reg_model.pt"))
            if ema_model_state_dict is not None:
                torch.save(ema_model_state_dict, checkpoint_path.replace(".pt", "_ema_model.pt"))
            if optim_scheduler_state_dict is not None:
                torch.save(optim_scheduler_state_dict, checkpoint_path.replace(".pt", "_optim.pt"))
            if state_dict is not None:
                torch.save(state_dict, checkpoint_path)
            if rank == 0:
                self._write_latest_checkpoint_file(checkpoint_file)
            log.success(f"Saved checkpoint (local): {checkpoint_path}")
            iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", ""))
            self.callbacks.on_save_checkpoint_success(iteration=iteration)
        except Exception as e:  # noqa: BLE001
            log.exception(f"Checkpoint failed to save (local): {e}")

    @misc.timer("checkpoint loading")
    def load(
        self,
        model: Model,
        optimizer: torch.optim.Optimizer | None = None,
        scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
        grad_scaler: torch.amp.GradScaler | None = None,
    ) -> int:
        """Load network weights and optimizer states from a checkpoint in a single process.

        The priority of the checkpoint loading logic is:
        1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name.
        2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path.
           - This is typically used for inference mode.
           - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states.
        3. If none of the above, randomly initialize the model parameters and train from scratch.

        Args:
            model (FSDPDiffModle): The PyTorch model.
            optimizer (torch.optim.Optimizer | None): The model optimizer (default: None).
            scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None).
            grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training).

        Returns:
            iteration (int): the iteration number to start/resume from.
        """
        self.callbacks.on_load_checkpoint_start(model)

        del optimizer, grad_scaler
        checkpoint_path, resume = self._load_ckpt_file_during_init()
        iteration = 0
        if checkpoint_path is not None:
            self._check_checkpoint_exists(checkpoint_path)
            log.info(f"Loading checkpoint (local): {checkpoint_path}")
            state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False)
            log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
            self.callbacks.on_load_checkpoint(model, state_dict=state_dict)
            if resume:
                iteration = state_dict["iteration"]
            log.success("Done with loading the checkpoint.")
        else:
            log.info("Training from scratch.")
        torch.cuda.empty_cache()

        self.callbacks.on_load_checkpoint_end(model)

        if scheduler is not None:
            scheduler.last_epoch = iteration
            log.critical(f"resume scheduler from {iteration}", rank0_only=False)

        return iteration

    def _read_latest_checkpoint_file(self) -> str | None:
        """Get the file name of the latest saved checkpoint. If it doesn't exist, return None.

        Returns:
            checkpoint_file (str | None): file name of the latest saved checkpoint.
        """
        checkpoint_file = None
        latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
        if os.path.isfile(latest_path):
            checkpoint_file = open(latest_path).read().strip()
        if checkpoint_file is None:
            log.warning(f"Latest ckpt file not found: {latest_path}")
        else:
            log.info(f"Found latest checkpoint: {checkpoint_file}")
        return checkpoint_file

    def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None:
        """Track the file name of the latest saved checkpoint.

        Args:
            checkpoint_file (str): file name of the latest saved checkpoint.
        """
        content = f"{checkpoint_file}\n"
        latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
        with open(latest_path, "w") as file:
            file.write(content)

    def _check_checkpoint_exists(self, checkpoint_path: str, is_raise: bool = True) -> None:
        """If the file checkpoint_path does not exist, raise an error.

        Args:
            checkpoint_path (str): full path to the checkpoint.
        """
        if not os.path.exists(checkpoint_path):
            if is_raise:
                raise FileNotFoundError(f"File not found (local): {checkpoint_path}")
        return True

    def finalize(self) -> None:
        """Finalize the checkpointer."""
        if self.save_thread:
            self.save_thread.join()


class FSDPInferenceCheckpointer:
    def __init__(
        self,
        ckpt_path: str,
        strict_resume: bool = True,
    ):
        self.ckpt_path = ckpt_path
        self.strict_resume = strict_resume

    @misc.timer("FSDPInferenceCheckpointer.load_model_during_init")
    def load_model_during_init(self, model, is_ema=False, ema_id: int = 0):
        del ema_id
        if is_ema:
            log.warning("EMA model is not supported in inference mode.")
            return
        assert easy_io.exists(self.ckpt_path)
        log.info(f"Loading from {self.ckpt_path}")
        state_dict = torch.load(self.ckpt_path, map_location=lambda storage, loc: storage, weights_only=False)
        if self.strict_resume:
            log.info(model.load_state_dict(state_dict, strict=self.strict_resume))
        else:
            log.critical("\t Using non-strict model")
            from cosmos_predict1.diffusion.training.utils.checkpointer import non_strict_load_model

            log.info(non_strict_load_model(model, state_dict))
        log.info("-finish model loading")

    def load_optim_scheduler_during_init(self, *args, **kwargs):
        """
        We do not do load in inference mode. The function is here to maintain the same interface to avoid errors.
        """
        pass

    def save(self, *args, **kwargs):
        """
        We do not save anything in inference mode. The function is here to maintain the same interface to avoid errors.
        """
        pass

    def load(self, *args, **kwargs):
        """
        We do not do load in inference mode. The function is here to maintain the same interface to avoid errors.
        """
        return 0
