import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA


def onehot2radius(state_batch, args):

    radius_dim = args.radius_dim
    radius_candidate = np.array([int(item) for item in args.radius_candidate.split(',')])

    if state_batch.ndim == 3: 
        radius_onehot_batch = state_batch[:, -1, -radius_dim:]
    elif state_batch.ndim == 1: 
        radius_onehot_batch = state_batch[-radius_dim:]
    else:
        radius_onehot_batch = state_batch[:, -radius_dim:]

    radius_batch = np.dot(radius_onehot_batch, radius_candidate)

    return radius_batch


def get_minibatch_depricated(state_batch, radius_batch, args):

    # state_batch dim : minibatch X traj_len X dim
    # radius_batch_dim : minibatch X traj_len

    num_samples_per_trajectory = args.num_samples_per_trajectory
    batch_size, trajectory_length, _ = state_batch.shape

    # initialization
    minibatch_before = []
    minibatch_before_prime = []
    minibatch_after = []

    minibatch_radius = []

    
    # select random index for each batch index
    for batch_index in range(batch_size):
        L = radius_batch[batch_index, 0].astype(int)
        for _ in range(num_samples_per_trajectory):
            start_index = np.random.randint(0, max(1, trajectory_length - L - 1))

            # If all values are zero, skip the current iteration and move to the next one
            if np.all(state_batch[batch_index, start_index + L, :] == 0):
                continue  
            
            minibatch_before.append(state_batch[batch_index, start_index, :])
            minibatch_before_prime.append(state_batch[batch_index, start_index + 1, :])
            minibatch_after.append(state_batch[batch_index, start_index + L, :])
            minibatch_radius.append(L)

    # list to numpy
    minibatch_before = np.array(minibatch_before)
    minibatch_before_prime = np.array(minibatch_before_prime)
    minibatch_after = np.array(minibatch_after)

    minibatch_radius = np.array(minibatch_radius)

    return minibatch_before, minibatch_before_prime, minibatch_after, minibatch_radius


def get_minibatch(state_batch, radius_batch, args):
    batch_size, trajectory_length, _ = state_batch.shape
    num_samples_per_trajectory = args.num_samples_per_trajectory

    # Select random indices for each trajectory
    L = radius_batch[:, 0].astype(int)
    start_indices = np.random.randint(0, np.maximum(1, trajectory_length - L[:, None] - 1), 
                                      size=(batch_size, num_samples_per_trajectory))

    # Use batch and sample indices to select data
    batch_indices = np.arange(batch_size)[:, None]

    minibatch_before = state_batch[batch_indices, start_indices]
    minibatch_before_prime = state_batch[batch_indices, start_indices + 1]
    minibatch_after = state_batch[batch_indices, start_indices + L[:, None]]

    # Filter out cases where all values are zero
    valid_mask = ~np.all(minibatch_after == 0, axis=-1)
    minibatch_before = minibatch_before[valid_mask]
    minibatch_before_prime = minibatch_before_prime[valid_mask]
    minibatch_after = minibatch_after[valid_mask]

    # Compute L values for all batch items and samples
    expanded_L = np.repeat(L, num_samples_per_trajectory)
    minibatch_radius = expanded_L[valid_mask.flatten()]  # Use flattened valid_mask for indexing

    return minibatch_before, minibatch_before_prime, minibatch_after, minibatch_radius


def get_evalbatch(states_eval, radius_eval):

    batch_size, trajectory_length, feature_dim = states_eval.shape
    
    # initialization
    minibatches = []

    # for every batch index
    for batch_index in range(batch_size):
        # select start time index for current batch
        start_indices = np.random.randint(0, trajectory_length, size=20)

        # add data to minibatch
        for start_index in start_indices:
            minibatch = states_eval[batch_index, start_index, :]
            minibatches.append(minibatch)

    # (eval) randomly choose batch and start idx
    batch_index = np.random.randint(0, batch_size)
    L = radius_eval[batch_index, 0]
    start_index = np.random.randint(0, trajectory_length - (2*L+1))

    # (eval) select data point for eval
    minibatch_eval = states_eval[batch_index, start_index:start_index + (2*L), :]

    minibatches = np.array(minibatches).reshape(-1, feature_dim)

    return minibatches, minibatch_eval




def plot_graph(data, eval_data, writer, updates, args):

    plt.figure(figsize=(12, 8))  

    if args.radius_latent_dim == 2: # directly plot
        for i in range(len(data)):
            plt.scatter(data[i, 0], data[i, 1], color='red', alpha=0.2)  
        for i in range(len(eval_data)):
            plt.text(eval_data[i, 0], eval_data[i, 1], str(i), color='blue', alpha=1.0, fontsize=12)  
            plt.scatter(eval_data[i, 0], eval_data[i, 1], color='blue', alpha=0.2)  

    else: # PCA for higher dimension
        pca = PCA(n_components=2)
        combined_data = np.concatenate([data, eval_data], axis=0)
        pca_combined_data = pca.fit_transform(combined_data)
        pca_data = pca_combined_data[:len(data)]
        pca_data_eval = pca_combined_data[len(data):]

        for i in range(len(pca_data)):
            plt.scatter(pca_data[i, 0], pca_data[i, 1], color='red', alpha=0.2)
        for i in range(len(pca_data_eval)):
            plt.text(pca_data_eval[i, 0], pca_data_eval[i, 1], str(i), color='blue', alpha=0.5, fontsize=12)  
            plt.scatter(pca_data_eval[i, 0], pca_data_eval[i, 1], color='blue', alpha=0.2)  

    # labeling
    plt.title('Latent space visualization')
    plt.xlabel('X axis')
    plt.ylabel('Y axis')

    # add figure
    writer.add_figure('numpy_plot', plt.gcf(), updates)

    # plot
    # plt.show()