import os
import torch
import pickle

import numpy as np

from filelock import FileLock

class DiskPythonObjectDB:
    def __init__(self, directory):
        self._directory = directory
        if not os.path.exists(self._directory):
            os.makedirs(self._directory)

    def get_last(self, name):
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            idx = self._size(name) - 1
            if idx<0: return None
            filename = os.path.join(
                self._directory, str(name) + "___" + str(idx) + ".pt"
            )
            return torch.load(filename, map_location=torch.device('cpu'))

    def _size(self, name):
        s = len(
            [
                f
                for f in os.listdir(self._directory)
                if len(f) > len(name) + 3
                and f.startswith(name + "___")
                and f.endswith(".pt")
            ]
        )
        return s

    def size(self, name):
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            return self._size(name)

    def get(self, name, idx):
        if idx < 0:
            idx = -idx
            s = self.size(name)
            idx = s - idx
        assert idx < self.size(name)
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            filename=os.path.join(self._directory, str(name) + "___" + str(idx) + ".pt")
            p=torch.load(filename, map_location='cpu')
            return p

    def push(self, name, fn):
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            idx = self._size(name)
            filename = os.path.join(
                self._directory, str(name) + "___" + str(idx) + ".pt"
            )
            # print("Saving model in ",filename)
            torch.save(fn, filename)

# The object to manage episodes tracked
class EpisodesReader:
    def get_ids(self):
        raise NotImplementedError

    def __getitem__(self, application_id, session_id, episode_id):
        raise NotImplementedError

    def pop(self, application_id, session_id, episode_id):
        raise NotImplementedError

    def __len__(self):
        return len(self.get_ids())


# The object to manage episodes tracked
class EpisodesWriter:
    def write(self, application_id, session_id, episode_id, episode):
        raise NotImplementedError

class PytorchEpisodesReader(EpisodesReader):
    pass

class PytorchOnDiskEpisodesDB(PytorchEpisodesReader, EpisodesWriter):
    def __init__(self, directory, episode_id=None, use_pickle=False):
        self._directory = directory
        self._episode_id = episode_id
        if not os.path.exists(self._directory):
            os.makedirs(self._directory)

        self.use_pickle = use_pickle

    def get_ids(self):
        results = [
            file[:-3] for file in os.listdir(self._directory) if file.endswith(".pt")
        ]
        results = [tuple(r.split("___")) for r in results]
        if not self._episode_id is None:
            nresults = []
            for x in results:
                a, s, e = x[0], x[1], x[2]
                if e == self._episode_id:
                    nresults.append((a, s, e))
            results = nresults
        return results

    def __getitem__(self, idx):
        filename = os.path.join(self._directory, "___".join(idx) + ".pt")
        if self.use_pickle:
            values = pickle.load(open(filename, "rb"))
            result = {k: torch.tensor(v) for k, v in values.items()}
        else:
            result = torch.load(filename)
        return result
        # values = pickle.load(open(filename, "rb"))
        # values = {k: torch.tensor(v) for k, v in values.items()}
        # return values

    def pop(self, idx):
        filename = os.path.join(self._directory, "___".join(idx) + ".pt")
        result = torch.load(filename)
        os.remove(filename)
        return result

    def random_pop(self):
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            ids = self.get_ids()
            episode = None
            _id = None
            if len(ids) > 0:
                _id = ids[0]
                episode = self.pop(_id)
            return episode, _id

    def __len__(self):
        return len(self.get_ids())

    def write(self, application_id, session_id, episode_id, episode):
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            assert isinstance(episode, dict)
            for k, v in episode.items():
                if isinstance(v,np.ndarray):
                    episode[k] = torch.from_numpy(v)
                assert isinstance(episode[k], torch.Tensor), k + " = " + str(type(episode[k]))
            if self._episode_id is None or self._episode_id == episode_id:
                T = len(episode)
                fn = os.path.join(
                    self._directory,
                    application_id + "___" + session_id + "___" + episode_id,
                )
                if os.path.exists(fn + ".pt"):
                    idx = 1
                    _fn = fn + "___" + str(idx)
                    while os.path.exists(_fn + ".pt"):
                        idx += 1
                        _fn = fn + "___" + str(idx) + ".pt"
                    fn = _fn
                torch.save(episode, fn + ".pt")