import os
import random

import h5py
import torch
from torch.utils.data import Dataset
import scipy.io as sio


from ._base import register_dataset


@register_dataset('helmholtz')
class HelmholtzDataset(Dataset):
    def __init__(self, root='./datasets/data', split='train'):
        self.root = root
        self.data_file = f'{self.root}/helmholtz_1.mat'
        self.data = sio.loadmat(self.data_file)

        self.inputs = torch.tensor(self.data['f_data'], dtype=torch.float)
        self.input_std = torch.std(self.inputs)
        self.inputs = self.inputs / self.input_std
        self.outputs = torch.tensor(self.data['psi_data'], dtype=torch.float)
        self.output_std = torch.std(self.outputs)
        self.outputs = self.outputs / self.output_std
        self.data = torch.stack((self.outputs, self.inputs), dim=-1)
        self.resolution = (128, 128, 2)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], {}


    def compute_pde_error(self,u):
        _u = u[:,:,:,0]
        _a = u[:,:,:,1]
        u, a = _u, _a
        u = u * self.output_std
        a = a * self.input_std

        a = a.unsqueeze(1)  # (1,1,H,W)
        u = u.unsqueeze(1)
        S = u.size(2)
        h = 1 / (S - 1)
        u_padded = torch.nn.functional.pad(u, (1, 1, 1, 1), 'constant', 0)
        d2u = (u_padded[:, :, :-2, 1:-1] + u_padded[:, :, 2:, 1:-1] +
               u_padded[:, :, 1:-1, :-2] + u_padded[:, :, 1:-1, 2:] - 4 * u[:, :, :, :]) / h ** 2
        pde_loss = d2u + u - a
        pde_loss = pde_loss * (h ** 2)
        pde_loss[:, :, 0, :] = 0
        pde_loss[:, :, -1, :] = 0
        pde_loss[:, :, :, 0] = 0
        pde_loss[:, :, :, -1] = 0
        pde_loss = torch.permute(pde_loss, (0, 2, 3, 1))
        return pde_loss

    def get_condition(self, x, type):
        if type == 'a':
            cond = x.clone()
            cond[:, :, :, :1] = 0
            return cond
        else:
            raise ValueError(f'Invalid condition type: {type}')
