import time
import math
import glob 
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import argparse
import logging
import numpy as np
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
torch.autograd.set_detect_anomaly(True)
import gymnasium as gym
import torch.distributions as dist
from scipy.stats import chi2
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import random
# TODO: make a logging object instead of printing to terminal

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

def mahalanobis_distance(mean, cov, true_value):
    diff = true_value - mean
    inv_cov = np.linalg.inv(cov)

    return np.sqrt(np.dot(np.dot(diff.T, inv_cov), diff))

def euclidean_distance(mean, true_value):
    return np.linalg.norm(true_value - mean)

# Helper function to create block-diagonal matrix for each batch element
def batch_block_diag(cov1, cov2):
    batch_size = cov1.shape[0]
    cov_list = []
    for i in range(batch_size):
        # Create block diagonal for each item in the batch
        block_diag_cov = torch.block_diag(cov1[i], cov2[i])
        cov_list.append(block_diag_cov)
    return torch.stack(cov_list, dim=0)

def confidence_set_test(means, covs, true_values, confidence_levels):
    n_samples = len(means)
    results = {level: [] for level in confidence_levels}
    
    for i in range(n_samples):
        mean = means[i]
        cov = covs[i]
        true_value = true_values[i]
        mahal_dist = mahalanobis_distance(mean, cov, true_value)
        
        for level in confidence_levels:
            chi2_threshold = chi2.ppf(level, df=mean.shape[0])
            results[level].append(mahal_dist <= np.sqrt(chi2_threshold))
    
    return results

def uncertainty_error_correlation(means, covs, true_values):
    n_samples = len(means)
    mahalanobis_distances = []
    euclidean_distances = []
    
    for i in range(n_samples):
        mean = means[i]
        cov = covs[i]
        true_value = true_values[i]
        
        mahal_dist = mahalanobis_distance(mean, cov, true_value)
        eucl_dist = euclidean_distance(mean, true_value)
        
        mahalanobis_distances.append(mahal_dist)
        euclidean_distances.append(eucl_dist)
    
    correlation = np.corrcoef(mahalanobis_distances, euclidean_distances)[0, 1]
    return correlation

class WeightedMSELoss(nn.Module):
    def __init__(self, decay_factor=0.98):
        super(WeightedMSELoss, self).__init__()
        self.decay_factor = decay_factor

    def forward(self, predicted, target):
        # Ensure the shapes match
        assert predicted.shape == target.shape, "Shape of predicted and target must match"
        
        # Get the sequence length from the input shape
        batch_size, seq_len, feats = predicted.shape

        # Inverse time decay example
        # Create an exponentially decaying weight tensor of shape (seq_len,)
        weights = torch.tensor([self.decay_factor ** i for i in range(seq_len)], device=predicted.device)
        
        # Reshape weights to apply across the sequence dimension (broadcasting over batch and feats)
        weights = weights.view(1, seq_len, 1)  # Shape: (1, seq_len, 1)

        # Compute the MSE for each element
        mse_loss = (predicted - target) ** 2  # Shape: (batch, seq, feats)

        # Apply the weights to the loss (element-wise multiplication)
        weighted_mse_loss = mse_loss * weights  # Shape: (batch, seq, feats)

        # Now, average over the features and sum over the sequence dimension
        # Optionally, average over batch if you want
        loss = weighted_mse_loss.mean()

        return loss


class TrajectoryDataset(Dataset):
    # The dataset format is on the shape of (# training examples, sequence length * (state dim + action dim))
    def __init__(self, data_file, history_size, prediction_horizon, state_dim, action_dim):
        super(TrajectoryDataset, self).__init__()
        # Load the data from numpy file
        self.data = torch.from_numpy(np.load(data_file, allow_pickle=True))[:10000]
        self.history_size = history_size
        self.prediction_horizon = prediction_horizon
        self.state_dim  = state_dim 
        self.action_dim = action_dim
        
        if(self.history_size + self.prediction_horizon > self.data.shape[-1]):
            raise Exception("History size and prediction horizons doesn't match the dataset trajectories.")

    
    def __getitem__(self, idx):
        trajectory = self.data[idx].reshape([-1,self.state_dim+self.action_dim])

        past_states    = trajectory[:self.history_size, :self.state_dim]
        past_actions   = trajectory[:self.history_size, self.state_dim:]
        future_states  = trajectory[self.history_size:, :self.state_dim]
        future_actions = trajectory[self.history_size:, self.state_dim:]
        
        
        return (
            torch.tensor(past_states   , dtype=torch.float32),  # Past staes
            torch.tensor(past_actions  , dtype=torch.float32),  # Past actions
            torch.tensor(future_states , dtype=torch.float32),  # Future states
            torch.tensor(future_actions, dtype=torch.float32),  # Future actions
        )
    
    def __len__(self):
        return len(self.data)




def get_dataloaders(train_data_file, test_data_file, history_size, prediction_horizon, state_dim, action_dim):

    train_dataset = TrajectoryDataset(train_data_file, history_size, prediction_horizon, state_dim, action_dim) # create your datset
    train_dataloader = DataLoader(train_dataset, shuffle = True, batch_size=1024) # create your dataloader
    
    test_dataset = TrajectoryDataset(test_data_file, history_size, prediction_horizon, state_dim, action_dim) # create your datset
    test_dataloader = DataLoader(test_dataset, shuffle = True, batch_size=1024) # create your dataloader
    
    return train_dataloader, test_dataloader

class TraditionalPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.05, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create a matrix of [SeqLen, D_model] with positional encodings
        pe = torch.zeros(max_len, d_model)  # (max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             (-math.log(1000.0) / d_model))
                             
        pe[:, 0::2] = torch.sin(position * div_term)  # Apply sin to even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # Apply cos to odd indices
        pe = pe.unsqueeze(1)  # Add batch dimension: (max_len, 1, d_model)
        self.register_buffer('pe', pe)  # Register as buffer to avoid updates

    def forward(self, x):
        """
        x: Tensor of shape (seq_len, batch_size, embedding_dim)
        """
        x = x + self.pe[:x.size(0), :]  # Add positional encoding
        return self.dropout(x)

class RelativePositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.05, max_len=250):
        super(PositionalEncoding, self).__init__()
        self.relative_positions = nn.Embedding(2 * max_len - 1, d_model)  # Relative positions

    def forward(self, x):
        seq_len = x.size(0)
        # Generate relative position indices matrix
        positions = torch.arange(seq_len, dtype=torch.long, device=x.device)
        relative_positions_matrix = positions[None, :] - positions[:, None]  # Shape (seq_len, seq_len)
        relative_positions_matrix += seq_len - 1  # Shift values to be positive (0 to 2*seq_len-2)
        
        # Fetch relative positional encodings
        rel_pos_embeddings = self.relative_positions(relative_positions_matrix)
        return x + rel_pos_embeddings[:x.shape[0], :]  # Add relative positional encodings to input

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.05, max_len=350):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model

    def forward(self, x):
        # x is expected to have shape (seq_len, batch_size, d_model)
        seq_len, batch_size, d_model = x.shape
        
        # Ensure that d_model is divisible by 2
        assert d_model % 2 == 0, "d_model must be divisible by 2 for RoPE."

        half_dim = d_model // 2
        freqs = torch.exp(-torch.arange(0, half_dim, 2, dtype=torch.float32) * (math.log(10000.0) / half_dim))
        angles = torch.einsum("i,j->ij", torch.arange(seq_len, dtype=torch.float32), freqs)  # (seq_len, half_dim)
        
        # Apply sin and cos functions for rotary encoding
        encoding = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)  # (seq_len, half_dim)

        # Split x into two halves, apply RoPE to the first half
        x1, x2 = x[:, :, :half_dim], x[:, :, half_dim:]
        x1 = x1 * encoding.unsqueeze(1).to(x.device)  # Shape (seq_len, batch_size, half_dim)

        # Concatenate the two halves back together
        return torch.cat([x1, x2], dim=-1)  # Shape (seq_len, batch_size, d_model)


class TransformerDynamicsModelNonCausal(nn.Module):
    def __init__(self, enc_input_dim, dec_input_dim, dec_output_dim, hidden_dim, num_heads, num_encoder_layers, num_decoder_layers, dropout=0.1):
        super(TransformerDynamicsModel, self).__init__()

        # Embedding layers to project input and output sequences to hidden dimension
        self.enc_input_embedding = nn.Linear(enc_input_dim, hidden_dim)
        self.dec_input_embedding = nn.Linear(dec_input_dim, hidden_dim)

        # Positional encoding layers
        self.pos_encoder = PositionalEncoding(hidden_dim, dropout)
        self.pos_decoder = PositionalEncoding(hidden_dim, dropout)

        # Transformer layer
        self.transformer = nn.Transformer(d_model=hidden_dim, 
                                          nhead=num_heads, 
                                          num_encoder_layers=num_encoder_layers, 
                                          num_decoder_layers=num_decoder_layers, 
                                          dim_feedforward=64,
                                          dropout=dropout)

        # Final linear layer to project back to output dimension
        self.output_layer = nn.Linear(hidden_dim, dec_output_dim)

    def forward(self, enc_src_seq, dec_src_seq):
        """
        enc_src_seq: Encoder input sequence of shape (batch_size, seq_length, enc_input_dim)
        dec_src_seq: Decoder input sequence of shape (batch_size, seq_length, dec_input_dim)
        """
        batch_size, enc_src_seq_len, _ = enc_src_seq.size()
        _, dec_src_seq_len, _ = dec_src_seq.size()

        # Embed the input and target sequences
        enc_src_seq = self.enc_input_embedding(enc_src_seq)  # (batch_size, src_seq_len, hidden_dim)
        dec_src_seq = self.dec_input_embedding(dec_src_seq)  # (batch_size, tgt_seq_len, hidden_dim)

        # Transpose for compatibility with nn.Transformer (seq_len, batch_size, hidden_dim)
        enc_src_seq = enc_src_seq.permute(1, 0, 2)  # (src_seq_len, batch_size, hidden_dim)
        dec_src_seq = dec_src_seq.permute(1, 0, 2)  # (tgt_seq_len, batch_size, hidden_dim)

        # Apply positional encoding
        enc_src_seq = self.pos_encoder(enc_src_seq)
        dec_src_seq = self.pos_decoder(dec_src_seq)

        # Transformer forward pass
        memory = self.transformer.encoder(enc_src_seq)  # (src_seq_len, batch_size, hidden_dim)
        decoded = self.transformer.decoder(dec_src_seq, memory)  # (tgt_seq_len, batch_size, hidden_dim)

        # Transpose back to (batch_size, seq_len, hidden_dim)
        decoded = decoded.permute(1, 0, 2)  # (batch_size, tgt_seq_len, hidden_dim)

        # Apply the final linear layer
        output = self.output_layer(decoded)  # (batch_size, tgt_seq_len, dec_output_dim)

        # Return the predicted sequence and the encoded memory
        return memory.permute(1, 0, 2), output  # (batch_size, src_seq_len, hidden_dim), (batch_size, tgt_seq_len, dec_output_dim)



