import os
import bisect
import pickle
from pathlib import Path

import lmdb
import torch
import torch_geometric
import numpy as np
from torch.utils.data import Dataset
from torch_geometric.data import Data
from typing import Type, TypeVar

_T = TypeVar("_T")

# This is a copy of OC22LmdbDataset with minor changes to enable label rescaling
class SOC22LmdbDataset(Dataset):
    r"""Dataset class to load from LMDB files containing relaxation
    trajectories or single point computations.

    Useful for Structure to Energy & Force (S2EF), Initial State to
    Relaxed State (IS2RS), and Initial State to Relaxed Energy (IS2RE) tasks.

    The keys in the LMDB must be integers (stored as ascii objects) starting
    from 0 through the length of the LMDB. For historical reasons any key named
    "length" is ignored since that was used to infer length of many lmdbs in the same
    folder, but lmdb lengths are now calculated directly from the number of keys.

    Args:
            config (dict): Dataset configuration
            transform (callable, optional): Data transform function.
                    (default: :obj:`None`)
    """

    def __init__(self, task, root, split="200000", transform=None) -> None:
        super(SOC22LmdbDataset, self).__init__()
        # we only use out of domain val set to test our work
        if task == "test": task = "test_ood"
        
        self.C = 10000.0
        self.identifier = f"SOC22LmdbDataset_{task}"
        self.path = Path(root+"/"+task)
        self.data2train = split
        if not self.path.is_file():
            db_paths = sorted(self.path.glob("*.lmdb"))
            assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'"

            self.metadata_path = self.path / "metadata.npz"

            self._keys, self.envs = [], []
            for db_path in db_paths:
                cur_env = self.connect_db(db_path)
                self.envs.append(cur_env)

                # Get the number of stores data from the number of entries
                # in the LMDB
                num_entries = assert_is_instance(cur_env.stat()["entries"], int)

                # If "length" encoded as ascii is present, we have one fewer
                # data than the stats suggest
                if cur_env.begin().get("length".encode("ascii")) is not None:
                    num_entries -= 1

                # Append the keys (0->num_entries) as a list
                self._keys.append(list(range(num_entries)))

            keylens = [len(k) for k in self._keys]
            self._keylen_cumulative = np.cumsum(keylens).tolist()
            self.num_samples = sum(keylens)

            if self.data2train != "all":
                # load the indice generated by split_data.py
                idx_path = os.path.dirname(os.path.realpath(__file__))
                with open(f"{idx_path}/random/{split}/{task}.pkl", "rb") as f:
                    idx = pickle.load(f)
                self.indices = idx
                self.num_samples = len(self.indices)
        else:
            self.metadata_path = self.path.parent / "metadata.npz"
            self.env = self.connect_db(self.path)

            num_entries = assert_is_instance(self.env.stat()["entries"], int)

            # If "length" encoded as ascii is present, we have one fewer
            # data than the stats suggest
            if self.env.begin().get("length".encode("ascii")) is not None:
                num_entries -= 1

            self._keys = list(range(num_entries))
            self.num_samples = num_entries

        self.transform = transform
        self.lin_ref = self.oc20_ref = False
        # only needed for oc20 datasets, oc22 is total by default
        self.train_on_oc20_total_energies = False
        
        #if self.train_on_oc20_total_energies:
        #    self.oc20_ref = pickle.load(open(config["oc20_ref"], "rb"))
        
        #if self.config.get("lin_ref", False):
        #    coeff = np.load(self.config["lin_ref"], allow_pickle=True)["coeff"]
        #    self.lin_ref = torch.nn.Parameter(
        #        torch.tensor(coeff), requires_grad=False
        #    )

        self.subsample = assert_is_instance(False, bool)

    def __len__(self) -> int:
        if self.subsample:
            return min(self.subsample, self.num_samples)
        return self.num_samples

    def __getitem__(self, idx):
        if self.data2train != "all":
            idx = self.indices[idx]
        if not self.path.is_file():
            # Figure out which db this should be indexed from.
            db_idx = bisect.bisect(self._keylen_cumulative, idx)
            # Extract index of element within that db.
            el_idx = idx
            if db_idx != 0:
                el_idx = idx - self._keylen_cumulative[db_idx - 1]
            assert el_idx >= 0

            # Return features.
            datapoint_pickled = (
                self.envs[db_idx]
                .begin()
                .get(f"{self._keys[db_idx][el_idx]}".encode("ascii"))
            )
            data_object = pyg2_data_transform(pickle.loads(datapoint_pickled))
            data_object.id = f"{db_idx}_{el_idx}"
        else:
            datapoint_pickled = self.env.begin().get(
                f"{self._keys[idx]}".encode("ascii")
            )
            data_object = pyg2_data_transform(pickle.loads(datapoint_pickled))

        if self.transform is not None:
            data_object = self.transform(data_object)
        # make types consistent
        sid = data_object.sid
        if isinstance(sid, torch.Tensor):
            sid = sid.item()
            data_object.sid = sid
        if "fid" in data_object:
            fid = data_object.fid
            if isinstance(fid, torch.Tensor):
                fid = fid.item()
                data_object.fid = fid

        if hasattr(data_object, "y_relaxed"):
            attr = "y_relaxed"
        elif hasattr(data_object, "y"):
            attr = "y"
        # if targets are not available, test data is being used
        else:
            return data_object

        # convert s2ef energies to raw energies
        if attr == "y":
            # OC20 data
            if "oc22" not in data_object:
                assert False, "To train OC20 or OC22+OC20 on total energies set train_on_oc20_total_energies=True"
                randomid = f"random{sid}"
                data_object[attr] += self.oc20_ref[randomid]
                data_object.nads = 1
                data_object.oc22 = 0

        # convert is2re energies to raw energies
        else:
            if "oc22" not in data_object:
                assert False, "To train OC20 or OC22+OC20 on total energies set train_on_oc20_total_energies=True"
                randomid = f"random{sid}"
                data_object[attr] += self.oc20_ref[randomid]
                del data_object.force
                del data_object.y_init
                data_object.nads = 1
                data_object.oc22 = 0

        if self.lin_ref is not False:
            lin_energy = sum(self.lin_ref[data_object.atomic_numbers.long()])
            data_object[attr] -= lin_energy

        # to jointly train on oc22+oc20, need to delete these oc20-only attributes
        # ensure otf_graph=1 in your model configuration
        if "edge_index" in data_object:
            del data_object.edge_index
        if "cell_offsets" in data_object:
            del data_object.cell_offsets
        if "distances" in data_object:
            del data_object.distances

        return data_object

    def connect_db(self, lmdb_path=None):
        env = lmdb.open(
            str(lmdb_path),
            subdir=False,
            readonly=True,
            lock=False,
            readahead=True,
            meminit=False,
            max_readers=1,
        )
        return env

    def close_db(self) -> None:
        if not self.path.is_file():
            for env in self.envs:
                env.close()
        else:
            self.env.close()

    def collate(self, batch):
        data = {"R":[], "z":[], "batch":[], "n":[]}
        label = {"E":[], "F":[]}      

        for i, b in enumerate(batch):
            data["R"].append(b["pos"])
            data["z"].append(b["atomic_numbers"].long())
            data["n"].append(b["natoms"])
            data["batch"].append(torch.ones((b["atomic_numbers"].size()), dtype=torch.int64)*i)
            label["E"].append(torch.Tensor([b["y"]]))
            label["F"].append(b["force"])
        
        data["R"] = torch.cat(data["R"])
        data["z"] = torch.cat(data["z"])
        data["n"] = torch.Tensor(data["n"])
        data["batch"] = torch.cat(data["batch"])
        label["E"] = torch.stack(label["E"])/self.C
        label["F"] = torch.cat(label["F"])/self.C
        #print("collate", data["R"].shape, data["z"].shape, data["n"].shape, data["batch"].shape, label["E"].shape, label["F"].shape)

        return data, label

