import random

from sklearn.preprocessing import StandardScaler
import torch
from copy import deepcopy
from torch.utils.data import Dataset, Subset
from torch_geometric.data import Batch
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm


type2index = {'zb': 0, 'gb': 1, 'hy':2, 'qy':3, 'cs':4, 'zj': 5}
type2ratio = {'zb': 2, 'gb':2, 'hy': 0.1, 'qy': 0.1, 'cs': 0.1, 'zj': 0.1}


class SubDataset(Dataset):
    def __init__(self, data, step, SEQ_LEN, LABEL_LEN, PRED_LEN, task='prediction', mask_ratio=0.125):
        self.step = step
        # self.type = type
        self.data = deepcopy(data)
        self.data.edge_index = self.data.edge_index
        normalize = True
        self.total_x = deepcopy(self.data.total_x)
        self.task = task
        self.mask_ratio = mask_ratio
        # 10697, 396
        if normalize:
            self.total_x = StandardScaler().fit_transform(self.total_x.transpose(1, 0)).transpose(1, 0)

        del self.data.total_x

        self.SEQ_LEN = SEQ_LEN
        self.LABEL_LEN = LABEL_LEN
        self.PRED_LEN = PRED_LEN
        self.WINDOW_LENGTH = SEQ_LEN + PRED_LEN

        self.LEN = (self.total_x.shape[-1] - self.WINDOW_LENGTH) // self.step + 1

    def __getitem__(self, index):
        s_begin = index * self.step
        s_end = s_begin + self.SEQ_LEN
        r_begin = s_end
        r_end = r_begin + self.PRED_LEN

        data = deepcopy(self.data)

        x = torch.Tensor(self.total_x[:, s_begin:s_end])

        if self.task == 'anomaly':
            y = self.data.y.long()
        elif self.task == 'prediction':
            y = torch.Tensor(self.total_x[:, r_begin:r_end])
        else:
            bs, T = x.shape
            mask = torch.zeros(T)
            indexes = torch.argsort(torch.rand(T))[:int(T * self.mask_ratio)]
            mask[indexes] = 1
            mask = mask.unsqueeze(0).repeat(bs, 1).long()
            x_masked = torch.clone(x)
            x_masked[mask == 1] = 0
            data.val_mask = mask == 1
            y = x
            x = x_masked

        data.x = x
        data.y = y
        # data.mask = torch.from_numpy(np.array(self.data.node_attr) == self.type).bool()
        return data

    def __len__(self):
        return self.LEN
class FinetunePowerDataset(object):
    def __init__(self, split_ratio=None, SEQ_LEN=256, LABEL_LEN=48, PRED_LEN=7, step=500, task='prediction', mask_ratio=0.125):

        if split_ratio is None:
            split_ratio = [0.6, 0.2, 0.2]
        self.split_ratio = split_ratio

        self.task = task
        self.mask_ratio = mask_ratio

        self.SEQ_LEN = SEQ_LEN
        self.LABEL_LEN = LABEL_LEN
        self.PRED_LEN = PRED_LEN
        self.WINDOW_LENGTH = SEQ_LEN + PRED_LEN

        self.step = step

        self.train_datasets = []
        self.valid_datasets = []
        self.test_datasets = []

    def add_dataset(self, dataset_name):
        data = torch.load('/data/tsh/PowerGPT/finetune/' + f'{dataset_name}.pt')

        # print('load_pred dataset~')
        dataset = SubDataset(data=data, step=self.step, SEQ_LEN=self.SEQ_LEN, LABEL_LEN=self.LABEL_LEN,
                             PRED_LEN=self.PRED_LEN, task=self.task, mask_ratio=self.mask_ratio)
        N = (data.total_x.shape[-1] - self.WINDOW_LENGTH) // self.step + 1

        indexes = list(range(N))
        random.shuffle(indexes)

        train_length = int(self.split_ratio[0] * N)
        valid_length = int(self.split_ratio[1] * N)
        test_length = int(self.split_ratio[2] * N)

        train_indexes = indexes[:train_length]
        valid_indexes = indexes[N - valid_length - test_length:N - test_length]
        test_indexes = indexes[N - test_length:]

        self.train_datasets.append(Subset(dataset, train_indexes))
        self.valid_datasets.append(Subset(dataset, valid_indexes))
        self.test_datasets.append(Subset(dataset, test_indexes))

    def get_dataset(self):
        train_batch = []
        valid_batch = []
        test_batch = []
        for dataset in tqdm(self.train_datasets):
            dl = DataLoader(dataset, batch_size=10, shuffle=False, collate_fn=lambda x: x)
            for data in dl:
                train_batch += data

        for dataset in tqdm(self.valid_datasets):
            dl = DataLoader(dataset, batch_size=10, shuffle=False, collate_fn=lambda x: x)
            for data in dl:
                valid_batch += data

        for dataset in tqdm(self.test_datasets):
            dl = DataLoader(dataset, batch_size=10, shuffle=False, collate_fn=lambda x: x)
            for data in dl:
                test_batch += data

        return collate_fn(train_batch), collate_fn(valid_batch), collate_fn(test_batch)


def collate_fn(batch):
    batch = Batch.from_data_list(batch)
    node_attr = []
    for n in batch.node_attr:
        node_attr += n
    batch.node_attr = torch.LongTensor([int(type2index[i]) for i in list(np.array(node_attr).reshape(-1))])
    batch.x = batch.x.unsqueeze(-1)
    batch.y = batch.y.unsqueeze(-1)
    return batch