import torch
import h5py
import random
import numpy as np
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

class HDF5DatasetGraph(Dataset):    
    def __init__(self, 
                 path,
                 nt,
                 nx,
                 mode='train'):
        
        assert mode in ['train', 'valid', 'test'], "mode must belong to one of these ['train', 'val', 'test']"
        
        f = h5py.File(path, 'r')
        self.mode = mode
        self.data = f[self.mode]
        self.dataset = f'pde_{nt}-{nx}'


    def __len__(self):
        return self.data[self.dataset].shape[0]

    def __getitem__(self, idx):        
        x = torch.from_numpy(self.data['x'][idx]).unsqueeze(-1) # N, 1
        t = torch.from_numpy(self.data['t'][idx]) # T
        u = torch.from_numpy(self.data[self.dataset][idx]).permute(1,0) # N, T
        
        return_tensors = {
            'u': u,
            'x': x,
            't': t
        }
        return return_tensors    
    
    
class HDF5DatasetImplicitGNN(Dataset):
    
    def __init__(self, 
                 path,
                 nt,
                 nx,
                 mode='train', 
                 samples = 256):
        
        assert mode in ['train', 'valid', 'test'], "mode must belong to one of these ['train', 'val', 'test']"
        
        f = h5py.File(path, 'r')
        self.mode = mode
        self.data = f[self.mode]
        self.dataset = f'pde_{nt}-{nx}'
        self.samples = samples

    def __len__(self):
        return self.data[self.dataset].shape[0]

    def __getitem__(self, idx):
        
        x = self.data['x'][idx]
        # Normalize time coordinates
        x = 2*(x-x.min())/(x.max()-x.min())-1
        
        t = self.data['t'][idx]
        u_hr = torch.from_numpy(self.data[self.dataset][idx]).unsqueeze(1) # T, 1, L
        T, _, L = u_hr.shape
        u_lr = u_hr[:,:,::2] # T, 1, L//2
        lr_coord = x[::2]
        
        
#(Extrapolation)
#        x = self.data['x'][idx, :125]  # slicing here
        # Normalize time coordinates
#        x = 2*(x-x.min())/(x.max()-x.min())-1
        
#        t = self.data['t'][idx, :125]  # slicing here
#        u_hr = torch.from_numpy(self.data[self.dataset][idx, :125, :]).unsqueeze(1)  # slicing here
#        T, _, L = u_hr.shape
#        u_lr = u_hr[:,:,::2] # T, 1, L//2
#        lr_coord = x[::2]    

        if self.mode in ['train']:
            indices_left = np.setdiff1d(np.arange(0,L), np.arange(0,L)[::2])
            sample_lst = torch.tensor(sorted(np.random.choice(indices_left, self.samples, replace=False)))
            hr_coord = x[sample_lst]

            hr_points = u_hr[:,:,sample_lst].permute(0,2,1)

            return_tensors = {
            't': t,
            'sample_idx': sample_lst,
            'lr_frames': u_lr,
            'hr_frames': u_hr,
            'hr_points': hr_points, 
            'coords_hr': hr_coord,
            'coords_lr': lr_coord
            }
        else:
            indices_left = np.setdiff1d(np.arange(0,L), np.arange(0,L)[::2])
            hr_coord = x[indices_left]

            hr_points = u_hr[:,:,indices_left].permute(0,2,1)

            return_tensors = {
            't': t,
            'lr_frames': u_lr,
            'hr_frames': u_hr,
            'hr_points': hr_points, 
            'coords_hr': hr_coord,
            'coords_lr': lr_coord 
        }

        return return_tensors
    
