import pickle

import os
import re
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from scipy.ndimage import binary_dilation


def get_nearmask(mask: np.ndarray, prob: float, dilation_structure=None) -> np.ndarray:
    B, L, K = mask.shape
    mask = mask.astype(bool)

    artificial_missing_mask = np.zeros_like(mask, dtype=bool)

    for b in range(B):
        new_missing_lst_feature = []
        for j in range(K):
            dilated_mask = binary_dilation(~mask[b, :, j], structure=dilation_structure)
            edge_mask = dilated_mask & mask[b, :, j]
            random_mask = np.random.rand(L) < prob
            new_missing = edge_mask & random_mask
            new_missing_lst_feature.append(new_missing)
        artificial_missing_mask[b] = np.stack(new_missing_lst_feature, axis=1)

    return artificial_missing_mask.astype('float32')

def get_random_mask(mask: np.ndarray, prob: float = 0.1) -> np.ndarray:
    mask = mask.astype(bool)
    B, L, K = mask.shape

    indicating_mask = np.zeros_like(mask, dtype=bool)
    for b in range(B):
        for k in range(K):

            observed_idx = np.where(mask[b, :, k])[0]
            if len(observed_idx) == 0:
                continue

            num_select = int(len(observed_idx) * prob)
            selected_idx = np.random.choice(observed_idx, num_select, replace=False)

            indicating_mask[b, selected_idx, k] = True

    return indicating_mask.astype(np.float32)

class ETT_Dataset(Dataset):
    def __init__(self, data_name='ett', eval_length=24, mode='train', artificial_missing_ratio=0.0, seed=1):
        self.eval_length = eval_length
        np.random.seed(seed)
        self.mode = mode
        self.missing_ratio = artificial_missing_ratio

        data_info_path = f"../data/{data_name}/{self.mode}_data_info.pk"

        with open(data_info_path, "rb") as f:
            data = pickle.load(f)

        self.original_values = data['X_ori']
        self.observed_values = data['X']
        self.observed_masks = data['observed_mask']
        self.indicating_masks = get_random_mask(self.observed_masks, prob = self.missing_ratio)
        self.gt_masks = self.observed_masks - self.indicating_masks


        self.use_index_list = np.arange(len(self.observed_values))


    def __getitem__(self, org_index):
        index = self.use_index_list[org_index]
        s = {
            "original_data": self.original_values[index],
            "observed_data": self.observed_values[index],
            "observed_mask": self.observed_masks[index],
            "indicating_mask": self.indicating_masks[index],
            "gt_mask": self.gt_masks[index],
            "timepoints": np.arange(self.eval_length),
        }
        return s

    def __len__(self):
        return len(self.use_index_list)
    
    def get_dataset_info(self):
        observed_missing_ratio = 1 - self.observed_masks.sum()/self.observed_masks.size
        artificial_missing_ratio = self.indicating_masks.sum()/self.indicating_masks.size
        num_target_entries = self.indicating_masks.sum()
        return observed_missing_ratio, artificial_missing_ratio, num_target_entries
    
    def save_eval_info(self, foldername):
        eval_info_path = os.path.join(foldername, f"{self.mode}_eval_info_{self.missing_ratio}_random.pk")
        
        data_eval_info = {
            'X_ori': self.original_values,
            'X': self.observed_values,
            'observed_mask': self.observed_masks,
            'indicating_mask': self.indicating_masks,
            'gt_mask': self.gt_masks,
        }

        with open(eval_info_path, 'wb') as f:
            pickle.dump(data_eval_info, f)



def get_all_dataloader(data_name='ett', eval_length=24, seed=1, batch_size=16, missing_ratio=0.1):

    train_dataset = ETT_Dataset(
        data_name=data_name, eval_length=eval_length, mode='train', artificial_missing_ratio=missing_ratio, seed=seed
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = ETT_Dataset(
        data_name=data_name, eval_length=eval_length, mode='test', artificial_missing_ratio=missing_ratio, seed=seed
    )
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    val_dataset = ETT_Dataset(
        data_name=data_name, eval_length=eval_length, mode='val', artificial_missing_ratio=missing_ratio, seed=seed
    )
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

def get_train_dataloader(data_name='ett', eval_length=24, seed=1, batch_size=16, missing_ratio=0.1):

    train_dataset = ETT_Dataset(
        data_name=data_name, eval_length=eval_length, mode='train', artificial_missing_ratio=missing_ratio, seed=seed
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    train_info = {
        'dataset': train_dataset,
        'dataloader': train_loader,
    }
    return train_info

def get_test_dataloader(data_name='ett', eval_length=24, seed=1, batch_size=16, missing_ratio=0.1):

    test_dataset = ETT_Dataset(
        data_name=data_name, eval_length=eval_length, mode='test', artificial_missing_ratio=missing_ratio, seed=seed
    )
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    test_info = {
        'dataset': test_dataset,
        'dataloader': test_loader,
    }

    return test_info

def get_val_dataloader(data_name='ett', eval_length=24, seed=1, batch_size=16, missing_ratio=0.1):

    val_dataset = ETT_Dataset(
        data_name=data_name, eval_length=eval_length, mode='val', artificial_missing_ratio=missing_ratio, seed=seed
    )
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    val_info = {
        'dataset': val_dataset,
        'dataloader': val_loader
    }
    return val_info