import os
import torch
import pickle
import subprocess
import tempfile
import gym
import d4rl
import numpy as np
from filelock import FileLock
from tqdm import tqdm
from sgcrl.data.dbs.d4rl_utils import build_d4rl_dataset

class DiskPythonObjectDB:
    def __init__(self, directory):
        self._directory = directory
        if not os.path.exists(self._directory):
            os.makedirs(self._directory)

    def get_last(self, name):
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            idx = self._size(name) - 1
            if idx<0: return None
            filename = os.path.join(
                self._directory, str(name) + "___" + str(idx) + ".pt"
            )
            return torch.load(filename)

    def _size(self, name):
        s = len(
            [
                f
                for f in os.listdir(self._directory)
                if len(f) > len(name) + 3
                and f.startswith(name + "___")
                and f.endswith(".pt")
            ]
        )
        return s

    def size(self, name):
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            return self._size(name)

    def get(self, name, idx):
        if idx < 0:
            idx = -idx
            s = self.size(name)
            idx = s - idx
        assert idx < self.size(name)
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            filename=os.path.join(self._directory, str(name) + "___" + str(idx) + ".pt")
            p=torch.load(filename)
            return p

    def push(self, name, fn):
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            idx = self._size(name)
            filename = os.path.join(
                self._directory, str(name) + "___" + str(idx) + ".pt"
            )
            print(f"Saving model in {filename}")
            torch.save(fn, filename)

class DiskPickleObjectDB:
    def __init__(self, directory):
        self._directory = directory
        if not os.path.exists(self._directory):
            os.makedirs(self._directory)

    def get_last(self, name):
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            idx = self._size(name) - 1
            if idx<0: return None
            filename = os.path.join(
                self._directory, str(name) + "___" + str(idx) + ".pickle"
            )
            with open(filename, 'rb') as file:
                p = pickle.load(file)
            return p

    def _size(self, name):
        s = len(
            [
                f
                for f in os.listdir(self._directory)
                if len(f) > len(name) + 3
                and f.startswith(name + "___")
                and f.endswith(".pickle")
            ]
        )
        return s

    def size(self, name):
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            return self._size(name)

    def get(self, name, idx):
        if idx < 0:
            idx = -idx
            s = self.size(name)
            idx = s - idx
        assert idx < self.size(name)
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            filename=os.path.join(self._directory, str(name) + "___" + str(idx) + ".pickle")
            with open(filename, 'rb') as file:
                p = pickle.load(file)
            return p

    def push(self, name, fn):
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            idx = self._size(name)
            filename = os.path.join(
                self._directory, str(name) + "___" + str(idx) + ".pickle"
            )
            print(f"Saving model in {filename}")
            with open(filename, 'wb') as file:
                pickle.dump(fn, file)

class EpisodesReader:
    def get_ids(self):
        raise NotImplementedError

    def __getitem__(self, application_id, session_id, episode_id):
        raise NotImplementedError

    def pop(self, application_id, session_id, episode_id):
        raise NotImplementedError

    def __len__(self):
        return len(self.get_ids())

class EpisodesWriter:
    def write(self, application_id, session_id, episode_id, episode):
        raise NotImplementedError

class PytorchEpisodesReader(EpisodesReader):
    pass

class PytorchOnDiskEpisodesDB(PytorchEpisodesReader, EpisodesWriter):
    def __init__(self, env_name, directory=None, episode_id=None, use_pickle=False, split_obs=True):
        self._directory = (tempfile.TemporaryDirectory().name if directory is None else directory)
        self._episode_id = episode_id
        if not os.path.exists(self._directory):
            os.makedirs(self._directory)

        self.use_pickle = use_pickle

        if not use_pickle:
            build_d4rl_dataset(env_name, self, split_obs=split_obs)


    def get_ids(self):
        results = [
            file[:-3] for file in os.listdir(self._directory) if file.endswith(".pt")
        ]
        results = [tuple(r.split("___")) for r in results]
        if not self._episode_id is None:
            nresults = []
            for x in results:
                a, s, e = x[0], x[1], x[2]
                if e == self._episode_id:
                    nresults.append((a, s, e))
            results = nresults
        return results

    def __getitem__(self, idx):
        filename = os.path.join(self._directory, "___".join(idx) + ".pt")
        if self.use_pickle:
            values = pickle.load(open(filename, "rb"))
            result = {k: torch.tensor(v) for k, v in values.items()}
        else:
            result = torch.load(filename)
        return result
        # values = pickle.load(open(filename, "rb"))
        # values = {k: torch.tensor(v) for k, v in values.items()}
        # return values

    def pop(self, idx):
        filename = os.path.join(self._directory, "___".join(idx) + ".pt")
        result = torch.load(filename)
        os.remove(filename)
        return result

    def random_pop(self):
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            ids = self.get_ids()
            episode = None
            _id = None
            if len(ids) > 0:
                _id = ids[0]
                episode = self.pop(_id)
            return episode, _id

    def __len__(self):
        return len(self.get_ids())

    def write(self, application_id, session_id, episode_id, episode):
        lock = FileLock(os.path.join(self._directory, "lock.lock"))
        with lock:
            assert isinstance(episode, dict)
            for k, v in episode.items():
                if isinstance(v,np.ndarray):
                    episode[k] = torch.from_numpy(v)
                assert isinstance(episode[k], torch.Tensor), k + " = " + str(type(episode[k]))
            if self._episode_id is None or self._episode_id == episode_id:
                T = len(episode)
                fn = os.path.join(
                    self._directory,
                    application_id + "___" + session_id + "___" + episode_id,
                )
                if os.path.exists(fn + ".pt"):
                    idx = 1
                    _fn = fn + "___" + str(idx)
                    while os.path.exists(_fn + ".pt"):
                        idx += 1
                        _fn = fn + "___" + str(idx) + ".pt"
                    fn = _fn
                torch.save(episode, fn + ".pt")

def load_model(env_name, algorithm_name, seed, device, model_name):
    try:
        if 'pickle' in model_name:
            model_file = subprocess.check_output(f"readlink -f runs/{env_name}/{algorithm_name}/{seed}/*/*/*/models/{model_name} | sort -u | sort -V | tail -n1", shell=True).splitlines()[-1].decode("utf-8")
            print(f'loaded the file: {model_file}')
            with open(model_file, 'rb') as f:
                return pickle.load(f)
        else:
            model_file = subprocess.check_output(f"readlink -f runs/{env_name}/{algorithm_name}/{seed}/*/*/*/models/{model_name}*.pt | sort -u | sort -V | tail -n1", shell=True).splitlines()[-1].decode("utf-8")
            print(f'loaded the file: {model_file}')
            return torch.load(model_file, map_location=torch.device(device))
        
    except subprocess.CalledProcessError as e:
        print(f'Cannot load a model using the name \'{model_name}\'.\nFull error: {e}')
        return None

