import os
from pathlib import Path

import h5py
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset

def to_tensor(array):
    return torch.tensor(array, dtype=torch.float)

def to_causal_segments(spikes, ic_window_length):
    """
    Split spikes into encoding and decoding segments for causal training.

    Parameters:
    - spikes: Tensor of shape (n_trials, seq_length, n_neurons)
    - ic_window_length: int, length of the initial condition window
    - predict_t: bool, whether to predict at time t or t+1

    Returns:
    - encoding_spikes: Tensor of shape (n_trials, ic_window_length, n_neurons)
    - decoding_spikes: Tensor of shape (n_trials, seq_length - ic_window_length, n_neurons)
    """
    
    # Check if window length is equal to sequence length
    if spikes.shape[1] == ic_window_length:
        raise ValueError('I.C. window length is equal to seq length of spikes, '
                         'cannot split causally. Change \'ic_window_length\' '
                         'field in config.yaml to a value smaller than '
                         'seq length.')

    # Split spikes and rates into past and future segments
    encoding_spikes = spikes[:, :ic_window_length, :]
    decoding_spikes = spikes[:, ic_window_length:, :]

    return (encoding_spikes, decoding_spikes)

def create_windows(data, window_size, step_size):
    """
    Create sliding windows from the input data.
    
    Parameters:
    - data: numpy array of shape (num_timepoints, num_neurons)
    - window_size: int, size of each window
    - step_size: int, step size between windows

    Returns:
    - windows: numpy array of shape (num_windows, window_size, num_neurons)
    """
    return sliding_window_view(
        data,
        window_shape=window_size,
        axis=0
    )[::step_size, :].transpose(0, 2, 1)
    

def window_data(data, window_size, step_size, concat=True):
    """
    Create sliding windows from a list of time series data. Returned window data
    will be concatenated into a single numpy array by default.

    Parameters:
    - data: list of numpy arrays, each of shape (num_timepoints, num_neurons)
    - window_size: int, size of each window
    - step_size: int, step size between windows

    Returns:
    - windows: torch tensor of shape (total_windows, window_size, num_neurons)
               OR
               list of torch tensors, each shape (n_windows, win_size, n_neurs)
    """
    # Raise error if step size > window size
    if step_size > window_size:
        raise ValueError("Step size must be less than or equal to window size.")

    windows = [
        to_tensor(
            create_windows(
                bout,
                window_size=window_size,
                step_size=step_size
            )
        ) for bout in data if bout.shape[0] >= window_size
    ]

    if concat:
        return torch.cat(windows, dim=0)
    
    return windows


def train_val_split(data, config, val_size=0.2):
    """
    """
    # Create numpy generator with specified seed
    rng = np.random.default_rng(config['seed'])
    # Get bout lengths from config file
    bout_lengths = [d.shape[0] for d in data]
    # Get total timepoints
    total_timepoints = sum(bout_lengths)
    # Get val percent of timepoints and bouts
    val_timepoints = int(total_timepoints * val_size)
    val_bouts = int(len(bout_lengths) * val_size)

    # Randomly choose bouts for validation set
    val_bouts = rng.choice(len(data), size=val_bouts, replace=False)
    # Ensure validation bouts are less than or equal to time points
    val_length = sum(bout_lengths[i] for i in val_bouts)
    if val_length > val_timepoints:
        # Print a warning message
        print("Warning: Validation set contains "
              f"{val_length/total_timepoints * 100} percent dataset.")

    # Get validation and training data
    val_data = [data[i] for i in val_bouts]
    train_data = [data[i] for i in range(len(data)) if i not in val_bouts]
    
    return train_data, val_data

# TODO: Need to make changes or a new class for the Wang dataset now
class NeuralData():
    def __init__(
        self, 
        config: dict
    ):
        super().__init__()
        self.config = config
        # TODO: Need to think about how to handle when dont want windows
    
    def setup(self, data, stage=None):

        # Split data into training and validation sets
        train_bouts, val_bouts = train_val_split(data,
                                               self.config,
                                               val_size=0.2)
        
        if self.config['window_size'] == -1:
            train_windows = to_tensor(np.stack(train_bouts))
            self.val_windows = to_tensor(np.stack(val_bouts))
        else:
            # Window training and validation data
            train_windows = window_data(train_bouts,
                                        self.config['window_size'],
                                        self.config['step_size']
                                        )
            # Save windowed bouts for validation loop
            self.val_windows = window_data(val_bouts,
                                        self.config['window_size'],
                                        self.config['step_size'],
                                        concat=False
                                        )

        # Window training bouts
        train_windows = window_data(train_bouts,
                                    self.config['window_size'],
                                    self.config['step_size']
                                    )

        # Split windows into causal segments
        if self.config['causal_model']:
            # Split training data into causal segments
            train_enc_spikes, train_dec_spikes = to_causal_segments(
                train_windows, self.config['ic_window_size']
            )
        else:
            # Use the entire window as encoding and decoding segments
            train_enc_spikes = train_windows
            train_dec_spikes = train_windows
        
        # Print dataset size for training set
        print(f"Training set size: {train_enc_spikes.shape[0]}")

        # Store datasets and dataloader
        self.train_ds = TensorDataset(train_enc_spikes, train_dec_spikes)
        self.train_dl = self.train_dataloader()

    def train_dataloader(self, shuffle=True):
        train_dl = DataLoader(
            self.train_ds,
            batch_size=self.config['batch_size'],
            shuffle=shuffle,
        )
        return train_dl

    def val_dataloader(self):
        valid_dl = DataLoader(
            self.valid_ds,
            batch_size=self.config['batch_size'],
            shuffle=False
        )
        return valid_dl

