import random

from sklearn.preprocessing import StandardScaler
import torch
from copy import deepcopy
from torch.utils.data import Dataset, Subset
from torch_geometric.data import Batch
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.nn.functional import interpolate
from torch_geometric.data import Data


type2index = {'zb': 0, 'gb': 1, 'hy':2, 'qy':3, 'cs':4, 'zj': 5}
type2ratio = {'zb': 2, 'gb':2, 'hy': 0.1, 'qy': 0.1, 'cs': 0.1, 'zj': 0.1}

def upsample_sequence(sequence, target_size=10000):
    # 计算上采样比例
    scale_factor = target_size / sequence.size(1)
    # 执行上采样
    upsampled_sequence = interpolate(sequence.reshape(sequence.size(0), 1, -1), scale_factor=scale_factor, mode='linear')
    return upsampled_sequence.reshape(-1, target_size)

class SubDataset(Dataset):
    def __init__(self, data, step, SEQ_LEN, LABEL_LEN, PRED_LEN, task='prediction', mask_ratio=0.125):
        cov = torch.load('/data/tsh/PowerGPT/preprocess/cov/cov_finetune.pt')
        if  task!='pretrain':
            cov = torch.rand(4,data.total_x.shape[-1])         
        self.step = step
        # self.type = type
        self.data = deepcopy(data)
        self.data.edge_index = self.data.edge_index
        normalize = True if task != 'classification' else False
        self.total_x = deepcopy(self.data.total_x)
        self.cov = cov.unsqueeze(0).repeat(self.total_x.shape[0], 1, 1)
        if task == 'imputation':
            ## why?
            self.total_x = upsample_sequence(self.total_x, self.total_x.shape[-1]*50)

        self.task = task
        self.mask_ratio = mask_ratio
        # 10697, 396
        if normalize:
            self.total_x = StandardScaler().fit_transform(self.total_x.transpose(1, 0)).transpose(1, 0)

        del self.data.total_x

        self.SEQ_LEN = SEQ_LEN
        self.LABEL_LEN = LABEL_LEN
        self.PRED_LEN = PRED_LEN
        self.WINDOW_LENGTH = SEQ_LEN + PRED_LEN

        self.LEN = (self.total_x.shape[-1] - self.WINDOW_LENGTH) // self.step + 1

    def __getitem__(self, index):
        s_begin = index * self.step
        s_end = s_begin + self.SEQ_LEN
        r_begin = s_end
        r_end = r_begin + self.PRED_LEN

        data = deepcopy(self.data)

        x = torch.Tensor(self.total_x[:, s_begin:s_end])
        x_cov = torch.Tensor(self.cov[:, :, s_begin:s_end])
        if self.task == 'classification':
            y = self.data.y[index]
        elif self.task == 'prediction':
            y = torch.Tensor(self.total_x[:, r_begin:r_end])
        else:
            bs, T = x.shape
            mask = torch.zeros(T)
            indexes = torch.argsort(torch.rand(T))[:int(T * self.mask_ratio)]
            mask[indexes] = 1
            mask = mask.unsqueeze(0).repeat(bs, 1).long()
            x_masked = torch.clone(x)
            x_masked[mask == 1] = 0
            data.val_mask = mask == 1
            y = x
            x = x_masked

        data.x = x.float()
        data.y = y
        data.x_cov = x_cov
        # data.mask = torch.from_numpy(np.array(self.data.node_attr) == self.type).bool()
        return data

    def __len__(self):
        return self.LEN



class FinetunePowerDataset(object):
    def __init__(self, name, split_ratio=None, SEQ_LEN=256, LABEL_LEN=48, PRED_LEN=7, step=50, task='prediction', mask_ratio=0.125):

        self.name = name
        if split_ratio is None:
            split_ratio = [0, 0, 1]
        self.split_ratio = split_ratio

        self.task = task
        self.mask_ratio = mask_ratio

        self.SEQ_LEN = SEQ_LEN
        self.LABEL_LEN = LABEL_LEN
        self.PRED_LEN = PRED_LEN
        self.WINDOW_LENGTH = SEQ_LEN + PRED_LEN

        self.step = step

        self.train_datasets = []
        self.valid_datasets = []
        self.test_datasets = []

    def add_dataset(self):
        # data = torch.load('/data/tsh/PowerGPT/new_finetune/' + f'{self.name}.pt')
        data = torch.load('/data/tsh/PowerGPT/preprocess/pretrain/pretrain_load_yp.pt')
        # data = torch.load('/data/tsh/PowerGPT/preprocess/finetune/' + f'{self.name}.pt')
        if not hasattr(data,'total_x'):
            data.total_x = data.x
            del data.x
        if self.task == 'classification':
            edge_index = torch.LongTensor([0, 0]).reshape(2, 1)
            node_attr = ['gb']
            g = Data(edge_index=edge_index, edge_type=torch.zeros(1).long(), node_attr=node_attr)
            x = data.total_x
            y = data.y

            indexes = list(range(x.shape[0]))
            random.shuffle(indexes)
            indexes = torch.LongTensor(indexes)

            x = x[indexes]
            y = y[indexes]

            g.total_x = x.reshape(1, -1)
            g.y = y
            g.num_nodes = 1

            data = g

            self.step = 5000

        dataset = SubDataset(data=data, step=self.step, SEQ_LEN=self.SEQ_LEN, LABEL_LEN=self.LABEL_LEN,
                             PRED_LEN=self.PRED_LEN, task=self.task, mask_ratio=self.mask_ratio)



        N = (dataset.total_x.shape[-1] - self.WINDOW_LENGTH) // self.step + 1

        indexes = list(range(N))

        random.shuffle(indexes)

        train_length = int(self.split_ratio[0] * N)
        valid_length = int(self.split_ratio[1] * N)
        test_length = int(self.split_ratio[2] * N)

        train_indexes = indexes[:train_length]
        if self.task != 'imputation' and self.task != 'classification':
            valid_indexes = indexes[N - valid_length - test_length:N - test_length]
            test_indexes = indexes[N - test_length:][::self.step]
        else:
            valid_indexes = indexes[N - valid_length - test_length:N - test_length]
            test_indexes = indexes[N - test_length:]

        self.train_datasets.append(Subset(dataset, test_indexes))
        self.valid_datasets.append(Subset(dataset, test_indexes))
        self.test_datasets.append(Subset(dataset, test_indexes))

    def get_dataset(self):
        train_batch = []
        valid_batch = []
        test_batch = []
        for dataset in tqdm(self.train_datasets):
            dl = DataLoader(dataset, batch_size=10, shuffle=False, collate_fn=lambda x: x)
            for data in dl:
                train_batch += data

        for dataset in tqdm(self.valid_datasets):
            dl = DataLoader(dataset, batch_size=10, shuffle=False, collate_fn=lambda x: x)
            for data in dl:
                valid_batch += data

        for dataset in tqdm(self.test_datasets):
            dl = DataLoader(dataset, batch_size=10, shuffle=False, collate_fn=lambda x: x)
            for data in dl:
                test_batch += data

        return collate_fn(train_batch), collate_fn(valid_batch), collate_fn(test_batch)


def collate_fn(batch):
    batch = Batch.from_data_list(batch)
    # node_attr = []
    # for n in batch.node_attr:
    #     node_attr += n
    # batch.node_attr = torch.LongTensor([int(type2index[i]) for i in list(np.array(node_attr).reshape(-1))])
    batch.x = batch.x.unsqueeze(-1)
    batch.y = batch.y.unsqueeze(-1)
    return batch