# import dill as pickle
import dill as pickle
import jax
from .locking import NFSLock
import shutil
import atexit
from pathlib import Path
from functools import partial
from copy import deepcopy

def locked(f, mutable):
    def g(self, *args, **kwargs):
        with NFSLock(self.lock_path):
            self._load()
            if not mutable:
                old_registry = deepcopy(self.registry)

            ret = f(self, *args, **kwargs)

            if mutable:
                self._save()
            else:
                assert old_registry == self.registry

            return ret

    return g

mutable_lock = partial(locked, mutable=True)
immutable_lock = partial(locked, mutable=False)

class DDict:
    # cache_dir format: states/state_{it}.pkl, dthetas/dtheta_{stage}.pkl
    # save_dir format: deps/dep_{stage}.pkl
    @classmethod
    def from_disk(cls, dired):
        dired = Path(dired)
        if not (dired / 'manager.pkl').exists():
            raise FileNotFoundError

        with open(dired / 'manager.pkl', 'rb') as f:
            return pickle.load(f)

    @classmethod
    def load_or_create(cls, dired, *args, **kwargs):
        try:
            print('** Loading from dired', dired)
            return cls.from_disk(dired)
        except FileNotFoundError:
            dm = cls(dired, *args, **kwargs)
            dm._save()
            return dm

    def _save(self):
        # atomic swap these files
        swap_file = self.dir / 'manager_new.pkl'
        actual_file = self.dir / 'manager.pkl'

        with open(swap_file, 'wb') as f:
            pickle.dump(self, f)

        if actual_file.exists():
            actual_file.unlink()

        swap_file.rename(actual_file)

    def _load(self):
        with open(self.dir / 'manager.pkl', 'rb') as f:
            self.registry = pickle.load(f).registry

    def __init__(self, dir, clear_on_exit=False):
        self.dir = Path(dir)
        self.values_dir = self.dir / 'values'
        self.values_dir.mkdir(exist_ok=True, parents=True)
        self.registry = {}
        self.lock_path = self.dir / 'lock'
        if clear_on_exit:
            atexit.register(self.clear)

    @immutable_lock
    def __len__(self):
        return len(self.registry)

    def _path_for(self, key):
        return self.values_dir / f'{key}.pkl'

    @mutable_lock
    def __setitem__(self, key, value):
        self._save_item(key, value)

    @immutable_lock
    def __getitem__(self, key):
        return self._load_item(key)

    @mutable_lock
    def __delitem__(self, key):
        self._delete(key)

    @immutable_lock
    def __contains__(self, key):
        return key in self.registry

    @immutable_lock
    def keys(self):
        return list(self.registry.keys())

    def _delete(self, key):
        p = self.registry.pop(key)
        # atomic save the new registry without the key, in case something goes
        # wrong
        self._save()
        p.unlink()

    @mutable_lock
    def force_set(self, key, value):
        self._save_item(key, value, force=True)

    def _save_item(self, key, value, force=False):
        assert not key in self.registry, key
        p = self._path_for(key)
        if p.exists():
            if not force:
                raise FileExistsError

            p.unlink()

        with open(p, 'wb') as f:
            pickle.dump(value, f)

        self.registry[key] = p

    def _load_item(self, key):
        p = self._path_for(key)
        if not p.exists():
            raise FileNotFoundError((key, p))

        with open(p, 'rb') as f:
            return pickle.load(f)

    def clear(self):
        with NFSLock(self.lock_path):
            shutil.rmtree(self.dir / 'values')
            (self.dir / 'manager.pkl').unlink()

        shutil.rmtree(self.dir)

from .dlpack import dlpack_cpu2gpu

class CastedDDict(DDict):
    def __init__(self, dir, cast, device, clear_on_exit):
        super().__init__(dir, clear_on_exit=clear_on_exit)
        self.cast = cast.replace(params=None, opt_state=None, batch_stats=None)
        self.device = device

    def _cast(self, v):
        v = self.cast.replace(params=v.params, opt_state=v.opt_state, batch_stats=v.batch_stats)
        if self.device is not None:
            v = dlpack_cpu2gpu(v, blocking=True)

        return v

    def _load_item(self, key):
        v = (super()._load_item(key))
        v = self._cast(v)
        return v

class MemCastedDDict(CastedDDict):
    def __init__(self, dir, cast, device, clear_on_exit):
        super().__init__(dir, cast, device, clear_on_exit=clear_on_exit)
        self.memkv = {}
        self.use_mem = False

    def _save(self):
        memkv = self.memkv
        self.memkv = {}
        super()._save()
        self.memkv = memkv

    def set_mode(self, use_mem):
        self.use_mem = use_mem

    def __getitem__(self, key):
        if self.use_mem and key in self.memkv:
            return self._cast(self.memkv[key])

        return super().__getitem__(key)

    def __len__(self):
        if not self.use_mem:
            return super().__len__()

        disk_keys = set(super().keys())
        mem_keys = set(self.memkv.keys())
        return len(mem_keys | disk_keys)

    def __setitem__(self, key, value):
        if self.use_mem:
            assert not key in self.memkv
            self.memkv[key] = value
        else:
            return super().__setitem__(key, value)

    def __delitem__(self, key):
        if key in self.memkv and self.use_mem:
            self.memkv.pop(key)
        else:
            return super().__delitem__(key)

    def __contains__(self, key):
        if key in self.memkv and self.use_mem:
            return True

        return key in self.registry

    def keys(self):
        disk_keys = super().keys()
        if not self.use_mem:
            return disk_keys

        mem_keys = set(self.memkv.keys())
        return list(mem_keys | set(disk_keys))

    def force_set(self, key, value):
        if self.use_mem:
            self.memkv[key] = value
        else:
            return super().force_set(key, value)


# from multiprocessing import shared_memory
# import numpy as np

# # Create a NumPy array
# np_array = np.array([0, 1, 2, 3, 4], dtype=np.int32)

# # Create a shared memory block
# shm = shared_memory.SharedMemory(create=True, size=np_array.nbytes)

# # Create a NumPy array backed by shared memory
# np_array_shared = np.ndarray(np_array.shape, dtype=np_array.dtype, buffer=shm.buf)

# # Copy data to shared memory
# np_array_shared[:] = np_array[:]

# # Print the name of the shared memory block
# print(f"Shared memory name: {shm.name}")

# # Keep the shared memory block alive
# input("Press Enter to continue...")

# # Clean up
# shm.close()
# shm.unlink()

# from multiprocessing import shared_memory
# import numpy as np

# # Get the name of the shared memory block from the writer process
# shm_name = input("Enter the shared memory name: ")

# # Access the existing shared memory block
# existing_shm = shared_memory.SharedMemory(name=shm_name)

# # Create a NumPy array backed by shared memory
# np_array_shared = np.ndarray((5,), dtype=np.int32, buffer=existing_shm.buf)

# # Read data from shared memory
# print(np_array_shared)

# # Clean up
# existing_shm.close()
