import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

    
def onehot2radius(state_batch, args):

    radius_dim = args.radius_dim
    radius_candidate = np.array([int(item) for item in args.radius_candidate.split(',')])

    if state_batch.ndim == 3: 
        radius_onehot_batch = state_batch[:, -1, -radius_dim:]
    elif state_batch.ndim == 1: 
        radius_onehot_batch = state_batch[-radius_dim:]
    else:
        radius_onehot_batch = state_batch[:, -radius_dim:]

    radius_batch = np.dot(radius_onehot_batch, radius_candidate)

    return radius_batch


def get_minibatch_deprecated(state_batch, radius_batch, args):

    # state_batch dim : minibatch X traj_len X dim
    # radius_batch_dim : minibatch X traj_len

    num_samples_per_trajectory = args.num_samples_per_trajectory
    batch_size, trajectory_length, _ = state_batch.shape

    # initialization
    minibatch_before = []
    minibatch_before_prime = []
    minibatch_after = []

    minibatch_radius = []

    # select random index for each batch index
    for batch_index in range(batch_size):
        L = radius_batch[batch_index, 0].astype(int)
        for _ in range(num_samples_per_trajectory):
            start_index = np.random.randint(0, max(1, trajectory_length - L - 1))

            # If all values are zero, skip the current iteration and move to the next one
            if np.all(state_batch[batch_index, start_index + L, :] == 0):
                continue  
            
            minibatch_before.append(state_batch[batch_index, start_index, :])
            minibatch_before_prime.append(state_batch[batch_index, start_index + 1, :])
            minibatch_after.append(state_batch[batch_index, start_index + L, :])
            minibatch_radius.append(L)

    # list to numpy
    minibatch_before = np.array(minibatch_before)
    minibatch_before_prime = np.array(minibatch_before_prime)
    minibatch_after = np.array(minibatch_after)

    minibatch_radius = np.array(minibatch_radius)

    return minibatch_before, minibatch_before_prime, minibatch_after, minibatch_radius


def get_minibatch(state_batch, radius_batch, args):
    batch_size, trajectory_length, _ = state_batch.shape
    num_samples_per_trajectory = args.num_samples_per_trajectory

    # Select random indices for each trajectory
    L = radius_batch[:, 0].astype(int)
    start_indices = np.random.randint(0, np.maximum(1, trajectory_length - L[:, None] - 1), 
                                      size=(batch_size, num_samples_per_trajectory))

    # Use batch and sample indices to select data
    batch_indices = np.arange(batch_size)[:, None]

    minibatch_before = state_batch[batch_indices, start_indices]
    minibatch_before_prime = state_batch[batch_indices, start_indices + 1]
    minibatch_after = state_batch[batch_indices, start_indices + L[:, None]]

    # Filter out cases where all values are zero
    valid_mask = ~np.all(minibatch_after == 0, axis=-1)
    minibatch_before = minibatch_before[valid_mask]
    minibatch_before_prime = minibatch_before_prime[valid_mask]
    minibatch_after = minibatch_after[valid_mask]

    # Compute L values for all batch items and samples
    expanded_L = np.repeat(L, num_samples_per_trajectory)
    minibatch_radius = expanded_L[valid_mask.flatten()]  # Use flattened valid_mask for indexing

    return minibatch_before, minibatch_before_prime, minibatch_after, minibatch_radius


def get_evalbatch(states_eval, radius_eval):

    batch_size, trajectory_length, feature_dim = states_eval.shape
    
    # initialization
    minibatches = []

    # for every batch index
    for batch_index in range(batch_size):
        # select start time index for current batch
        start_indices = np.random.randint(0, trajectory_length, size=20)

        # add data to minibatch
        for start_index in start_indices:
            minibatch = states_eval[batch_index, start_index, :]
            minibatches.append(minibatch)

    # (eval) randomly choose batch and start idx
    batch_index = np.random.randint(0, batch_size)
    L = radius_eval[batch_index, 0]
    start_index = np.random.randint(0, trajectory_length - (2*L+1))

    # (eval) select data point for eval
    minibatch_eval = states_eval[batch_index, start_index:start_index + (2*L), :]

    minibatches = np.array(minibatches).reshape(-1, feature_dim)

    return minibatches, minibatch_eval



