import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA


def onehot2radius(state_batch, args):

    radius_dim = args.radius_dim
    radius_candidate = np.array([int(item) for item in args.radius_candidate.split(',')])

    if state_batch.ndim == 3: 
        radius_onehot_batch = state_batch[:, -1, -radius_dim:]
    elif state_batch.ndim == 1: 
        radius_onehot_batch = state_batch[-radius_dim:]
    else:
        radius_onehot_batch = state_batch[:, -radius_dim:]

    radius_batch = np.dot(radius_onehot_batch, radius_candidate)

    return radius_batch


def get_minibatch_depricated(state_batch, radius_input_batch, radius_batch, args):

    # state_batch dim         : minibatch X traj_len X C X H X W
    # radius_input batch dim  : minibatch X traj_len X feature
    # radius_batch dim        : minibatch X traj_len

    num_samples_per_trajectory = args.num_samples_per_trajectory
    batch_size, trajectory_length, _, _, _ = state_batch.shape

    # initialization
    minibatch_before = []
    minibatch_before_prime = []
    minibatch_after = []

    minibatch_radius_input = []
    minibatch_radius = []

    # select random index for each batch index
    for batch_index in range(batch_size):
        L = radius_batch[batch_index, 0].astype(int)
        L_input = radius_input_batch[batch_index, 0, :]
        for _ in range(num_samples_per_trajectory):
            start_index = np.random.randint(0, max(1, trajectory_length - L - 1))

            # If all values are zero, skip the current iteration and move to the next one
            if np.all(state_batch[batch_index, start_index + L, :] == 0):
                continue  
            
            minibatch_before.append(state_batch[batch_index, start_index, :])
            minibatch_before_prime.append(state_batch[batch_index, start_index + 1, :])
            minibatch_after.append(state_batch[batch_index, start_index + L, :])

            minibatch_radius_input.append(L_input)
            minibatch_radius.append(L)

    # list to numpy
    minibatch_before = np.array(minibatch_before)
    minibatch_before_prime = np.array(minibatch_before_prime)
    minibatch_after = np.array(minibatch_after)

    minibatch_radius_input = np.array(minibatch_radius_input)
    minibatch_radius = np.array(minibatch_radius)

    return minibatch_before, minibatch_before_prime, minibatch_after, minibatch_radius_input, minibatch_radius


import numpy as np

def get_minibatch(state_batch, radius_input_batch, radius_batch, args):
    num_samples_per_trajectory = args.num_samples_per_trajectory
    batch_size, trajectory_length, C, H, W = state_batch.shape

    # Compute the radius and indices
    L = radius_batch[:, 0].astype(int)
    max_indices = np.maximum(1, trajectory_length - L[:, None] - 1)
    
    # Generate random start indices for the whole batch at once
    start_indices = np.random.randint(0, max_indices, size=(batch_size, num_samples_per_trajectory))
    end_indices = start_indices + L[:, None]

    # Flatten the batch and trajectory dimensions to simplify indexing
    flat_start_indices = start_indices.flatten()
    flat_end_indices = end_indices.flatten()
    flat_batch_indices = np.repeat(np.arange(batch_size), num_samples_per_trajectory)

    # Fetch the state data for start, start + 1, and end indices
    minibatch_before = state_batch[flat_batch_indices, flat_start_indices].reshape(-1, C, H, W)
    minibatch_before_prime = state_batch[flat_batch_indices, flat_start_indices + 1].reshape(-1, C, H, W)
    minibatch_after = state_batch[flat_batch_indices, flat_end_indices].reshape(-1, C, H, W)

    # Check for zero condition across all samples and keep only the valid indices
    valid_samples = ~np.all(minibatch_after == 0, axis=(1, 2, 3)) 

    # Filter the valid minibatches
    minibatch_before = minibatch_before[valid_samples]
    minibatch_before_prime = minibatch_before_prime[valid_samples]
    minibatch_after = minibatch_after[valid_samples]

    # Prepare radius inputs
    flat_L_input = np.repeat(radius_input_batch[:, 0], num_samples_per_trajectory, axis=0)
    minibatch_radius_input = flat_L_input[valid_samples]

    # Repeat radius values and filter
    flat_L = np.repeat(L, num_samples_per_trajectory)
    minibatch_radius = flat_L[valid_samples]

    return minibatch_before, minibatch_before_prime, minibatch_after, minibatch_radius_input, minibatch_radius