class HDF5Dataset2d(Dataset):
    """
    Load samples of an PDE Dataset, get items according to PDE.
    """
    def __init__(self, 
                 path: str,
                 mode: str,
                 nt: int,
                 res: int,
                 dtype=torch.float32):
        """Initialize the dataset object.
        Args:
            path: path to dataset
            mode: [train, valid, test]
            nt: temporal resolution
            res: spatial resolution
            shift: [fourier, linear]
            dtype: floating precision of data
            load_all: load all the data into memory
        """
        super().__init__()
        self.f = h5py.File(path, 'r')
        self.mode = mode
        self.dtype = dtype
 
        # Generate keys from '0000' to '0999'
        all_keys = [str(i).zfill(4) for i in range(1000)]

        # split keys into train, valid, and test sets
        train_valid_keys, self.test_keys = train_test_split(all_keys, test_size=0.2, random_state=42)
        self.train_keys, self.valid_keys = train_test_split(train_valid_keys, test_size=0.25, random_state=42)  # Taking 20% of 80% -> 16% of total as validation
    
        if self.mode == 'train':
            self.keys = self.train_keys
        elif self.mode == 'test':
            self.keys = self.test_keys
        else:  # For 'valid' mode
            self.keys = self.valid_keys

            
    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx: int):
        """
        Returns data items for batched training/validation/testing.
        Args:
            idx: data index
        Returns:
            torch.Tensor: data trajectory used for training/validation/testing
            torch.Tensor: dx
            torch.Tensor: dt
        """
        key = self.keys[idx]
        u = self.f[key]['data'][:-1, 2::4, 2::4, :]# Shape: (101, 128, 128, 1) -> (100, 32, 32, 1)
        dx = (self.f[key]['grid']['x'][4:5]-self.f[key]['grid']['x'][0:1])
        dy = (self.f[key]['grid']['y'][4:5]-self.f[key]['grid']['y'][0:1])
        dt = (self.f[key]['grid']['t'][1:2]-self.f[key]['grid']['t'][0:1])
        dx, dy, dt = torch.from_numpy(dx), torch.from_numpy(dy), torch.from_numpy(dt)
        u = u.squeeze(-1)
        
        return u, dx[0], dy[0], dt[0]
    
    
