import numpy as np

def generate_diverse_theta(d, m, lambda0xm=1/4, seed=0):
     
    ## input: d, m, lambda0*m >0 (degree of diversity)
    ## output : m objective parameters
    while True:
        # Generate m vectors
        theta = []
        V=np.zeros((d,d))
        for _ in range(m):
            random_vector = np.random.randn(d) 
            random_vector = np.absolute(random_vector)
            unit_vector = random_vector / np.linalg.norm(random_vector)
            theta.append(unit_vector)
            unit_vector=np.array(unit_vector)
            V += unit_vector.reshape(-1,1) @ unit_vector.reshape(1,-1) 

        
        theta = np.array(theta)
        mineigen=np.min(np.linalg.eigvals(V))
        lam_0=np.min(np.linalg.eigvals(theta.T@theta))/m
        
        # Check if the result of f(vectors) > lambda_val
        if mineigen >= lambda0xm:
            return theta

## Pareto Front via index [Y(1), ..., Y(K)]
def pareto_front(Y):
    K= Y.shape[0]
    pareto_index = [i for i in range(K)]
    for i in range(K):
        for j in pareto_index:
            if np.max(Y[i,:] - Y[j,:]) < 0:
                pareto_index.remove(i)
                break    
    return pareto_index

def fixed_regular_cxts(K, d, Tlist, x_max=1, sigma=0.1, TPF_size=None, seed=0):

    ## Generate context satisfying regularity : m arms near to theta & K-m arms chosen random
    ## input : K, d, Tlist(thetalist), x_max, sigma(distance from theta)
    ## output : K context vectors 
    
    
    m=len(Tlist)
    if (m>K) or (len(Tlist[0])!=d):
        print("Invalid Input!")


    Tlist_normalized = [t / np.linalg.norm(t) for t in Tlist]
    iter=1

    while(True):
        # Generate m vectors that are similar to Tlist vectors    
        similar_vectors = []
        remaining_vectors = []
        
        for i in range(m):
            mean_vector = Tlist_normalized[i]
            cov_matrix = sigma * np.identity(d)
        
            # Draw a sample from the multivariate normal distribution
            sample = np.random.multivariate_normal(mean_vector, cov_matrix)
            sample_normalized = sample / np.linalg.norm(sample)
            radius=np.random.uniform(x_max*3/4, x_max)
            similar_vector = sample_normalized * radius        
            similar_vectors.append(similar_vector)
    
        # Generate K-m vectors uniformly within the ball of radius xmax    
        for _ in range(K - m):
            # Generate a random vector in d-dimensions from the standard normal distribution
            random_vector = np.random.randn(d)
            #random_vector = np.abs(random_vector)        
            norm = np.linalg.norm(random_vector)
            if _<3*m:
                scale = np.random.uniform(x_max*3/4, x_max)
            else:
                scale = np.random.uniform(0, x_max*3/4)
            random_vector = random_vector / norm * scale        
            remaining_vectors.append(random_vector)
    
        # Combine the similar and remaining vectors
        all_vectors = similar_vectors + remaining_vectors
        if TPF_size==None:
            return np.array(all_vectors)
        else: 
            exp_rewards=np.matmul(all_vectors, Tlist.T)   
            TPF_num=len(pareto_front(exp_rewards))
            if (TPF_num <=TPF_size) : #| (iter>100)
                return np.array(all_vectors)

        iter+=1 

def uniform_cxts(K, d, x_max=1, seed=0):
    ## Generate contexts
    # input: K, d
    # output: [X(1),...,X(K)] list with K uniformly distributed contexts in B_{xmax}^d
    #np.random.seed(seed) # For reproducibility
    X_directions = np.random.normal(size=(d,K))
    X_directions /= np.linalg.norm(X_directions, axis=0)
    X_radius = x_max * (np.random.random(K)) 
    return np.array((X_directions * X_radius).T)