class TransformerDynamicsModel(nn.Module):
    def __init__(self, enc_input_dim, dec_input_dim, dec_output_dim, hidden_dim, num_heads, num_encoder_layers, num_decoder_layers, window_size=8, dropout=0.1):
        super(TransformerDynamicsModel, self).__init__()

        self.window_size = window_size

        # Embedding layers to project input and output sequences to hidden dimension
        self.enc_input_embedding = nn.Linear(enc_input_dim, hidden_dim)
        self.dec_input_embedding = nn.Linear(dec_input_dim, hidden_dim)

        # Learnable start token
        self.start_token = nn.Parameter(torch.randn(1, 1, hidden_dim))  # Shape: (1, 1, hidden_dim)

        # Positional encoding layers
        self.pos_encoder = PositionalEncoding(hidden_dim, dropout)
        self.pos_decoder = PositionalEncoding(hidden_dim, dropout)

        # Transformer layer
        self.transformer = nn.Transformer(d_model=hidden_dim, 
                                          nhead=num_heads, 
                                          num_encoder_layers=num_encoder_layers, 
                                          num_decoder_layers=num_decoder_layers, 
                                          dim_feedforward=64,
                                          dropout=dropout)

        # Final linear layer to project back to output dimension
        self.output_layer = nn.Linear(hidden_dim, dec_output_dim)

    def generate_causal_sliding_window_mask(self, size, window_size):
        """
        Generate a causal sliding window mask.
        size: int, the length of the sequence.
        window_size: int, the size of the sliding window.
        Returns:
        A mask of shape (size, size) with upper triangular part beyond the window set to -inf.
        """
        mask = torch.full((size, size), float('-inf'))  # Start with a full -inf matrix
        for i in range(size):
            # Allow attention only within the sliding window and to previous tokens (causal)
            start = max(0, i - window_size)
            mask[i, start:i + 1] = 0  # Set 0s in the window and causal positions
        return mask  # (seq_len, seq_len)

    def forward(self, enc_src_seq, dec_src_seq):
        """
        enc_src_seq: Encoder input sequence of shape (batch_size, seq_length, enc_input_dim)
        dec_src_seq: Decoder input sequence of shape (batch_size, seq_length, dec_input_dim)
        """
        batch_size, enc_src_seq_len, _ = enc_src_seq.size()
        _, dec_src_seq_len, _ = dec_src_seq.size()

        # Embed the input and target sequences
        enc_src_seq = self.enc_input_embedding(enc_src_seq)  # (batch_size, src_seq_len, hidden_dim)
        dec_src_seq = self.dec_input_embedding(dec_src_seq)  # (batch_size, tgt_seq_len, hidden_dim)

        # Prepend the learnable start token to the decoder input sequence
        start_token = self.start_token.expand(batch_size, -1, -1)  # (batch_size, 1, hidden_dim)
        dec_src_seq = torch.cat([start_token, dec_src_seq], dim=1)  # (batch_size, 1 + tgt_seq_len, hidden_dim)

        # Adjust target sequence length after adding start token
        dec_src_seq_len += 1

        # Transpose for compatibility with nn.Transformer (seq_len, batch_size, hidden_dim)
        enc_src_seq = enc_src_seq.permute(1, 0, 2)  # (src_seq_len, batch_size, hidden_dim)
        dec_src_seq = dec_src_seq.permute(1, 0, 2)  # (tgt_seq_len, batch_size, hidden_dim)

        # Apply positional encoding
        enc_src_seq = self.pos_encoder(enc_src_seq)
        dec_src_seq = self.pos_decoder(dec_src_seq)

        # Generate the causal sliding window mask for the decoder
        causal_window_mask = self.generate_causal_sliding_window_mask(dec_src_seq_len, self.window_size).to(enc_src_seq.device)  # (tgt_seq_len, tgt_seq_len)

        # Transformer forward pass
        memory = self.transformer.encoder(enc_src_seq)  # (src_seq_len, batch_size, hidden_dim)
        decoded = self.transformer.decoder(dec_src_seq, memory, tgt_mask=causal_window_mask)  # (tgt_seq_len, batch_size, hidden_dim)

        # Remove the start token before projecting to output
        decoded = decoded[1:].permute(1, 0, 2)  # (batch_size, tgt_seq_len, hidden_dim)

        # Apply the final linear layer
        output = self.output_layer(decoded)  # (batch_size, tgt_seq_len, dec_output_dim)

        # Return the predicted sequence and the encoded memory
        return memory.permute(1, 0, 2), output  # (batch_size, src_seq_len, hidden_dim), (batch_size, tgt_seq_len, dec_output_dim)


class MultivariateBayesianRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim, prior_sample_weight = 10, init_params=None):
        """
        prior_sample_weight:   initial number of prior importance (sampling weight?) 
        output_dim: Dimension of koopman embeddings
        input_dim : Dimension of state koopman embedding
        init_params: Dictionary for initializing prior parameters
        """
        super(MultivariateBayesianRegressionModel, self).__init__()
        
        self.input_dim  = input_dim
        self.output_dim = output_dim
        
        # Initialize prior parameters
        self.init_prior(init_params, prior_sample_weight, input_dim, output_dim)
        
    def init_prior(self, init_params, n, input_dim, output_dim):
        """
        Initialize prior parameters for MN-IW distribution.
        Takes user-provided params if available.
        """
        if init_params is None:
            init_params = {}
        
        self.M_hat   = nn.Parameter(init_params.get('M_hat', torch.zeros((output_dim, input_dim))), requires_grad=False)
        self.V_hat   = nn.Parameter(init_params.get('V_hat', torch.eye(input_dim)), requires_grad=False)
        self.nu_hat = nn.Parameter(torch.tensor(float(init_params.get('nu_hat', n))), requires_grad=False)
        self.Psi_hat = nn.Parameter(init_params.get('Psi_hat', torch.eye(output_dim)), requires_grad=False)
        
    def update_posterior(self, X, Y):
        """
        Update the posterior distribution given data (X, Y)
        X: (input_dim  x N) matrix of input states
        Y: (output_dim x N) matrix of output states
        """
        N = X.shape[1] # Number of data points

        # Compute S matrices
        S_yx = Y @ X.T + self.M_hat @ self.V_hat
        S_xx = X @ X.T + self.V_hat
        S_yy = Y @ Y.T + self.M_hat @ self.V_hat @ self.M_hat.T
        
        # Posterior update rules
        M_hat_new = S_yx @ torch.linalg.pinv(S_xx)
        V_hat_new = S_xx
        nu_hat_new = self.nu_hat + N

        
        
        Psi_hat_new = self.Psi_hat + S_yy - S_yx @ torch.linalg.pinv(S_xx) @ S_yx.T

        # Update the parameters in-place
        self.M_hat.data = M_hat_new
        self.V_hat.data = V_hat_new
        self.nu_hat.data = nu_hat_new
        self.Psi_hat.data = Psi_hat_new
        

    def forward(self, x_t, calculate_covariance=False):
        """
        Perform a prediction on x_t.
        x_t: The current state @ time t, either a vector or a tensor of means
        calculate_covariance: Whether to calculate covariance or not
        """
        if calculate_covariance:
            # If x_t is a batch of means and covariances (not a MultivariateNormal)
            x_mean, x_cov = x_t

            # Calculate the mean of the distribution
            mean = torch.matmul(x_mean, self.M_hat.T)

            # Covariance propagation: M_hat * cov(x_t) * M_hat.T + Psi_hat * (1 + x_mean.T @ V_hat @ x_mean)
            if(len(x_mean.shape) == 3):
                adjustment_term = (1 + torch.einsum('bki,ij,bkj->bk', x_mean, self.V_hat, x_mean))  # Inner product across batch
            else:
                adjustment_term = (1 + torch.einsum('bi,ij,bj->b', x_mean, self.V_hat, x_mean))  # Inner product across batch
            covariance = torch.matmul(self.M_hat, torch.matmul(x_cov, self.M_hat.T)) + self.Psi_hat * adjustment_term.unsqueeze(-1).unsqueeze(-1)
            
            return mean, covariance

        else:
            # x_t is assumed to be a batch of vectors (not a Gaussian)
            mean = torch.matmul(x_t, self.M_hat.T)
            return mean

