import copy
import json
import numpy as np
import logging
from torch.utils.data import Dataset

SEED = 1
np.random.seed(SEED)

logger = logging.getLogger(__name__)


class CosseratRodCorrectionDataSet(Dataset):
    def __init__(self, filename):
        self._db = self._get_db(filename)
        logger.info("Global size: {},\nedges size: {},\nnodes size: {},\nsender size: {},\n"
                    "receiver size: {}".format(self._db["graph"]["globals"].shape,
                                               self._db["graph"]["edges"].shape,
                                               self._db["graph"]["nodes"].shape,
                                               self._db["graph"]["senders"].shape,
                                               self._db["graph"]["receivers"].shape,))

    def __len__(self):
        return len(self._db["data"]["input_state"])

    def __getitem__(self, item):
        #it = np.random.randint(0, self.__len__(), 1)[0]
        it = item # Set reshuffle true in dataloader with torch.manual_seed(2)

        #####
        # input node dimension 3+3+4+3+4 = 17, [x, v, q, omega]
        #####
        input_state = copy.deepcopy(self._db["data"]["input_state"][it])
        graph_input = copy.deepcopy(self._db["graph"])
        graph_input["nodes"] = input_state

        corr_x = copy.deepcopy(self._db["data"]["corr_x"][it])
        lambda_v = copy.deepcopy(self._db["data"]["lambda_v"][it])

        return graph_input, corr_x, lambda_v

    def _get_db(self, filename):
        with open(filename, 'r') as inputfile:
            data = json.load(inputfile)

        graph_init = data["graph"]

        data["graph"]["globals"] = np.array(graph_init["globals"])
        data["graph"]["edges"] = np.array(graph_init["edges"])
        data["graph"]["senders"] = np.array(graph_init["senders"])
        data["graph"]["receivers"] = np.array(graph_init["receivers"])
        data["graph"]["nodes"] = np.array(graph_init["nodes"])
        data["graph"]["nodes_index"] = np.array(graph_init["nodes_index"])
        data["graph"]["edges_index"] = np.array(graph_init["edges_index"])

        input_data = {"input_state": np.array(data["data"]["input_state"]),
                      "corr_x": np.array(data["data"]["corr_x"]),
                      "lambda_v": np.array(data["data"]["lambda_v"])}
        data["data"] = input_data

        #self.NODES_SHAPE = data["graph"]["nodes"].shape
        self.NODES_SHAPE = data["data"]["input_state"][0].shape
        self.EDGES_SHAPE = data["graph"]["edges"].shape
        self.GLOBAL_SHAPE = data["graph"]["globals"].shape

        return data