import os
import pickle
import random
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from sklearn.preprocessing import StandardScaler
import warnings

warnings.filterwarnings('ignore')
class Dataset_LPI4AI(Dataset):
    def __init__(self, root_path='dataset/', data_path='set_cleaned_sampled.pickle', flag='train'):
        self.root_path = root_path
        self.flag = flag
        if flag == 'train':
            self.data_path = 'train' + data_path
        elif flag == 'test':
            self.data_path = 'test' + data_path
        elif flag == 'val':
            self.data_path = 'valid' + data_path
        else:
            self.data_path = data_path
        
        self.label_len = 400
        self.pred_len = 400
        
        self.__read_data__()

    def __read_data__(self):
        with open(self.root_path + self.data_path, 'rb') as handle:
            self.dataset = pickle.load(handle)
        
        if self.flag == 'train' or self.flag == 'val':
            random.shuffle(self.dataset)
        
        for i in range(len(self.dataset)):
            self.dataset[i]['intensity'] = torch.Tensor(self.dataset[i]['intensity']).unsqueeze(-1)
            self.dataset[i]['output1'] = torch.Tensor(self.dataset[i]['output1']).unsqueeze(-1)

    def __getitem__(self, index):
        seq_x = self.dataset[index]['intensity']
        seq_y = self.dataset[index]['output1']
        phase_plate = self.dataset[index]['phase_plate']
        target_size = self.dataset[index]['target_size']

        return seq_x, seq_y, phase_plate, target_size

    def __len__(self):
        return len(self.dataset)