import numpy as np

def sqrt_sigma(Sigma):
    D, V = np.linalg.eig(Sigma)
    D[D < 0] = np.min(D[D > 0])
    sqrt_Sigma = V @ np.diag(D) @ V.T
    return sqrt_Sigma


def generate_diag_matrix(p, T, n_arm,seed=0,case=1,decay=50):
    np.random.seed(int(seed))
    if case==1:
        Sigma = np.eye(p)
        cs = np.random.uniform(0.5, 1.5, size=n_arm)
        # Sigmas=np.array([c * Sigma for c in cs])
        Sigmas = np.array([Sigma for _ in range(n_arm)])
    if case==2:
        cs=np.random.uniform(0.5, 1.5, size=n_arm)
        lambdas = np.array([k**(-1 + 1/T) for k in range(1, p+1)])
        diag_matrix = np.diag(lambdas)
        Sigmas=np.array([c * diag_matrix for c in cs])

    return Sigmas



def generate_sparse_vectors(n_arm, p, rho, seed=0, norm_value=1):
    np.random.seed(seed)
    vectors = []
    num_non_zero = int(rho * p)
    for _ in range(n_arm):
        vector = np.zeros(p)
        non_zero_indices = np.random.choice(p, num_non_zero, replace=False)
        vector[non_zero_indices] = np.random.normal(loc=0.0, scale=1.0, size=num_non_zero)
        current_norm = np.linalg.norm(vector, ord=2)
        if current_norm > 0:
            vector = vector * (norm_value / current_norm)
        vectors.append(vector)
    return np.array(vectors)



def generate_data(betas, Sigmas, p, T,n_arm, seed=0,noise_variance=0.01):
    np.random.seed(seed)
    Gen_data=[]
    y_true=[]

    for arm_i in range(n_arm):
        X = np.random.multivariate_normal(mean=np.zeros(p), cov=Sigmas[arm_i], size=T)
        epsilons = np.random.normal(loc=0, scale=np.sqrt(noise_variance), size=T)
        Y = X.dot(betas[arm_i]) + epsilons
        y_true.append(X.dot(betas[arm_i]))
        Gen_data.append(np.concatenate((X, Y.reshape(-1, 1)), axis=1)) # X: (T, p) Y: (T,)
    return Gen_data,np.array(y_true)


def gen_data(p,T,n_arm,rho,seed=0,case=3,noise_variance=1, norm_value=10):
    Sigmas = generate_diag_matrix(p, T, n_arm, seed,case)
    betas = generate_sparse_vectors(n_arm, p, rho, seed, norm_value)
    Gen_data,y_true = generate_data(betas, Sigmas, p, T, n_arm, seed, noise_variance)
    np.savez(f"/data/p_{p}_T_{T}_n_{n_arm}_rho_{rho}_seed{seed}_case_{case}.npz", Sigmas=Sigmas, betas=betas, Data=Gen_data, Y_true=y_true)
