import os
import random

import h5py
import torch
from torch.utils.data import Dataset
import scipy.io as sio


from ._base import register_dataset


@register_dataset('burger')
class BurgerDataset(Dataset):
    def __init__(self, root='./datasets/data', split='train'):
        self.root = root
        self.data_file = f'{self.root}/burger_1.mat'
        self.data = sio.loadmat(self.data_file)

        self.outputs = torch.tensor(self.data['output'], dtype=torch.float).unsqueeze(-1)
        self.data = self.outputs
        self.resolution = (128, 128, 1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], {}


    def compute_pde_error(self,u):
        import torch.nn.functional as F
        u = torch.permute(u, (0, 3, 1, 2))
        deriv_t = torch.tensor([[-1], [0], [1]], dtype=torch.float, device=u.device).view(1, 1, 3, 1) / 2
        deriv_x = torch.tensor([[-1, 0, 1]], dtype=torch.float, device=u.device).view(1, 1, 1, 3) / 2
        u_t = F.conv2d(u, deriv_t, padding=(1, 0))
        u_x = F.conv2d(u, deriv_x, padding=(0, 1))
        u_xx = F.conv2d(u_x, deriv_x, padding=(0, 1))

        pde_loss = u_t + u * u_x - 0.01 * u_xx
        pde_loss = torch.permute(pde_loss, (0, 2, 3, 1))
        return pde_loss

    def get_condition(self, x, type):
        if type == 'ic':
            cond = x.clone()
            cond[:,:, 1:] = 0
            return cond
        elif type == 'bc':
            cond = x.clone()
            cond[:, 1:] = 0
            return cond
        elif type == 'round':
            cond = x.clone()
            cond[:, 1:-1, 1:-1] = 0
            return cond
        else:
            raise ValueError(f'Invalid condition type: {type}')
