import os
import random

import h5py
import torch
from torch.utils.data import Dataset

from ._base import register_dataset


@register_dataset('darcy')
class DarcyDataset(Dataset):
    def __init__(self, root='./datasets/data', split='train'):
        self.root = root
        self.split = split
        if split == 'train':
            self.data_file = f'{self.root}/darcy_train_128.pt'
        elif split == 'test':
            self.data_file = f'{self.root}/darcy_test_128.pt'
        else:
            raise NotImplementedError

        self.data = torch.load(self.data_file)
        self.inputs = self.data['x']
        self.outputs = self.data['y']
        self.data = torch.stack((self.outputs, self.inputs), dim=-1)
        self.resolution = (128, 128, 2)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], {}


    def get_condition(self, x, type):
        if type == 'ic':
            cond = x.clone()
            cond[..., 1:, 1] = 0
            return cond
        elif type == 'bc':
            cond = x.clone()
            cond[:, 1:, 1] = 0
            return cond
        elif type == 'round':
            cond = x.clone()
            cond[:, 1:-1, 1:-1, 1] = 0
            return cond
        elif type == 'coeff':
            cond = x.clone()
            cond[:, :, :, 0] = 0
            return cond
        else:
            raise ValueError(f'Invalid condition type: {type}')

    def refine_condition(self, samples, condition_type):
        mask = torch.ones_like(samples)
        mask = self.get_condition(mask, type=condition_type)
        mask = mask < 0.5
        samples = samples[mask]
        return samples