def pyg2_data_transform(data: Data):
    """
    if we're on the new pyg (2.0 or later) and if the Data stored is in older format
    we need to convert the data to the new format
    """
    if torch_geometric.__version__ >= "2.0" and "_store" not in data.__dict__:
        return Data(
            **{k: v for k, v in data.__dict__.items() if v is not None}
        )

    return data

def assert_is_instance(obj: object, cls: Type[_T]) -> _T:
    if not isinstance(obj, cls):
        raise TypeError(f"obj is not an instance of cls: obj={obj}, cls={cls}")
    return obj

if __name__ == "__main__":
    task = "train"
    root = "/usr/data1/OC22/s2ef_total_train_val_test_lmdbs/data/oc22/s2ef-total"
    ds = OC22LmdbDataset(task, root)
    da = ds.__getitem__(20)
    print(da)
    print(type(da), ds.__len__())
    print("pos:", type(da["pos"]))
    print("atomic_numbers:", type(da["atomic_numbers"]))
    print("natoms:", type(da["natoms"]))
    print("y:", type(da["y"]))
    print("force:", type(da["force"]))
    
    from torch.utils.data import DataLoader
    dataloader = DataLoader(ds, batch_size=20, shuffle=True, collate_fn=ds.collate, num_workers=16)
    for data, label in dataloader:
        print(type(data["R"]), type(data["z"]), type(data["batch"]))
        print(type(label["E"]), type(label["F"]))
        print(data["R"].dtype, data["z"].dtype, data["batch"].dtype)
        print(label["E"].dtype, label["F"].dtype)
        print(data["R"].shape, data["z"].shape, data["batch"].shape)
        print(label["E"].shape, label["F"].shape)
        break
