import torch
from torch.utils.data import Dataset as TorchDataset
import scipy.io as sio

class Dataset(TorchDataset):
    def __init__(self, gain=1.0, use_labels=True, pde_direction='forward'):
        # Initialize common attributes
        self.gain = gain

        # Initialize placeholder attributes for normalization
        self.min_input = None
        self.max_input = None
        self.min_output = None
        self.max_output = None
    
    @property
    def name(self):
        raise NotImplementedError("Subclasses should implement this property.")
    
    @property
    def num_channels(self):
        raise NotImplementedError("Subclasses should implement this property.")

    @property
    def resolution(self):
        raise NotImplementedError("Subclasses should implement this property.")

    @property
    def label_dim(self):
        return self.y_dim
    
    # Following function is from https://github.com/neuraloperator/cond-diffusion-operators-edm/tree/master
    def normalize(self, input_data: torch.Tensor, output_data: torch.Tensor, gain: float, check_valid: bool = False):
        """Helper function to return normalized version of input and output data."""
        self._set_normalization_parameters(input_data, output_data)
        self.gain = gain
        input_normed = self.norm_input(input_data)
        output_normed = self.norm_output(output_data)
        
        if check_valid:
            assert torch.isclose(self.denorm_input(input_normed), input_data, atol=1e-4).all()
            assert torch.isclose(self.denorm_output(output_normed), output_data, atol=1e-4).all()
        
        return input_normed, output_normed
    
    # Following function is from https://github.com/neuraloperator/cond-diffusion-operators-edm/tree/master
    def _set_normalization_parameters(self, a, u):
       
        self.min_input = a.min(dim=0, keepdims=True)[0]
        self.max_input = a.max(dim=0, keepdims=True)[0]
        
        self.min_output = u.min(dim=0, keepdims=True)[0]
        self.max_output = u.max(dim=0, keepdims=True)[0]
    
    # Following function is from https://github.com/neuraloperator/cond-diffusion-operators-edm/tree/master
    def norm_input(self, input_data):
        """Normalize input data for training"""
        
        input_data = (input_data - self.min_input) / (self.max_input - self.min_input + 1e-6)
        input_data = ((input_data - 0.5) / 0.5) * self.gain
        return input_data

    # Following function is from https://github.com/neuraloperator/cond-diffusion-operators-edm/tree/master
    def denorm_input(self, input_normed):
        """Denormalize input data"""
        input_normed = input_normed.to("cpu")
        input_normed = (input_normed * 0.5 + (0.5 * self.gain)) / self.gain
        return input_normed * (self.max_input - self.min_input + 1e-6) + self.min_input

    # Following function is from https://github.com/neuraloperator/cond-diffusion-operators-edm/tree/master
    def norm_output(self, output_data):
        """Normalize output data for training"""
        
        output_data = (output_data - self.min_output) / (self.max_output - self.min_output + 1e-6)
        output_data = ((output_data - 0.5) / 0.5) * self.gain
        return output_data

    # Following function is from https://github.com/neuraloperator/cond-diffusion-operators-edm/tree/master
    def denorm_output(self, output_normed):
        """Denormalize output data"""
        output_normed = output_normed.to("cpu")
        output_normed = (output_normed * 0.5 + (0.5 * self.gain)) / self.gain
        return output_normed * (self.max_output - self.min_output + 1e-6) + self.min_output
    
    def __getitem__(self, idx):
        raise NotImplementedError("Subclasses should implement this method.")
    