def train_bayesian_operators(dataloader, model, training_horizon):
    model.to("cpu").train()
    total_loss = 0

    embeddings = []
    propagated_embeddings = []
    state_embedings = [] 
    state_ground_truth = []
    state_embedding_dim = model.state_embedding_dim
    #i = 0
    #dummy = []
    with torch.no_grad():
        for data in dataloader:
            past_states, past_actions, future_states, future_actions = data
            past_states = past_states.to("cpu")
            past_actions = past_actions.to("cpu")
            future_states = future_states.to("cpu")[:, :training_horizon, :]
            future_actions = future_actions.to("cpu")[:, :training_horizon, :]

            enc_inp = torch.cat((past_states, past_actions), dim=2)
            
            if(model.is_variational):
                _, history_embeddings, propagated_history_embeddings, predicted_history_states, predicted_future_states, \
                    predicted_future_embeddings = model(enc_inp, future_actions)   
            else:
                history_embeddings, propagated_history_embeddings, predicted_history_states, predicted_future_states, \
                    predicted_future_embeddings = model(enc_inp, future_actions)  

            
            embeddings.append(history_embeddings[:, :-1, :].reshape((-1, history_embeddings.shape[-1])))
            embeddings.append(predicted_future_embeddings[:, :-1, :].reshape((-1, predicted_future_embeddings.shape[-1])))
            propagated_embeddings.append(history_embeddings[:, 1:, :model.state_embedding_dim].reshape((-1, state_embedding_dim)))
            propagated_embeddings.append(predicted_future_embeddings[:, 1:, :model.state_embedding_dim].reshape((-1, state_embedding_dim)))


            
            state_embedings.append(
                torch.cat((
                    history_embeddings[:,:, :model.state_embedding_dim].reshape((-1, model.state_embedding_dim)),
                    predicted_future_embeddings[:,:, :model.state_embedding_dim].reshape((-1, model.state_embedding_dim))
                ))
            ) 
            state_ground_truth.append(
                torch.cat((
                    past_states.reshape((-1, past_states.shape[-1])),
                    future_states.reshape((-1, future_states.shape[-1]))
                ))
            )
    
    embeddings = torch.t(torch.cat(embeddings, dim=0))
    propagated_embeddings = torch.t(torch.cat(propagated_embeddings, dim=0))
    model.update_k_op_posterior(embeddings, propagated_embeddings)

    with torch.no_grad():
        for data in dataloader:
            past_states, past_actions, future_states, future_actions = data
            past_states = past_states.to("cpu")
            past_actions = past_actions.to("cpu")
            future_states = future_states.to("cpu")[:, :training_horizon, :]
            future_actions = future_actions.to("cpu")[:, :training_horizon, :]

            enc_inp = torch.cat((past_states, past_actions), dim=2)
            
            if(model.is_variational):
                _, history_embeddings, propagated_history_embeddings, predicted_history_states, predicted_future_states, \
                    predicted_future_embeddings = model(enc_inp, future_actions)   
            else:
                history_embeddings, propagated_history_embeddings, predicted_history_states, predicted_future_states, \
                    predicted_future_embeddings = model(enc_inp, future_actions)       
            
            state_embedings.append(
                torch.cat((
                    history_embeddings[:,:, :model.state_embedding_dim].reshape((-1, model.state_embedding_dim)),
                    predicted_future_embeddings[:,:, :model.state_embedding_dim].reshape((-1, model.state_embedding_dim))
                ))
            ) 
            state_ground_truth.append(
                torch.cat((
                    past_states.reshape((-1, past_states.shape[-1])),
                    future_states.reshape((-1, future_states.shape[-1]))
                ))
            )

    state_embedings = torch.t(torch.cat(state_embedings, dim=0))
    state_ground_truth = torch.t(torch.cat(state_ground_truth, dim=0))
    model.update_o_mat_posterior(state_embedings, state_ground_truth)
    model.to(device)
    return 