def plot_graph(data, eval_data, writer, updates, args):

    plt.figure(figsize=(12, 8))  

    if args.radius_latent_dim == 2: # directly plot
        for i in range(len(data)):
            plt.scatter(data[i, 0], data[i, 1], color='red', alpha=0.2)  
        for i in range(len(eval_data)):
            plt.text(eval_data[i, 0], eval_data[i, 1], str(i), color='blue', alpha=1.0, fontsize=12)  
            plt.scatter(eval_data[i, 0], eval_data[i, 1], color='blue', alpha=0.2)  

    else: # PCA for higher dimension
        pca = PCA(n_components=2)
        combined_data = np.concatenate([data, eval_data], axis=0)
        pca_combined_data = pca.fit_transform(combined_data)
        pca_data = pca_combined_data[:len(data)]
        pca_data_eval = pca_combined_data[len(data):]

        for i in range(len(pca_data)):
            plt.scatter(pca_data[i, 0], pca_data[i, 1], color='red', alpha=0.2)
        for i in range(len(pca_data_eval)):
            plt.text(pca_data_eval[i, 0], pca_data_eval[i, 1], str(i), color='blue', alpha=0.5, fontsize=12)  
            plt.scatter(pca_data_eval[i, 0], pca_data_eval[i, 1], color='blue', alpha=0.2)  

    # labeling
    plt.title('Latent space visualization')
    plt.xlabel('X axis')
    plt.ylabel('Y axis')

    # add figure
    writer.add_figure('numpy_plot', plt.gcf(), updates)

    # plot
    # plt.show()
    

def plot_pca_fft(skill_states, dt=1.0):
    """
    Performs PCA (1D) on each skill (episode) in 'skill_states' and then computes FFT. 
    Plots the FFT amplitude spectrum (up to half of the frequency range) for all skills on one figure.
    
    Args:
        skill_states (np.ndarray or list): Shape should be (episodes, time_steps, joint_dim).
        dt (float): Sampling interval in seconds (e.g., 0.01 if time step is 10ms).
    """
    # Convert to numpy array if it's not already
    skill_states = np.array(skill_states)
    print("skill_states.shape :", skill_states.shape)  # e.g., (4, T, 6)

    # Extract dimensions: number of episodes, number of time steps, and joint dimensions
    episodes = skill_states.shape[0]
    time_steps = skill_states.shape[1]
    joint_dim = skill_states.shape[2]

    # Create a PCA object for 1D projection
    pca = PCA(n_components=1)

    # Prepare for plotting
    plt.figure(figsize=(10, 5))

    # Loop through each skill (episode)
    for skill_idx in range(episodes):
        # skill_states[skill_idx] has shape (time_steps, joint_dim)
        X_skill = skill_states[skill_idx]

        # (1) Apply PCA to project (time_steps, joint_dim) into 1D
        X_skill_1d = pca.fit_transform(X_skill).squeeze()  # shape: (time_steps,)

        # (2) Compute FFT on the 1D signal
        fft_vals = np.fft.fft(X_skill_1d)  # complex-valued result
        fft_amp = np.abs(fft_vals)        # take the amplitude

        # (3) Generate frequency axis
        freqs = np.fft.fftfreq(time_steps, d=dt)  # shape: (time_steps,)

        # ★ 4. Plot only the frequency range from 0 to 1 Hz (using a boolean mask)
        freq_mask = (freqs >= 0) & (freqs <= 1)

        # Since the FFT result is generally symmetric about zero for positive/negative frequencies,
        # we usually only look at the positive side.
        # The condition (freqs >= 0) is already included in freq_mask.

        # Apply the mask and plot
        plt.plot(freqs[freq_mask], fft_amp[freq_mask], label=f"Skill {skill_idx+1}")

    plt.title("FFT of 1D PCA Projection for Each Skill")
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Amplitude")
    plt.legend()
    plt.tight_layout()
    plt.show()
    
