import numpy as np
import torch as th


def clipped(x, theta_min=None, theta_max=None):
    if th.is_tensor(x):
        if (theta_min is not None) or (theta_max is not None):
            y = th.clamp(x, min=theta_min, max=theta_max)
            y = y / y.sum(dim=-1, keepdim=True)
        else:
            y = x
    else:
        y = np.clip(x, a_min=theta_min, a_max=theta_max)
        y = y / y.sum()
    return y


def reshape_action_for_tensor(x, theta, low=0., high=1.):
    # Normalize x to [0, 1]
    x = (x - low) / (high - low)

    action_bins = len(theta)

    # Softmax to Make sure \sum (k_i) = 1
    theta = th.nn.functional.softmax(theta, dim=-1).to(x.device)

    # Clip if needed
    theta = clipped(theta)

    theta = theta * action_bins

    x_vec = th.linspace(0., 1., action_bins + 1, device=x.device)[:-1]

    x_delta = x[:, None] - x_vec[None, :]

    clip_x_dleta = th.clamp(x_delta, min=0.0, max=1 / action_bins)

    y = theta * clip_x_dleta

    y = th.sum(y, dim=-1)

    y = y * (high - low) + low

    # calculate \partial a / \partial e
    index = th.bucketize(x, x_vec) - 1
    slopes = theta[index]

    return y, slopes


def reshape_action(x, theta, low=0., high=1., need_slope = False):
    if not th.is_tensor(theta):
        theta = th.tensor(theta)

    x = th.tensor(x,device=theta.device)
    y, slopes = reshape_action_for_tensor(x=x,theta=theta,low=low,high=high)
    if need_slope:
        return y.detach().cpu().numpy(), slopes.detach().cpu().numpy()
    else:
        return y.detach().cpu().numpy()


def inverse_reshape_action_for_tensor(y, theta, low=0., high=1.):
    # Normalize y to [0, 1]
    y = (y - low) / (high - low)

    action_bins = len(theta)

    # Softmax to Make sure \sum (k_i) = 1
    theta = th.nn.functional.softmax(theta, dim=-1).to(y.device)

    # Clip if needed
    theta = clipped(theta)

    y_vec = th.cumsum(theta, dim=-1) - theta

    y_delta = y[:, None] - y_vec[None, :]

    clip_y_dleta = th.clamp(y_delta, min=th.zeros_like(theta), max=theta)

    inv_theta = 1./ (theta * action_bins)

    x = inv_theta * clip_y_dleta

    x = th.sum(x, dim=-1)

    x = x * (high - low) + low

    return x


def inverse_reshape_action(y, theta, low=0., high=1.):
    if not th.is_tensor(theta):
        theta = th.tensor(theta)

    y = th.tensor(y,device=theta.device)
    x = inverse_reshape_action_for_tensor(y=y,theta=theta,low=low,high=high)

    return x.detach().cpu().numpy()




if __name__ == "__main__":
    import matplotlib.pyplot as plt
    x = np.linspace(0, 1, num=1000)
    k = np.array([0., -1., 1., 0.])

    y1= reshape_action(x, theta=k, low=0, high=1)

    y2 = inverse_reshape_action(x, theta=k, low=0, high=1)

    # y3 & y4 should be the same with y=x
    y3 = reshape_action(y2, theta=k, low=0, high=1)

    y4 = inverse_reshape_action(y1, theta=k, low=0, high=1)

    plt.figure(figsize=(10, 6))
    plt.plot(x, y1, label='Reshape function f')
    plt.plot(x, y2, label='Inverse Reshape function f^-1')
    plt.plot(x, y3, label='f(f^-1(x))')
    plt.plot(x, y4, label='f^-1(f(x))')

    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Reshape Function Plot with theta='+str(k))
    plt.grid(True)
    plt.legend()
    plt.show()
