import numpy as np
import time
import sys

# Basic Parameters and seeds
S, A, d, eps = 6, 4, 3, 5e-3

# seed = np.random.randint(100)
# print(seed)
# np.random.seed(seed)
np.random.seed(15)

# contexts
phi_ = np.random.uniform(size=(S, A, d))
mu = np.random.uniform(size=(d, S))
theta = np.random.uniform(size=(d))

P_ = phi_ @ mu
normalizer = P_.sum(axis=2, keepdims=True)
P = P_ / normalizer
phi = phi_ / normalizer
norm = np.linalg.norm(phi, axis=2).max()
phi /= norm
R = phi @ theta # we will rescale theta afterwards, but theta is still bounded
theta /= R.max()
R = R / R.max() + np.random.uniform(low=-eps, high=eps, size=(S, A))

# generate the perturbation on kernel
x = np.concatenate((eps * np.ones(S // 2), -eps * np.ones(S // 2)))
perturb = np.zeros((S, A, S))
for s in range(S):
    for a in range(A):
        np.random.shuffle(x)
        P[s, a] += x.copy() / (S // 2)

# calculate the optimal functions
Q_star = P @ R.max(axis=1)
V_star = Q_star.max(axis=1, keepdims=True)
gap = (V_star - Q_star)
delta = gap[gap > 0].min()
print(delta, delta / eps)

def LSVI(beta, gamma, K):
    # initialization
    w1, w2, b1, b2 = np.zeros(d), np.zeros(d), np.zeros(d), np.zeros(d)
    U1, U2, U1inv, U2inv = np.eye(d), np.eye(d), np.eye(d), np.eye(d)

    regret = []
    for k in range(K):
        # calculate the Q function
        UCB1 = np.sqrt(np.einsum('sad, dp, sap -> sa', phi, U1inv, phi))
        UCB2 = np.sqrt(np.einsum('sad, dp, sap -> sa', phi, U2inv, phi))
        Q1 = phi @ w1 + beta * UCB1
        Q2 = phi @ w2 + beta * UCB2

        # data collection
        s1 = np.random.choice(S)
        a1 = Q1[s1].argmax()
        s2 = np.random.choice(S, p=P[s1, a1])
        a2 = Q2[s2].argmax()

        # regret analysis
        V2_real = R[np.arange(S), Q2.argmax(axis=1)]
        Q1_real = P @ V2_real
        V1_real = Q1_real[np.arange(S), Q1.argmax(axis=1)]
        regret.append(np.mean(V_star - V1_real))

        # update
        if UCB1[s1, a1] > gamma or UCB2[s2, a2] > gamma:
            b1 += phi[s1, a1] * np.clip(Q2[s2, a2], a_min=0, a_max=2)
            U1 += np.outer(phi[s1, a1], phi[s1, a1])
            b2 += phi[s2, a2] * R[s2, a2]
            U2 += np.outer(phi[s2, a2], phi[s2, a2])
            U1inv, U2inv = np.linalg.inv(U1), np.linalg.inv(U2)
            w1, w2 = U1inv @ b1, U2inv @ b2
    return np.array(regret)

if __name__ == '__main__':
    np.random.seed(int(time.time()))
    reg_list = np.vstack([LSVI(3, float(sys.argv[1]), K=3_000_000) for _ in range(8)])
    np.save('{}.npy'.format(float(sys.argv[1])), reg_list)
    print(float(sys.argv[1]), reg_list.sum(axis=1).mean())  