from sklearn.preprocessing import StandardScaler
import torch
from copy import deepcopy
from torch.utils.data import Dataset, Subset
from torch_geometric.data import Batch
import numpy as np
from torch.utils.data import DataLoader
from torch_geometric.utils import add_self_loops



type2index = {'zb': 0, 'gb': 1, 'hy':2, 'qy':3, 'hz':4, 'zj': 5}




class SubDataset(Dataset):
    def __init__(self, data, SEQ_LEN, LABEL_LEN, PRED_LEN):
        self.data = deepcopy(data)
        self.data.edge_index = self.data.edge_index
        normalize = True
        self.total_x = deepcopy(self.data.total_x)
        # 10697, 396
        if normalize:
            self.total_x = StandardScaler().fit_transform(self.total_x.transpose(1, 0)).transpose(1, 0)

        del self.data.total_x

        self.SEQ_LEN = SEQ_LEN
        self.LABEL_LEN = LABEL_LEN
        self.PRED_LEN = PRED_LEN
        self.WINDOW_LENGTH = SEQ_LEN + PRED_LEN

        self.LEN = self.total_x.shape[-1] - self.WINDOW_LENGTH + 1

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.SEQ_LEN
        r_begin = s_end
        r_end = r_begin + self.PRED_LEN

        x = torch.Tensor(self.total_x[:, s_begin:s_end])
        y = torch.Tensor(self.total_x[:, r_begin:r_end])

        data = deepcopy(self.data)
        data.x = x
        data.y = y
        return data

    def __len__(self):
        return self.LEN


class PowerDataset(object):
    def __init__(self, data, split_ratio=None, SEQ_LEN=96, LABEL_LEN=48, PRED_LEN=7):
        if split_ratio is None:
            split_ratio = [0.6, 0.2, 0.2]
        self.split_ratio = split_ratio
        self.data = deepcopy(data)

        self.SEQ_LEN = SEQ_LEN
        self.LABEL_LEN = LABEL_LEN
        self.PRED_LEN = PRED_LEN
        self.WINDOW_LENGTH = SEQ_LEN + PRED_LEN

        self.LEN = self.data.total_x.shape[-1] - self.WINDOW_LENGTH + 1

    def get_dataset(self):

        N = self.LEN
        indexes = list(range(N))

        train_length = int(self.split_ratio[0]*N)
        valid_length = int(self.split_ratio[1]*N)

        train_indexes = indexes[:train_length]
        valid_indexes = indexes[train_length:train_length+valid_length]
        test_indexes = indexes[train_length+valid_length:]

        dataset = SubDataset(data=self.data, SEQ_LEN=self.SEQ_LEN, LABEL_LEN=self.LABEL_LEN, PRED_LEN=self.PRED_LEN)
        train_ds = Subset(dataset=dataset, indices=train_indexes)
        valid_ds = Subset(dataset=dataset, indices=valid_indexes)
        test_ds = Subset(dataset=dataset, indices=test_indexes)
        def build_tree(ds):
            batch = []
            ds = DataLoader(ds, batch_size=10, shuffle=False, collate_fn=lambda x: x)
            for data in ds:
                batch += data
            return collate_fn(batch)

        train_tree_ds = build_tree(train_ds)
        valid_tree_ds = build_tree(valid_ds)
        test_tree_ds = build_tree(test_ds)

        print(train_tree_ds)
        print(valid_tree_ds)
        print(test_tree_ds)

        return train_tree_ds, valid_tree_ds, test_tree_ds


def collate_fn(batch):
    batch = Batch.from_data_list(batch)
    batch.node_attr = torch.LongTensor([int(type2index[i]) for i in list(np.array(batch.node_attr).reshape(-1))])
    batch.x = batch.x.unsqueeze(-1)
    batch.y = batch.y.unsqueeze(-1)
    return batch