class Blast(nn.Module):
    def __init__(self, state_embedding_dim, action_embedding_dim, state_dim, action_dim, hidden_dim, \
                  num_heads, num_encoder_layers, num_decoder_layers, dropout=0.1, is_variational = False):
        super(Blast, self).__init__()

        self.state_embedding_dim  = state_embedding_dim
        self.action_embedding_dim = action_embedding_dim
        self.state_dim = state_dim
        self.action_dim = action_dim 

        # Koopman Operator
        self.k_op = nn.Linear(state_embedding_dim + action_embedding_dim, state_embedding_dim, bias=False)

        # Observation Matrix
        self.o_mat = nn.Linear(state_embedding_dim, state_dim, bias=False)
        
        # Define embedding transformer
        enc_input_dim  = state_dim + action_dim
        dec_input_dim  = action_dim
        dec_output_dim = state_dim

        self.is_variational = is_variational
        self.predict_uncertainty = False
        self.final_linear_layer = nn.Linear(hidden_dim, state_embedding_dim + action_embedding_dim)
        if(self.is_variational):
            self.embedding_transformer = TransformerDynamicsModel(enc_input_dim, dec_input_dim, 12*action_embedding_dim, \
                                                                hidden_dim, num_heads, num_encoder_layers, num_decoder_layers, dropout=0)
        
            # Variational layers for mean and variance of future embeddings
            self.future_embedding_mu = nn.Linear(12*action_embedding_dim, action_embedding_dim)        # Mean layer
            self.future_embedding_logvar = nn.Linear(12*action_embedding_dim, action_embedding_dim)    # Log variance layer
        else:
            self.embedding_transformer = TransformerDynamicsModel(enc_input_dim, dec_input_dim, action_embedding_dim, \
                                                                hidden_dim, num_heads, num_encoder_layers, num_decoder_layers, dropout=0)
            

    def reparameterize(self, mu, logvar):
        """Reparameterization trick to sample z from the learned distribution."""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + std * eps

    def compute_kl_loss(self, mu, logvar):
        """KL divergence loss between the learned distribution and standard Gaussian."""
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        return kl_loss

    def update_k_op_posterior(self, embeddings, propagated_embeddings):
        self.k_op = MultivariateBayesianRegressionModel(self.state_embedding_dim + self.action_embedding_dim, \
                                                         self.state_embedding_dim, 10, {"M_hat":self.k_op.weight})

        self.k_op.update_posterior(embeddings, propagated_embeddings)


    def update_o_mat_posterior(self, state_embedings, state_ground_truth):
        

        self.o_mat = MultivariateBayesianRegressionModel(self.state_embedding_dim, self.state_dim, 10, {"M_hat":self.o_mat.weight})

        self.o_mat.update_posterior(state_embedings, state_ground_truth)



    def forward(self, history_sequence, future_sequence):
        # History sequence is a sequence of states and actions. Format: (batch size, seq length, features)
        # Future sequence is a sequence of actions only.        Format: (batch size, seq length, features)

        prediction_horizon = future_sequence.shape[1]
        history_length = history_sequence.shape[1]

        history_embeddings, future_embeddings = self.embedding_transformer(history_sequence, future_sequence)

        if(self.is_variational):
            # Compute the mean and log variance for the future embeddings
            mu = self.future_embedding_mu(future_embeddings)
            logvar = self.future_embedding_logvar(future_embeddings)

            # Sample future embeddings using the reparameterization trick
            future_embeddings_variational = self.reparameterize(mu, logvar)

            future_embeddings = future_embeddings_variational
        
        
        if(self.training):
            propagated_history_embeddings =  self.k_op(history_embeddings)
            current_state_embedding = propagated_history_embeddings[:, -1, :]

            # Initialize a list to store the predicted future embeddings
            predicted_future_embeddings_list = []
            predicted_future_embeddings_list.append(torch.cat((current_state_embedding, future_embeddings[:, 0, :]), 1))


            for i in range(prediction_horizon - 1):
                current_state_embedding = self.k_op(predicted_future_embeddings_list[-1])
                predicted_future_embeddings_list.append(torch.cat((current_state_embedding, future_embeddings[:, i + 1, :]), 1))

            predicted_future_embeddings = torch.stack(predicted_future_embeddings_list, dim=1)

            predicted_future_states  = self.o_mat(predicted_future_embeddings[:, :, :self.state_embedding_dim])
            predicted_history_states = self.o_mat(propagated_history_embeddings)

            if(self.is_variational):
                kl_loss = self.compute_kl_loss(mu, logvar)
                return kl_loss, history_embeddings, propagated_history_embeddings, predicted_history_states, predicted_future_states, predicted_future_embeddings
            else:
                return history_embeddings, propagated_history_embeddings, predicted_history_states, predicted_future_states, predicted_future_embeddings

        else:
            
            # The model is in eval mode.

            if self.predict_uncertainty:
                # k_op returns both mean and covariance
                propagated_history_embeddings_mean, propagated_history_embeddings_cov = self.k_op((history_embeddings[:, -1, :], torch.zeros((history_embeddings.shape[-1], history_embeddings.shape[-1])).to(device)), calculate_covariance=True)
                
                current_state_embedding_mean = propagated_history_embeddings_mean
                current_state_embedding_cov = propagated_history_embeddings_cov
                
                predicted_future_embeddings_list = []
                predicted_future_embeddings_list.append(
                    (torch.cat((current_state_embedding_mean, future_embeddings[:, 0, :]), dim=1),
                    batch_block_diag(current_state_embedding_cov, torch.zeros((current_state_embedding_cov.shape[0], future_embeddings.shape[-1], future_embeddings.shape[-1])).to(device)))
                )

                for i in range(prediction_horizon - 1):
                   
                    current_state_embedding_mean, current_state_embedding_cov = self.k_op(predicted_future_embeddings_list[-1], calculate_covariance=True)

                    predicted_future_embeddings_list.append(
                        (torch.cat((current_state_embedding_mean, future_embeddings[:, i + 1, :]), dim=1),
                        batch_block_diag(current_state_embedding_cov, torch.zeros((current_state_embedding_cov.shape[0], future_embeddings.shape[-1], future_embeddings.shape[-1])).to(device)))
                    )

                predicted_future_embeddings_mean = torch.stack([embed[0] for embed in predicted_future_embeddings_list], dim=1)
                predicted_future_embeddings_cov = torch.stack([embed[1] for embed in predicted_future_embeddings_list], dim=1)

                # o_mat handles both mean and covariance
                predicted_future_states_mean, predicted_future_states_cov = self.o_mat((predicted_future_embeddings_mean[:, :, :self.state_embedding_dim], predicted_future_embeddings_cov[:, :, :self.state_embedding_dim, :self.state_embedding_dim]), calculate_covariance=True)
                
                return predicted_future_states_mean, predicted_future_states_cov

            else:
                propagated_history_embeddings = self.k_op(history_embeddings)
                current_state_embedding = propagated_history_embeddings[:, -1, :]
                
                predicted_future_embeddings_list = []
                predicted_future_embeddings_list.append(
                    torch.cat((current_state_embedding, future_embeddings[:, 0, :]), dim=1)
                )

                for i in range(prediction_horizon - 1):
                    current_state_embedding = self.k_op(predicted_future_embeddings_list[-1])
                    predicted_future_embeddings_list.append(
                        torch.cat((current_state_embedding, future_embeddings[:, i + 1, :]), dim=1)
                    )

                predicted_future_embeddings = torch.stack(predicted_future_embeddings_list, dim=1)

                # o_mat handles deterministic predictions
                predicted_future_states = self.o_mat(predicted_future_embeddings[:, :, :self.state_embedding_dim])
                return predicted_future_states
            
def train_epoch(dataloader, training_horizon, model, optimizer, criterion, kl_weight):

    model.train()
    total_loss = 0
    for data in dataloader:
        past_states, past_actions, future_states, future_actions = data
        past_states = past_states.to(device)
        past_actions = past_actions.to(device)
        future_states = future_states.to(device)
        
        future_actions = future_actions.to(device)[:, :training_horizon, :]
        enc_inp = torch.cat((past_states, past_actions), dim=2)
        
        optimizer.zero_grad()
        if(model.is_variational):
            loss, history_embeddings, propagated_history_embeddings, predicted_history_states, predicted_future_states, _ = \
                model(enc_inp, future_actions)    
            loss = [loss*kl_weight]
        else:
            history_embeddings, propagated_history_embeddings, predicted_history_states, predicted_future_states, _ = \
                model(enc_inp, future_actions) 
            loss = []

        loss.append(criterion(history_embeddings[:, 1:, :state_embedding_dim], propagated_history_embeddings[:, :-1, :]))
        loss.append(criterion(predicted_future_states, future_states[:, :training_horizon, :]))
        loss.append(criterion(predicted_history_states, past_states))
        
    
        sum(loss).backward()
        
        optimizer.step()
        if(model.is_variational):
            total_loss += np.array([i.item() for i in loss])
        else:
            total_loss += np.array([0] + [i.item() for i in loss])
    return total_loss


