import numpy as np
from scipy.ndimage import gaussian_filter1d


def mountain_car_noise(next_state, noise_fraction, current_state, noise_model='heteroskedastic'):
    """
    Adds noise to the mountain car transition. The noise standard deviation
    is defined as a fraction of the state variable's range.

    Parameters:
        next_state (array): The next state (position, velocity) without noise.
        noise_fraction (float): The fraction of the range to be used as base noise std.
                                For example, 0.01 means 1% of the range.
        current_state (array): The current state (position, velocity).
        noise_model (str): 'heteroskedastic' to apply state-dependent noise factors,
                           any other value applies homogeneous noise.

    Returns:
        noisy_next_state (array): Next state with noise added.
        local_stds (tuple): The noise standard deviations used for position and velocity.
        position_factor (float): The factor applied to position noise.
        velocity_factor (float): The factor applied to velocity noise.
    """
    # Define the ranges for each state variable.
    pos_min, pos_max = -1.2, 0.6
    vel_min, vel_max = -0.07, 0.07

    # Compute base noise std as a fraction of each variable's range.
    pos_range = pos_max - pos_min  # 1.8
    vel_range = vel_max - vel_min  # 0.14

    # pos_base_noise_std = noise_fraction * pos_range  # 0.018
    # vel_base_noise_std = noise_fraction * vel_range  # 0.0014

    pos_base_noise_std = 0.005
    vel_base_noise_std = 0.0005

    # # Use these to prove Entropy
    # pos_base_noise_std = 0.01
    # vel_base_noise_std = 0.001

    if noise_model == 'heteroskedastic':
        position, velocity = current_state

        # Define heteroskedastic factors.
        # Increase noise in the valley (around position -0.5)
        position_factor = 1.0 + 1.0 * np.exp(-50.0 * (position + 0.5) ** 2)
        # position_factor = 1.0 + 2.0 * np.exp(-200.0 * (position + 0.5) ** 2)  # Use these to prove Entropy

        # Increase noise near zero velocity (or any region you choose)
        velocity_factor = 1.0 + 0.01 * abs(velocity + 0.5)

        # Compute local standard deviations for each dimension.
        local_std_pos = pos_base_noise_std * position_factor
        local_std_vel = vel_base_noise_std * velocity_factor

        noisy_next_state = next_state.copy()
        noisy_next_state[0] += np.random.normal(0, local_std_pos)
        noisy_next_state[1] += np.random.normal(0, local_std_vel)

        local_stds = (local_std_pos, local_std_vel)
        return noisy_next_state, local_stds, position_factor, velocity_factor

    else:
        # Homoskedastic noise: use the base noise std for each dimension.
        noisy_next_state = next_state.copy()
        noisy_next_state[0] += np.random.normal(0, pos_base_noise_std)
        noisy_next_state[1] += np.random.normal(0, vel_base_noise_std)
        local_stds = (pos_base_noise_std, vel_base_noise_std)
        return noisy_next_state, local_stds, 1.0, 1.0



def add_transition_noise(next_state, noise_std):
    """Add Gaussian noise to state transition"""
    return next_state + np.random.normal(0, noise_std, size=next_state.shape)


def gaussian_smooth(data, sigma=5):
    """
    Apply Gaussian smoothing to the data
    Args:
        data: Input array
        sigma: Standard deviation for Gaussian kernel
    Returns:
        Smoothed data
    """
    return gaussian_filter1d(data, sigma=sigma)