import io
import os

import numpy as np
import pyarrow.parquet as pq
import torch

def get_data(path):
    """
    Load data from a parquet file and convert to a PyTorch tensor.
    
    Args:
        path (str): Path to the parquet file containing the data.
        
    Returns:
        torch.Tensor: Data tensor with shape (batch_size, seq_len, nx),
                     where nx is the spatial dimension.
                     
    Raises:
        FileNotFoundError: If the data file doesn't exist.
        
    ##TODO: 
    - Data generation should be modified to also include the following:
        - Dimension of the problem: dim
        - Size of step in time: dt
    - Using dim, we can reuse get_data and process_data for different dimensions.
    - Using dt, which we can pass to process_data, we can generate data at different
      temporal resolutions.
    """
    # Check if the data file exists
    if not os.path.exists(path):
        raise FileNotFoundError(f"Data file not found: {path}.")
    
    # Load data from parquet file:
    # 1. Read the 'u' column from the parquet file
    # 2. Convert each entry to a numpy array using BytesIO
    # 3. Stack all arrays together
    # 4. Convert to float32 for better memory efficiency
    data = np.stack(
        [
            np.load(io.BytesIO(x.as_buffer()))
            for x in pq.read_table(path)["u"]
        ]
    ).astype(np.float32)
    
    # Convert numpy array to PyTorch tensor
    data = torch.tensor(data)

    return data  # Shape: (batch_size, seq_len, nx) where nx is the spatial dimension

def process_data(data, data_dt, pred_dt, residual=False):
    """
    Process data by creating input-target pairs based on time steps.
    
    Args:
        data (torch.Tensor): Input data tensor with shape (batch_size, seq_len, nx).
        data_dt (float): Time step size in the original data.
        pred_dt (float): Desired prediction time step size.
        residual (bool, optional): If True, compute targets as residuals from inputs.
                                  Default is False.
    
    Returns:
        tuple: (inputs, targets) tensors for model training.
    
    ##TODO:
    - Add dim parameter to process_data
    """
    # Validate pred_dt is a multiple of data_dt
    assert abs(pred_dt / data_dt - round(pred_dt / data_dt)) < 1e-10, "pred_dt must be a multiple of data_dt!"
    
    # Calculate step size for sampling based on the ratio of prediction dt to data dt
    step_size = round(pred_dt / data_dt)
    
    # Window size includes input and target (hence +1)
    window_size = step_size + 1

    # Create overlapping windows of size window_size along the sequence dimension
    # unfold(dim=1, size=window_size, step=1) creates windows along dimension 1
    # permute reorganizes dimensions to (batch, seq, window, spatial)
    data = data.unfold(1, window_size, 1).permute(0, 1, 3, 2)

    # Sample the windows at appropriate step_size to create input-target pairs
    data = data[..., :: step_size, :]

    # Flatten batch and sequence dimensions to create independent samples
    data = data.reshape(-1, 2, data.shape[-1])

    # Split data into inputs and targets
    inputs = data[:, :-1, :]
    targets = data[:, 1:, :]

    # If residual is True, subtract the inputs from the targets to compute residuals
    if residual:
        targets = targets.clone() - inputs
    
    return inputs, targets

def process_traj(traj, data_dt, pred_dt):
    """
    Create a safe copy of a trajectory for validation purposes.
    
    Args:
        traj (torch.Tensor): The trajectory tensor to copy.
        
    Returns:
        torch.Tensor: A cloned copy of the input trajectory.
    """
    # Create a copy of the trajectory for validation
    # This keeps the original trajectory unmodified
    # return traj.clone()

    assert abs(pred_dt / data_dt - round(pred_dt / data_dt)) < 1e-10, "pred_dt must be a multiple of data_dt!"

    step_size = round(pred_dt / data_dt)

    traj = traj[:, ::step_size, :]

    return traj