def eval_epoch(dataloader, training_horizon, model, criterion):
        
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for data in dataloader:
            past_states, past_actions, future_states, future_actions = data
            past_states = past_states.to(device)
            past_actions = past_actions.to(device)
            future_states = future_states.to(device)
            future_actions = future_actions.to(device)[:, :training_horizon, :]

            enc_inp = torch.cat((past_states, past_actions), dim=2)
            
            predicted_future_states = model(enc_inp, future_actions)     

            loss = criterion(predicted_future_states, future_states[:, :training_horizon, :])

            total_loss += loss.item()

    return total_loss


def prediction_horizon_loss(dataloader, model, criterion, prediction_horizon):
    
    model.eval()
    total_loss = None
    
    with torch.no_grad():
        for data in dataloader:
            past_states, past_actions, future_states, future_actions = data
            past_states = past_states.to(device)
            past_actions = past_actions.to(device)
            future_states = future_states.to(device)[:, :prediction_horizon, :]
            future_actions = future_actions.to(device)[:, :prediction_horizon, :]

            enc_inp = torch.cat((past_states, past_actions), dim=2)

            predicted_future_states = model(enc_inp, future_actions)
            
            loss = criterion(predicted_future_states, future_states.to(device))
            
            
            loss = torch.mean(loss, 0)
            
            loss = torch.mean(loss, -1)
            if total_loss is None:
                total_loss = loss.cpu().detach().numpy()
            else:
                total_loss += loss.cpu().detach().numpy()
            
    return total_loss / len(dataloader)

def eval_model_uncertainty(model, dataloader, device, writer, step):
    """
    Evaluate the uncertainty predicted by model.
    
    Arguments:
    model -- the trained model
    dataloader -- data loader with test data
    device -- torch device (cpu or gpu)
    writer -- tensorboard SummaryWriter instance
    step -- the current global step for logging
    
    Returns:
    The correlation between uncertainty and prediction error.
    """
    model.eval()
    model.predict_uncertainty = True 

    means = []
    covs = []
    true_values = []
    
    confidence_levels = [0.5, 0.6, 0.7, 0.8, 0.9]
    
    with torch.no_grad():
        for data in dataloader:
            past_states, past_actions, future_states, future_actions = data
            past_states = past_states.to(device)
            past_actions = past_actions.to(device)
            future_states = future_states.to(device)[:, :training_horizon, :]
            future_actions = future_actions.to(device)[:, :training_horizon, :]

            enc_inp = torch.cat((past_states, past_actions), dim=2)
            
            # The model returns mean and covariance
            pred_mean, pred_cov = model(enc_inp, future_actions)     
            
            means.extend(pred_mean.cpu().numpy())
            covs.extend(pred_cov.cpu().numpy())
            true_values.extend(future_states.cpu().numpy())
    
    
    means = np.array(means).reshape((-1, pred_mean.shape[-1]))[:100000]
    covs = np.array(covs).reshape((-1, pred_mean.shape[-1], means.shape[-1]))[:100000]
    true_values = np.array(true_values).reshape((-1, pred_mean.shape[-1]))[:100000]
    
    
    # Perform uncertainty-error correlation test
    correlation = uncertainty_error_correlation(means, covs, true_values)
    
    logger.info("Correlation {}".format(correlation))
    # Log the correlation to TensorBoard
    model.predict_uncertainty = False
    return correlation


