import os
import random

import h5py
import torch
from torch.utils.data import Dataset
import scipy.io as sio


from ._base import register_dataset


@register_dataset('poisson')
class PoissonDataset(Dataset):
    def __init__(self, root='./datasets/data', split='train'):
        self.root = root
        self.data_file = f'{self.root}/poisson_1.mat'
        self.data = sio.loadmat(self.data_file)

        self.inputs = torch.tensor(self.data['f_data'], dtype=torch.float)
        self.outputs = torch.tensor(self.data['phi_data'], dtype=torch.float)

        self.input_std = torch.std(self.inputs)
        self.inputs = self.inputs / self.input_std

        self.output_std = torch.std(self.outputs)
        self.outputs = self.outputs / self.output_std

        self.data = torch.stack((self.outputs, self.inputs), dim=-1)
        self.resolution = (128, 128, 2)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], {}

    def visualize_batch(self, num_samples=20, rows=4, cols=5, samples=None):
        from matplotlib import pyplot as plt
        fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
        axes = axes.flatten()

        for i in range(num_samples):
            ax = axes[i]
            if samples is None:
                (y, params) = self[i]
            else: y = samples[i]
            y = y.squeeze().numpy()[:,:,0]

            ax.imshow(y)
            ax.set_xlabel("Position (y)")
            ax.set_ylabel("u(y, t)")

        plt.tight_layout()
        plt.show()

    def compute_pde_error(self,u):
        _u = u[:,:,:,0]
        _a = u[:,:,:,1]
        u, a = _u, _a

        u = u * self.output_std
        a = a * self.input_std

        a = a.unsqueeze(1)  # (1,1,H,W)
        u = u.unsqueeze(1)
        S = u.size(2)
        h = 1 / (S - 1)
        u_padded = torch.nn.functional.pad(u, (1, 1, 1, 1), 'constant', 0)
        d2u = (u_padded[:, :, :-2, 1:-1] + u_padded[:, :, 2:, 1:-1] +
               u_padded[:, :, 1:-1, :-2] + u_padded[:, :, 1:-1, 2:] - 4 * u[:, :, :, :]) / h ** 2
        pde_loss = d2u - a
        pde_loss = pde_loss * (h ** 2)
        pde_loss[:, :, 0, :] = 0
        pde_loss[:, :, -1, :] = 0
        pde_loss[:, :, :, 0] = 0
        pde_loss[:, :, :, -1] = 0
        pde_loss = torch.permute(pde_loss, (0, 2, 3, 1))
        return pde_loss

    def get_condition(self, x, type):
        if type == 'a':
            cond = x.clone()
            cond[:, :, :, 0] = 0
            return cond
        else:
            raise ValueError(f'Invalid condition type: {type}')
