import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import argparse
import numpy as np
import os


def prepare_data_with_stacked_inputs(inputs, outputs, history_length, t_p):
    """
    Prepare input tensors by stacking the history of inputs and the previous output.

    Args:
        inputs (Tensor): Input data of shape (num_trials, seq_length, input_channels).
        outputs (Tensor): Output data of shape (num_trials, seq_length, output_channels).
        history_length (int): Number of time steps to look back.
        t_p (int): The time point at which to predict the next output.

    Returns:
        stacked_inputs (Tensor): Stacked inputs of shape (num_samples, history_length * input_channels + output_channels).
        targets (Tensor): Target outputs of shape (num_samples, output_channels).
    """
    num_trials, _, input_channels = inputs.shape
    output_channels = outputs.shape[-1]
    stacked_inputs, targets = [], []

    for trial in range(num_trials):
        start_idx = max(0, t_p - history_length + 1)
        end_idx = t_p + 1

        # Input sequence from start_idx to t_p
        input_seq = inputs[trial, start_idx:end_idx, :]
        if input_seq.shape[0] < history_length:
            padding = torch.zeros(history_length - input_seq.shape[0], input_channels)
            input_seq = torch.cat((padding, input_seq), dim=0)

        # Flatten the input sequence
        # print(input_seq.shape)
        input_seq_flat = input_seq.contiguous().view(-1)  # (history_length * input_channels)
        # print(input_seq_flat.shape)
        
        # Previous output at time t_p
        prev_output = outputs[trial, t_p, :]

        # Stack the flattened input sequence and the previous output
        stacked_input = torch.cat((input_seq_flat, prev_output), dim=0)
        stacked_inputs.append(stacked_input)

        # Target output at time t_p + 1
        target = outputs[trial, t_p + 1, :]
        targets.append(target)

    stacked_inputs = torch.stack(stacked_inputs)
    targets = torch.stack(targets)
    return stacked_inputs, targets


class FFNNPredictor(torch.nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=64):
        super().__init__()

        self.model = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
#             torch.nn.ReLU(),
            torch.nn.Tanh(),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.Tanh(),
#             torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.model(x)

def compute_accuracy(predictions, targets, threshold=0.1):
    """
    Compute accuracy as the percentage of predictions within a threshold of the target.

    Args:
        predictions (Tensor): Model predictions of shape (batch_size, output_size).
        targets (Tensor): Ground truth targets of shape (batch_size, output_size).
        threshold (float): Threshold within which predictions are considered accurate.

    Returns:
        accuracy (float): Percentage of predictions within the threshold.
    """
    errors = torch.abs(predictions - targets)
    within_threshold = (errors <= threshold).all(dim=1)  # Check if all output channels are within threshold
    accuracy = within_threshold.float().mean().item() * 100  # Convert to percentage
    return accuracy


def train_model(model, criterion, optimizer, scheduler, train_loader, val_loader, num_epochs=10, eval_interval=50):
    """
    Train the model using the provided training data and print validation loss per epoch.
    """
    for epoch in range(num_epochs):
        model.train()
        for stacked_inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(stacked_inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        # Evaluate on validation set
        if epoch % eval_interval == 0:
            val_loss = evaluate_model(model, criterion, val_loader)
            print(f"Epoch {epoch + 1}/{num_epochs}, Validation Loss: {val_loss:.4f}")
    
    scheduler.step(loss)


from sklearn.model_selection import train_test_split

def train_test_split_data(stacked_inputs, targets, test_size=0.2):
    """
    Split the data into training and testing sets.

    Args:
        stacked_inputs (Tensor): Input data of shape (num_samples, input_size).
        targets (Tensor): Target outputs of shape (num_samples, output_size).
        test_size (float): Proportion of the data to include in the test split.

    Returns:
        train_inputs, test_inputs, train_targets, test_targets: Tensors for training and testing.
    """
    num_samples = stacked_inputs.size(0)
    indices = torch.randperm(num_samples)  # Shuffle indices
    test_size = int(num_samples * test_size)
    
    test_indices = indices[:test_size]
    train_indices = indices[test_size:]
    
    train_inputs = stacked_inputs[train_indices]
    test_inputs = stacked_inputs[test_indices]
    train_targets = targets[train_indices]
    test_targets = targets[test_indices]
    
    return train_inputs, test_inputs, train_targets, test_targets

def evaluate_model(model, criterion, data_loader):
    """
    Evaluate the model and compute average validation loss on the provided data.

    Args:
        model (nn.Module): Trained model to evaluate.
        criterion: Loss function used for evaluation.
        data_loader (DataLoader): DataLoader for evaluation data.

    Returns:
        average_loss (float): Average validation loss on the dataset.
    """
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for inputs_batch, targets_batch in data_loader:
            outputs_batch = model(inputs_batch)
            loss = criterion(outputs_batch, targets_batch)
            total_loss += loss.item() * inputs_batch.size(0)
    average_loss = total_loss / len(data_loader.dataset)
    return average_loss


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Tasks')
    parser.add_argument('--task', type=str, default='3bff')
    parser.add_argument('--epoch', type=int, default=100)
    parser.add_argument('--test_thresh', type=float, default=0.1)
    parser.add_argument('--lr', type=float, default=1e-2)
    args = parser.parse_args()

    path = './dynamics'
    task = args.task

    inputs = np.load(os.path.join(path, 'task_inputs', f'{task}.npy'))
    outputs = np.load(os.path.join(path, 'task_outputs', f'{task}.npy'))
        
    if 'sine' in task:
        inputs = inputs[..., 0].swapaxes(1, 2)
    if 'delayed' in task:
        inputs = inputs[-200:, :, :]
        outputs = outputs[-200:, :, :]
        history_lengths = range(1, 49, 8)
    else:
        history_lengths = range(1, 5)

    t_p = 47
    batch_size = 32
    num_epochs = args.epoch
    test_size = 0.2  # Proportion of data for testing

    performance_results = []

    for history_length in history_lengths:
        print(f"\nTraining with history length: {history_length}")
        
        # Prepare data
        inputs_tensor = torch.tensor(inputs, dtype=torch.float32)
        outputs_tensor = torch.tensor(outputs, dtype=torch.float32)
                
        stacked_inputs, targets = prepare_data_with_stacked_inputs(
            inputs_tensor, outputs_tensor, history_length, t_p
        )
                
        # Train/Test split
        train_inputs, test_inputs, train_targets, test_targets = train_test_split_data(
            stacked_inputs, targets, test_size=test_size
        )
                
        train_dataset = TensorDataset(train_inputs, train_targets)
        test_dataset = TensorDataset(test_inputs, test_targets)
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        
        # Define model
        input_size = stacked_inputs.shape[-1]
        output_size = targets.shape[-1]
        
        model = FFNNPredictor(input_size, output_size)
        
        # Training setup
        criterion = nn.MSELoss()        
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6
        )
        
        # Train model
        train_model(model, criterion, optimizer, scheduler, train_loader, test_loader, num_epochs)
        
        # Evaluate model on test set
        test_loss = evaluate_model(model, criterion, test_loader)
        print(f"Validation loss for history length {history_length}: {test_loss:.2f}")
        performance_results.append((history_length, test_loss))


