import pickle
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import StandardScaler

class Forecasting_Dataset(Dataset):
    def __init__(self, datatype, mode="train"):
        self.history_length = 168
        self.pred_length = 24

        if datatype == 'electricity':
            datafolder = './data/electricity_nips'
            self.test_length= 24*7
            self.valid_length = 24*5
            
        self.seq_length = self.history_length + self.pred_length
            
        paths=datafolder+'/data.pkl' 
        #shape: (T x N)
        #mask_data is usually filled by 1
        with open(paths, 'rb') as f:
            self.main_data, self.mask_data = pickle.load(f)
        paths=datafolder+'/meanstd.pkl'
        with open(paths, 'rb') as f:
            self.mean_data, self.std_data = pickle.load(f)
            
        self.main_data = (self.main_data - self.mean_data) / self.std_data


        total_length = len(self.main_data)
        if mode == 'train': 
            start = 0
            end = total_length - self.seq_length - self.valid_length - self.test_length + 1
            self.use_index = np.arange(start,end,1)
        if mode == 'valid': #valid
            start = total_length - self.seq_length - self.valid_length - self.test_length + self.pred_length
            end = total_length - self.seq_length - self.test_length + self.pred_length
            self.use_index = np.arange(start,end,self.pred_length)
        if mode == 'test': #test
            start = total_length - self.seq_length - self.test_length + self.pred_length
            end = total_length - self.seq_length + self.pred_length
            self.use_index = np.arange(start,end,self.pred_length)
        
    def __getitem__(self, orgindex):
        index = self.use_index[orgindex]
        target_mask = self.mask_data[index:index+self.seq_length].copy()
        target_mask[-self.pred_length:] = 0. #pred mask for test pattern strategy
        s = {
            'observed_data': self.main_data[index:index+self.seq_length], # [seq_length, feature_dim]
            'observed_mask': self.mask_data[index:index+self.seq_length],
            'gt_mask': target_mask,
            'timepoints': np.arange(self.seq_length) * 1.0, 
            'feature_id': np.arange(self.main_data.shape[1]) * 1.0, 
        }

        return s
    def __len__(self):
        return len(self.use_index)

def get_dataloader(datatype,device,batch_size=8,is_multi_res=False):

    if datatype == 'electricity':
        dataset = Forecasting_Dataset(datatype,mode='train')
        train_loader = DataLoader(
            dataset, batch_size=batch_size, shuffle=1)
        valid_dataset = Forecasting_Dataset(datatype,mode='valid')
        valid_loader = DataLoader(
            valid_dataset, batch_size=batch_size, shuffle=0)
        test_dataset = Forecasting_Dataset(datatype,mode='test')
        test_loader = DataLoader(
            test_dataset, batch_size=batch_size, shuffle=0)

        scaler = torch.from_numpy(dataset.std_data).to(device).float()
        mean_scaler = torch.from_numpy(dataset.mean_data).to(device).float()

        return train_loader, valid_loader, test_loader, scaler, mean_scaler
    elif datatype in ['gas', 'gas2', 'bs', 'BE', "DE", "FR", "NP", "PJM"]:
        if datatype in ['gas', 'gas2']:
            csv_path = f'./data/{datatype}/gas.csv'  # Update with actual path
            stats_path = f'./data/{datatype}/data_stats.pkl'  # Update with actual path
            history_len = 30
            pred_len = 7
        elif datatype in ['BE', 'DE', 'FR', 'NP', 'PJM']:
            csv_path = f'./data/EPF/{datatype}.csv'
            stats_path = f'./data/EPF/{datatype}_data_stats.pkl'
            history_len = 168
            pred_len = 24
        else:
            csv_path = f'./data/bike_sharing/hour.csv'
            stats_path = f'./data/bike_sharing/data_stats.pkl'
            history_len = 168
            pred_len = 24
        
        dataset = dataset_func(
            csv_path=csv_path,
            stats_path=stats_path,
            subset='train',
            history_len=history_len,
            pred_len=pred_len,
        )
        train_loader = DataLoader(
            dataset, batch_size=batch_size, shuffle=True
        )
        valid_dataset = dataset_func(
            csv_path=csv_path,
            stats_path=stats_path,
            subset='val',
            history_len=history_len,
            pred_len=pred_len,
        )
        valid_loader = DataLoader(
            valid_dataset, batch_size=batch_size, shuffle=False
        )
        test_dataset = dataset_func(
            csv_path=csv_path,
            stats_path=stats_path,
            subset='test',
            history_len=history_len,
            pred_len=pred_len,
        )
        test_loader = DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False
        )

        scaler = torch.from_numpy(dataset.stds).to(device).float()
        mean_scaler = torch.from_numpy(dataset.means).to(device).float()

        return train_loader, valid_loader, test_loader, scaler, mean_scaler


class GasLoadForecastDataset(Dataset):
    def __init__(self, csv_path, stats_path, subset='train',
                 history_len=30, pred_len=7,
                 train_ratio=0.7, val_ratio=0.1,
                 mask_mode="history+future_cov"):
        self.history_len = history_len
        self.pred_len = pred_len
        self.window_size = history_len + pred_len
        self.subset = subset

        self.mask_mode = mask_mode

        df = pd.read_csv(csv_path)
        self.dates = df.iloc[:, 0].values
        self.data = df.iloc[:, 1:].values.astype(np.float32)

        total_len = len(self.data)
        train_idx, val_idx, test_idx = self.split_indices(total_len, self.window_size, train_ratio, val_ratio)

        if subset == 'train':
            self.indices = train_idx
        elif subset == 'val':
            self.indices = val_idx
        elif subset == 'test':
            self.indices = test_idx
        else:
            raise ValueError("subset must be 'train', 'val', or 'test'")
        
        self.scaler = StandardScaler()
        self.train_data = self.data[train_idx]
        self.scaler.fit(self.train_data)
        self.data = self.scaler.transform(self.data) 
        self.means = self.scaler.mean_.astype(np.float32)
        self.stds = self.scaler.scale_.astype(np.float32)

        self.feature_dim = self.data.shape[1]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        start = self.indices[idx]
        end = start + self.window_size
        window = self.data[start:end]

        mask = np.ones_like(window, dtype=np.float32)
        mask[self.history_len:, 0] = 0  

        if self.mask_mode == "history":
            mask[self.history_len:, 1:] = 0
        elif self.mask_mode == "history+future_cov":
            pass
        else:
            raise ValueError("mask_mode must be 'history' or 'history+future_cov'")

        sample = {
            "observed_data": window, 
            "observed_mask": np.ones_like(window, dtype=np.float32),
            "gt_mask": mask,
            "timepoints": np.arange(self.window_size, dtype=np.float32) * 1.0,
            "feature_id": np.arange(self.feature_dim, dtype=np.float32) * 1.0,
        }
        return sample
    
    def split_indices(self, total_length, window_size, train_ratio, val_ratio):
        max_start = total_length - window_size + 1
        train_end = int(train_ratio * max_start)
        val_end = int((train_ratio + val_ratio) * max_start)

        train_indices = list(range(0, train_end))
        val_indices = list(range(train_end, max_start))
        test_indices = val_indices.copy()

        return train_indices, val_indices, test_indices