import os
from PIL import Image
from omegaconf import OmegaConf
import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.plugins.io import CheckpointIO
from typing import (Any, Union, Dict, List)
from pathlib import Path
import io
import torch
import fsspec

from lightning_fabric.utilities.cloud_io import _atomic_save, get_filesystem
from lightning_fabric.utilities.cloud_io import _load as pl_load


class SetupCallback(Callback):
    def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
        super().__init__()
        self.resume = resume
        self.now = now
        self.logdir = logdir
        self.ckptdir = ckptdir
        self.cfgdir = cfgdir
        self.config = config
        self.lightning_config = lightning_config

    def on_fit_start(self, trainer, pl_module):
        if trainer.global_rank == 0:
            # Create logdirs and save configs
            print(f"BASE LOG DIR: {self.logdir}")
            os.makedirs(self.logdir, exist_ok=True)
            os.makedirs(self.ckptdir, exist_ok=True)
            os.makedirs(self.cfgdir, exist_ok=True)

            print("Project config")
            print(OmegaConf.to_yaml(self.config))
            OmegaConf.save(
                self.config,
                os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
            )

            print("Lightning config")
            print(OmegaConf.to_yaml(self.lightning_config))
            OmegaConf.save(
                OmegaConf.create({"lightning": self.lightning_config}),
                os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
            )
        else:
            # ModelCheckpoint callback created log directory --- remove it
            if not self.resume and os.path.exists(self.logdir):
                dst, name = os.path.split(self.logdir)
                dst = os.path.join(dst, "child_runs", name)
                os.makedirs(os.path.split(dst)[0], exist_ok=True)
                try:
                    os.rename(self.logdir, dst)
                except FileNotFoundError:
                    pass


def get_step_number(path):
    import re
    pattern = re.compile(r"step=(\d+)")
    result = pattern.search(path)[1]
    return result

def huggingface_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None:
    try:
        model = checkpoint.pop('model') #! Must 
    except Exception as e:
        raise ValueError("Model not found in checkpoint; You must override on_save_checkpoint method in LightningModule to save the model")

    bytesbuffer = io.BytesIO()
    torch.save(checkpoint, bytesbuffer)
    with fsspec.open(filepath, "wb") as f:
        f.write(bytesbuffer.getvalue())
    stepnumber = get_step_number(filepath)
    model.save_pretrained(os.path.join(os.path.dirname(filepath), f"checkpoint-{stepnumber}"))

class TransformerCheckpointIO(CheckpointIO):
    def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options =None) -> None:
        if storage_options is not None:
            raise TypeError(
                "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
                f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
                " to define how you'd like to use `storage_options`."
            )
        fs = get_filesystem(path)
        fs.makedirs(os.path.dirname(path), exist_ok=True)
        huggingface_save(checkpoint, path)
        _atomic_save(checkpoint, path)
        return super().save_checkpoint(checkpoint, path, storage_options)
        
    def load_checkpoint(self, path: Union[str, Path], map_location = None) -> Dict:
        return super().load_checkpoint(path, map_location)

    def remove_checkpoint(self, path: Union[str, Path]) -> None:
        return super().remove_checkpoint(path)