import torch
import os


class Checkpoint:
    """
    Checkpoint class for saving and loading experiments
    """
    def __init__(self, epoch=-1, model=None, optimizer=None, scheduler=None, params=None):
        try:
            self.epoch: int = epoch
            self.model = model
            self.params = params

            self.scheduler_state_dict = scheduler.state_dict() if scheduler is not None else None
            self.optimizer_state_dict = optimizer.state_dict() if optimizer is not None else None
        except TypeError:
            print("Error loading experiment")

    def save(self, path, run=None, save_locally=False):
        checkpoint_dir = os.path.join(path, "checkpoints")
        os.makedirs(checkpoint_dir, exist_ok=True)

        if save_locally:
            local_checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{self.epoch}.pth")
            checkpoint_path = os.path.join(checkpoint_dir, "ref.pth")
            torch.save(local_checkpoint_path, checkpoint_path)
            torch.save(self, local_checkpoint_path)
            description = f"epoch: {self.epoch}  \n"\
                          f"path: {local_checkpoint_path}"
        else:
            checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
            torch.save(self, checkpoint_path)
            description = f"epoch: {self.epoch}"
        
        architecture_path = os.path.join(checkpoint_dir, "architecture.txt")
        if not os.path.exists(architecture_path):
            architecture = open(architecture_path, "w")
            architecture.write(str(self.model))
            architecture.close()

        if run is not None:
            import wandb
            artifact = wandb.Artifact(self.params.log_params.name, type='model', description=description)
            artifact.add_file(checkpoint_path)
            artifact.add_file(architecture_path)
            run.log_artifact(artifact)
        return checkpoint_path

    def save_migration(self, path):
        os.makedirs(path, exist_ok=True)
        checkpoint_path = os.path.join(path, f"migrated_checkpoint.pth")
        torch.save(self, checkpoint_path)
        return checkpoint_path

    @staticmethod
    def load(path):
        # Add safe globals for PyTorch 2.6+ compatibility
        try:
            from torch.serialization import add_safe_globals
            add_safe_globals([Checkpoint])
        except ImportError:
            # Fallback for older PyTorch versions that don't have add_safe_globals
            pass
        
        # Load with weights_only=False for backward compatibility
        experiment: Checkpoint = torch.load(path, map_location='cpu', weights_only=False)
        return experiment

    def get_model(self):
        return self.model

    def __getstate__(self):
        from .sequence import hSequenceVAE
        return {
                "epoch": self.epoch,
                "model":       self.model.serialize(),
                "scheduler_state_dict": self.scheduler_state_dict,
                "optimizer_state_dict": self.optimizer_state_dict,
                "sequential": isinstance(self.model, hSequenceVAE)
                }

    def __setstate__(self, state):
        from .hvae import hVAE
        from .sequence import hSequenceVAE

        self.epoch = state["epoch"]
        self.model = hVAE.deserialize(state["model"]) if not state["sequential"] \
            else hSequenceVAE.deserialize(state["model"])
        self.scheduler_state_dict = state["scheduler_state_dict"]
        self.optimizer_state_dict = state["optimizer_state_dict"]