def train(train_dataloader, val_dataloader, model, training_horizon, prediction_horizon, n_epochs, checkpoints_path, learning_rate=0.0001, writer=None, print_every=5):
    checkpoints_dir = os.path.join(checkpoints_path, "Checkpoints")
    os.makedirs(checkpoints_dir, exist_ok=True)

    start = time.time()
    plot_losses = []
    train_print_loss_total = 0  # Reset every print_every
    val_print_loss_total = 0

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=.0)
    
    scheduler = StepLR(optimizer, step_size=n_epochs // 3, gamma=0.3)
    
    weighted_criterion = nn.MSELoss()
    criterion = nn.MSELoss()
    kl_weight = 1
    for epoch in range(1, n_epochs + 1):
        
        train_loss = train_epoch(train_dataloader, training_horizon, model, optimizer, weighted_criterion, kl_weight)
        val_loss = eval_epoch(val_dataloader, training_horizon, model, criterion)
        
        train_print_loss_total += np.sum(train_loss)
        val_print_loss_total += val_loss

        
        writer.add_scalar('Params/lr', scheduler.get_last_lr()[0], epoch)
        
        writer.add_scalar('Loss/train', np.sum(train_loss) / len(train_dataloader), epoch)#{"LSTM": }
        writer.add_scalar('train/kl', train_loss[0] / len(train_dataloader), epoch)
        writer.add_scalar('train/alginment', train_loss[1] / len(train_dataloader), epoch)
        writer.add_scalar('train/prediction', np.sum(train_loss[2:]) / len(train_dataloader), epoch)

        writer.add_scalar('Loss/Valid', val_loss / len(val_dataloader), epoch)#{"LSTM": val_loss / len(val_dataloader)}, epoch)
        
        if epoch % print_every == 0:
            train_print_loss_total_avg = train_print_loss_total / (print_every * len(train_dataloader))
            val_print_loss_total_avg = val_print_loss_total / (print_every * len(val_dataloader))
            
            train_print_loss_total = 0
            val_print_loss_total = 0
            
            logger.info('%s (%d %d%%) %.4f %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, train_print_loss_total_avg, val_print_loss_total_avg))

        if(epoch % 25 == 0):
            torch.save(model.state_dict(), os.path.join(checkpoints_dir, "model"+str(epoch)+".pt"))
            
        scheduler.step()
    

    criterion = nn.MSELoss(reduction='none')
    horizon_loss = prediction_horizon_loss(val_dataloader, model, criterion, 300)
    print(horizon_loss.shape)
    for i in range(horizon_loss.shape[0]):
        writer.add_scalar('Loss/horizon_before', horizon_loss[i], i)

    criterion = nn.MSELoss()
    train_bayesian_operators(train_dataloader, model, training_horizon)
    logger.info("Loss on eval: {}".format(eval_epoch(val_dataloader, training_horizon, model, criterion) / len(val_dataloader)))


    eval_model_uncertainty(model, val_dataloader, device, 2, 3)

    criterion = nn.MSELoss(reduction='none')
    horizon_loss = prediction_horizon_loss(val_dataloader, model, criterion, 300)
    print(horizon_loss.shape)
    for i in range(horizon_loss.shape[0]):
        writer.add_scalar('Loss/horizon_after', horizon_loss[i], i)

    # Flush and close
    writer.flush()
    writer.close()
        
if __name__=="__main__":

    logFormatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
    logger = logging.getLogger("mainLogger")
    logger.setLevel(logging.INFO)

    parser = argparse.ArgumentParser(description='Train LSTM Encoder Deocoder Model')
    parser.add_argument('-e', '--env', help='Environment Type', required=True)
    parser.add_argument('-x', '--exp-dir', help='Experiments Directory.', required=True)
    parser.add_argument("-v", "--env-variant", help="type of environemnt, e.g.: normal, process_noise, observation_noise, or distribution_shift", required=True)
    parser.add_argument("-i", "--run_id", type=int, help="Run ID.", required=True)
    parser.add_argument("-d", "--datasets_dir", help="Run ID.", default="datasets_dir")
    parser.add_argument("--seed", default=10, type=int)              # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--device", default="cuda:0")
    args = parser.parse_args()

    # Set Random Seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    
    model_name = "Blak"
    device = args.device
    base_dir = os.path.join(os.path.join(os.path.join(args.exp_dir, args.env), args.env_variant), model_name)
    os.makedirs(base_dir, exist_ok=True)
    logs_dir = os.path.join(base_dir, "logs")
    os.makedirs(logs_dir, exist_ok=True)
    logger_path = os.path.join(logs_dir, "run_"+str(args.run_id)+".log")
    fileHandler = logging.FileHandler(logger_path, mode="w")
    fileHandler.setLevel(logging.INFO)
    fileHandler.setFormatter(logFormatter)
    logger.addHandler(fileHandler)

    consoleHandler = logging.StreamHandler()
    consoleHandler.setFormatter(logFormatter)
    consoleHandler.setLevel(logging.INFO)
    logger.addHandler(consoleHandler)
    
    tb_dir = os.path.join(os.path.join(base_dir, "Tensorboard"), "run"+str(args.run_id))
    os.makedirs(tb_dir, exist_ok=True)
    writer = SummaryWriter(tb_dir)

    env = gym.make(args.env)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0] 
    hidden_size = 24
    training_horizon = 200
    prediction_horizon = 300

    enc_input_dim = state_dim + action_dim  # Example input feature size
    dec_input_dim = action_dim  # Example input feature size
    dec_output_dim = state_dim  # Example output feature size
    hidden_dim = 48  # Hidden dimension size for the Transformer
    state_embedding_dim = 40  # Hidden dimension size for the Transformer
    action_embedding_dim = 8
    num_heads = 12  # Number of attention heads
    num_encoder_layers = 1  # Number of encoder layers
    num_decoder_layers = 1  # Number of decoder layers
    dropout = 0.1  # Dropout rate

    # Create the model
    model = Blast(state_embedding_dim, action_embedding_dim, state_dim, action_dim, hidden_dim, \
                  num_heads, num_encoder_layers, num_decoder_layers, dropout=0).to(device)
    
    print("-"*20)
    print(model)
    print("Number of parameters is:", sum(p.numel() for p in model.embedding_transformer.parameters()))
    print("-"*20)
    

    n_epochs = 300

    dataset_dir = os.path.join(os.path.join("datasets_dir", args.env), args.env_variant)
    train_data_file = os.path.join(dataset_dir, "train.npy")
    test_data_file  = os.path.join(dataset_dir, "test.npy")
    history_size = 20

    train_dataloader, val_dataloader = get_dataloaders(train_data_file, test_data_file, history_size, prediction_horizon, state_dim, action_dim)

    train(train_dataloader, val_dataloader, model, training_horizon, prediction_horizon, n_epochs, base_dir, learning_rate=0.003, 
          writer=writer, print_every=1)


