import numpy as np
from scipy.special import expit

SIGMA1 = 1.0
SIGMA0 = 0.3
X_C = -1
X_P = 0
X_H = 1.3
C = 1
cc = 0.1
P_Y = 0.95
NON_STATIONARY = 1


def sigma_func(sigma, tt):
    return sigma * expit(0.4 * tt)


def clinician_tx(s, a, tt, sigma_0=SIGMA0, sigma_1=SIGMA1):
    # y = np.random.binomial(1, p=expit(s+2), size=1)[0]
    if a == 1:
        mn = s + NON_STATIONARY * 0.03 * 1
        scale = NON_STATIONARY * sigma_func(sigma_0, tt) + (1 - NON_STATIONARY) * sigma_0
        samp = np.random.normal(mn,
                                scale=scale,
                                size=1)
        return samp, mn, scale

    else:
        mn = s - NON_STATIONARY * 0.1 * tt - (1 - NON_STATIONARY) * 0.1
        scale = NON_STATIONARY * sigma_func(sigma_1, tt) + (1 - NON_STATIONARY) * sigma_1
        samp = np.random.normal(mn, scale=scale, size=1)
        return samp, mn, scale


def policy_tx(s, a, tt, sigma_0=SIGMA0, sigma_1=SIGMA1, xp=X_P, xc=X_C):
    # y = np.random.binomial(1, p=expit(s + 2), size=1)[0]
    if s >= xp or s <= xc:
        if a == 1:
            return np.random.normal(s + 0.1,
                                    scale=NON_STATIONARY * sigma_func(sigma_0, tt) + (1 - NON_STATIONARY) * sigma_0,
                                    size=1)
        else:
            return np.random.normal(s - 0.001,
                                    scale=NON_STATIONARY * sigma_func(sigma_0, tt) + (1 - NON_STATIONARY) * sigma_0,
                                    size=1)
    else:
        if a == 1:
            return np.random.normal(s + 0.1,
                                    scale=NON_STATIONARY * sigma_func(sigma_0, tt) + (1 - NON_STATIONARY) * sigma_0,
                                    size=1)
        else:
            return np.random.normal(s - NON_STATIONARY * 0.1 * tt - (1 - NON_STATIONARY) * 0.1,
                                    scale=NON_STATIONARY * sigma_func(sigma_1, tt) + (1 - NON_STATIONARY) * sigma_1,
                                    size=1)


def reward_fn(x, xc=X_C, xh=X_H, CC=C, cu=cc):
    return -CC * int(x < xc) + cu * int(x > xh) + 0 * int(xc <= x <= xh)


def target_policy_func(x, xp=X_P, xc=X_C, pp=P_Y):
    if x >= xp or x <= xc:
        return pp
    else:
        return 1 - pp


def target_policy_func_t(x, t, xp=X_P, xc=X_C, pp=P_Y):
    if x >= xp or x <= xc:
        return pp
    else:
        return 1 - pp


def clinician_policy_func(x, xp=X_P, xc=X_C, pp=P_Y):
    return pp


def clinician_policy_func_t(x, t, xp=X_P, xc=X_C, pp=P_Y):
    return pp


def behavior_policy_func(x, xp=X_P, xc=X_C, pp=P_Y):
    return pp - 0.1


def behavior_policy_func_t(x, t, xp=X_P, xc=X_C, pp=P_Y):
    return pp - 0.1