class HDF5DatasetGraph_2d(Dataset):
    def __init__(self, 
                 path,
                 nt,
                 res,
                 mode='train', 
                 regular=True,
                 seed=0,
                original=False):
        
        assert mode in ['train', 'valid', 'test'], "mode must belong to one of these ['train', 'valid', 'test']"
        
        self.f = h5py.File(path, 'r')
        self.mode = mode
        self.regular = regular
        self.dataset = f'pde_{nt}-{res}'
        self.seed = seed
        self.original = original
        
        # Generate keys from '0000' to '0999'
        all_keys = [str(i).zfill(4) for i in range(1000)]

        # split keys into train, valid, and test sets
        train_valid_keys, self.test_keys = train_test_split(all_keys, test_size=0.2, random_state=42)
        self.train_keys, self.valid_keys = train_test_split(train_valid_keys, test_size=0.25, random_state=42)  # Taking 20% of 80% -> 16% of total as validation
    
        if self.mode == 'train':
            self.keys = self.train_keys
        elif self.mode == 'test':
            self.keys = self.test_keys
        else:  # For 'valid' mode
            self.keys = self.valid_keys

        if not self.regular:
            # Only compute the necessary indices once to save memory
            random.seed(seed)
            self.sampled_indices = random.sample(range(128 * 128), 32 * 32) 
            
    def __len__(self):
        return len(self.keys)
            
    def __getitem__(self, idx):
        key = self.keys[idx]
        
        if self.regular:
            if self.original:
                data = self.f[key]['data'][:, :, :, :]  # Shape: (101, 128, 128, 1) 
                u = torch.from_numpy(data.squeeze(-1))
                u = u.reshape(u.shape[0], -1)
                u = u.permute(1, 0)            
                x = torch.from_numpy(self.f[key]['grid']['x'][:])  # Shape: (128,) 
                y = torch.from_numpy(self.f[key]['grid']['y'][:])  # Shape: (128,) 
                coords = torch.stack(torch.meshgrid(x, y), dim=-1).reshape(-1, 2)  # Shape: (1024, 2)
                
            else:
                data = self.f[key]['data'][:, 2::4, 2::4, :]  # Shape: (101, 128, 128, 1) -> (101, 32, 32, 1)
                u = torch.from_numpy(data.squeeze(-1))
                u = u.reshape(u.shape[0], -1)
                u = u.permute(1, 0)            
                x = torch.from_numpy(self.f[key]['grid']['x'][2::4])  # Shape: (128,) -> (32,)
                y = torch.from_numpy(self.f[key]['grid']['y'][2::4])  # Shape: (128,) -> (32,)
                coords = torch.stack(torch.meshgrid(x, y), dim=-1).reshape(-1, 2)  # Shape: (1024, 2)
        else:
            data = self.f[key]['data'][:]
            u = torch.from_numpy(data.squeeze(-1))
            u = u.reshape(u.shape[0], -1)
            u = u.permute(1, 0)
            # Now compute the coordinates on the fly using the indices
            x_full = torch.from_numpy(self.f[key]['grid']['x'][:])
            y_full = torch.from_numpy(self.f[key]['grid']['y'][:])
            W = len(x_full)

            coords = [(x_full[i % W].item(), y_full[i // W].item()) for i in self.sampled_indices]
            coords = torch.tensor(coords, dtype=torch.float32)
            u = u[torch.tensor(self.sampled_indices, dtype=torch.long), :]

        t = torch.from_numpy(self.f[key]['grid']['t'][:])  # Shape: (101,)
        
        return_tensors = {
            'u': u,
            'x': coords,
            't': t
        }        
        return return_tensors

class HDF5DatasetImplicitGNN_2d(Dataset):
    def __init__(self, path, nt, res, mode='train', samples=256):
        assert mode in ['train', 'valid', 'test'], "mode must belong to one of these ['train', 'val', 'test']"
        
        self.f = h5py.File(path, 'r')
        self.mode = mode
        self.nt = nt
        self.res = res
        self.samples = samples

        # Get all keys and split them into training and test sets
        all_keys = list(map(str, range(1000)))
        # split keys into train, valid, and test sets
        train_valid_keys, self.test_keys = train_test_split(all_keys, test_size=0.2, random_state=42)
        self.train_keys, self.valid_keys = train_test_split(train_valid_keys, test_size=0.25, random_state=42)  # Taking 20% of 80% -> 16% of total as validation
    
        if self.mode == 'train':
            self.keys = self.train_keys
        elif self.mode == 'test':
            self.keys = self.test_keys
        else:  # For 'valid' mode
            self.keys = self.valid_keys
            
    def __len__(self):
        # Adjusted to reflect the number of keys (1000)
        return len(self.keys)

    def __getitem__(self, idx):
        key = str(idx).zfill(4)
    
        data = self.f[key]['data'][:-1, 2::4, 2::4, :]  # shape: (101, 32, 32, 1)
        grid = self.f[key]['grid']
    
        x = grid['x'][2::4]  # shape: (32,)
        y = grid['y'][2::4]  # shape: (32,)
        t = grid['t'][:-1]  # shape: (101,)
    
        coords = np.stack(np.meshgrid(x, y), axis=-1)  # shape: (32, 32, 2)
        coords = coords.reshape(-1, coords.shape[-1])  # shape: (1024, 2)
    
        u_hr = torch.from_numpy(data).squeeze(-1)  # remove the last dimension
        u_hr = u_hr.reshape(u_hr.shape[0], 1, -1)  # reshape to (101, 1, 1024)
    
        coords = 2*(coords-coords.min(0))/(coords.max(0)-coords.min(0))-1  # normalize coordinates
        
        T, _, N = u_hr.shape
        u_lr = u_hr[:,:,::2] # T, 1, N//2
        lr_coord = coords[::2]
        
        if self.mode in ['train']:
            indices_left = np.setdiff1d(np.arange(0,N), np.arange(0,N)[::2])
            sample_lst = torch.tensor(sorted(np.random.choice(indices_left, self.samples, replace=False)))
            hr_coord = coords[sample_lst]
            hr_points = u_hr[:,:,sample_lst].permute(0,2,1)

            return_tensors = {
                't': t,
                'sample_idx': sample_lst,
                'lr_frames': u_lr,
                'hr_frames': u_hr,
                'hr_points': hr_points, 
                'coords_hr': hr_coord,
                'coords_lr': lr_coord
            }
        else:
            indices_left = np.setdiff1d(np.arange(0,N), np.arange(0,N)[::2])
            hr_coord = coords[indices_left]
            hr_points = u_hr[:,:,indices_left].permute(0,2,1)

            return_tensors = {
                't': t,
                'lr_frames': u_lr,
                'hr_frames': u_hr,
                'hr_points': hr_points, 
                'coords_hr': hr_coord,
                'coords_lr': lr_coord 
            } 

        return return_tensors
    
    
    
class HDF5DatasetImplicitGNN_2d_irregular(Dataset): 
    def __init__(self, path, nt, res, mode='train', samples=256, seed=0):
        assert mode in ['train', 'valid', 'test'], "mode must belong to one of these ['train', 'val', 'test']"
        
        self.f = h5py.File(path, 'r')
        self.mode = mode
        self.nt = nt
        self.res = res
        self.samples = samples

        # Generate keys from '0000' to '0999'
        all_keys = [str(i).zfill(4) for i in range(1000)]

        # split keys into train, valid, and test sets
        train_valid_keys, self.test_keys = train_test_split(all_keys, test_size=0.2, random_state=42)
        self.train_keys, self.valid_keys = train_test_split(train_valid_keys, test_size=0.25, random_state=42)  # Taking 20% of 80% -> 16% of total as validation
    
        if self.mode == 'train':
            self.keys = self.train_keys
        elif self.mode == 'test':
            self.keys = self.test_keys
        else:  # For 'valid' mode
            self.keys = self.valid_keys
            
        key = self.keys[0]  
        W = len(self.f[key]['grid']['x'][:])

        # Set seed and sample 1024 coordinates randomly
        random.seed(seed)
        self.sampled_coords = random.sample([(i, j) for i in range(W) for j in range(W)], 32*32)
    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        key = str(idx).zfill(4)
        
        data = self.f[key]['data'][:-1,:, :, :]  # shape: (101, 128, 128, 1)
        data = torch.from_numpy(data).squeeze(-1)  # Shape: (101, 128, 128)
        u_hr = data[:, [coord[0] for coord in self.sampled_coords], [coord[1] for coord in self.sampled_coords]]  # Shape: (101, 32, 32)

        grid = self.f[key]['grid']    

        t = grid['t'][:-1]  # shape: (101,)   
        x_full = self.f[key]['grid']['x'][:]
        y_full = self.f[key]['grid']['y'][:]

        coords = np.array([[x_full[i], y_full[j]] for i, j in self.sampled_coords])  # Shape: (1024, 2)
      
        u_hr = u_hr.reshape(u_hr.shape[0], 1, -1)  # reshape to (101, 1, 1024)
        
        coords = 2*(coords-coords.min(0))/(coords.max(0)-coords.min(0))-1  # normalize coordinates    
        
        T, _, N = u_hr.shape
        u_lr = u_hr[:,:,::2] # T, 1, N//2
        lr_coord = coords[::2]
        
        if self.mode in ['train']:
            indices_left = np.setdiff1d(np.arange(0,N), np.arange(0,N)[::2])
            sample_lst = torch.tensor(sorted(np.random.choice(indices_left, self.samples, replace=False)))
            hr_coord = coords[sample_lst]
            hr_points = u_hr[:,:,sample_lst].permute(0,2,1)

            return_tensors = {
                't': t,
                'sample_idx': sample_lst,
                'lr_frames': u_lr,
                'hr_frames': u_hr,
                'hr_points': hr_points, 
                'coords_hr': hr_coord,
                'coords_lr': lr_coord
            }
        else:
            indices_left = np.setdiff1d(np.arange(0,N), np.arange(0,N)[::2])
            hr_coord = coords[indices_left]
            hr_points = u_hr[:,:,indices_left].permute(0,2,1)

            return_tensors = {
                't': t,
                'lr_frames': u_lr,
                'hr_frames': u_hr,
                'hr_points': hr_points, 
                'coords_hr': hr_coord,
                'coords_lr': lr_coord 
            } 

        return return_tensors