def get_evalbatch(states_eval, radius_value_eval, radius_eval):

    batch_size, trajectory_length, feature_dim = states_eval.shape
    
    # initialization
    minibatches = []

    # for every batch index
    for batch_index in range(batch_size):
        # select start time index for current batch
        start_indices = np.random.randint(0, trajectory_length, size=20)

        # add data to minibatch
        for start_index in start_indices:
            minibatch = states_eval[batch_index, start_index, :]
            minibatches.append(minibatch)

    # (eval) randomly choose batch and start idx
    batch_index = np.random.randint(0, batch_size)
    L = radius_eval[batch_index, 0]
    start_index = np.random.randint(0, trajectory_length - (2*L+1))

    # (eval) select data point for eval
    minibatch_eval = states_eval[batch_index, start_index:start_index + (2*L), :]

    minibatches = np.array(minibatches).reshape(-1, feature_dim)

    return minibatches, minibatch_eval




def plot_graph(data, eval_data, writer, updates, args):

    plt.figure(figsize=(12, 8))  

    if args.radius_latent_dim == 2: # directly plot
        for i in range(len(data)):
            plt.scatter(data[i, 0], data[i, 1], color='red', alpha=0.2)  
        for i in range(len(eval_data)):
            plt.text(eval_data[i, 0], eval_data[i, 1], str(i), color='blue', alpha=1.0, fontsize=12)  
            plt.scatter(eval_data[i, 0], eval_data[i, 1], color='blue', alpha=0.2)  

    else: # PCA for higher dimension
        pca = PCA(n_components=2)
        combined_data = np.concatenate([data, eval_data], axis=0)
        pca_combined_data = pca.fit_transform(combined_data)
        pca_data = pca_combined_data[:len(data)]
        pca_data_eval = pca_combined_data[len(data):]

        for i in range(len(pca_data)):
            plt.scatter(pca_data[i, 0], pca_data[i, 1], color='red', alpha=0.2)
        for i in range(len(pca_data_eval)):
            plt.text(pca_data_eval[i, 0], pca_data_eval[i, 1], str(i), color='blue', alpha=0.5, fontsize=12)  
            plt.scatter(pca_data_eval[i, 0], pca_data_eval[i, 1], color='blue', alpha=0.2)  

    # labeling
    plt.title('Latent space visualization')
    plt.xlabel('X axis')
    plt.ylabel('Y axis')

    # add figure
    writer.add_figure('numpy_plot', plt.gcf(), updates)

    # plot
    # plt.show()


def center_crop(image, output_size):
    """
    Crop the center of an image.

    Args:
    image: NumPy array of shape (C, H, W).
    output_size: Desired output size (output_size, output_size).

    Returns:
    Cropped image of shape (C, output_size, output_size).
    """
    C, H, W = image.shape
    new_h, new_w = output_size, output_size

    top = (H - new_h) // 2
    left = (W - new_w) // 2

    # Cropping the image
    cropped_image = image[:, top:top + new_h, left:left + new_w]
    return cropped_image


def random_crop(image, output_size):
    """
    Crop a random part of the image.

    Args:
    image: NumPy array of shape (C, H, W).
    output_size: Desired output size (output_size, output_size).

    Returns:
    Cropped image of shape (C, output_size, output_size).
    """
    C, H, W = image.shape
    crop_max_height = H - output_size
    crop_max_width = W - output_size

    # Randomly choose the top left corner of the cropping box
    top = np.random.randint(0, crop_max_height) if crop_max_height > 0 else 0
    left = np.random.randint(0, crop_max_width) if crop_max_width > 0 else 0

    # Crop the image
    cropped_image = image[:, top:top + output_size, left:left + output_size]
    return cropped_image