import numpy as np
from scipy.optimize import minimize, linprog
import itertools, copy
def generate_diverse_theta(d, m, lambda0xm=1/4, seed=0):
     
    ## input: d, m, lambda0*m >0 (degree of diversity)
    ## output : m objective parameters
    while True:
        # Generate m vectors
        theta = []
        V=np.zeros((d,d))
        for _ in range(m):
            random_vector = np.random.randn(d) 
            random_vector = np.absolute(random_vector)
            unit_vector = random_vector / np.linalg.norm(random_vector)
            theta.append(unit_vector)
            unit_vector=np.array(unit_vector)
            V += unit_vector.reshape(-1,1) @ unit_vector.reshape(1,-1) 

        
        theta = np.array(theta)
        # mineigen=np.min(np.linalg.eigvals(V))
        # lam_0=np.min(np.linalg.eigvals(theta.T@theta))/m
        
        # Check if the result of f(vectors) > lambda_val
        #if mineigen >= lambda0xm:
        return theta

## Initial objective parameters

def initial_thetalist(m, d, contexts=None, reducing=False):
    #context: to get exploration-facilitating 
    initial_vectors = []
    
    # if context vec is stochastic : use diverse the most diverse theta in positive part of  
    if contexts is None:
        vnum=m
        while (True):
            for num1 in range (1, m+1):
                for indices in itertools.combinations(range(d), num1):
                    vector = np.zeros(d)
                    for index in indices:
                        vector[index] = 1
                    initial_vectors.append(vector/np.linalg.norm(vector))
                    vnum-=1 
                    if vnum==0:
                        return np.array(initial_vectors)

        
    #fixed context : use exploration facilitating (reduce candidates for large m)
    else: 
        K=len(contexts)
        combinations = itertools.combinations(range(K), m)
    
        if reducing==True:
            lengths = np.array([np.linalg.norm(vec) for vec in contexts])
            sorted_indices = np.argsort(lengths)[::-1]
            if m>5:        
                contexts = [contexts[i] for i in sorted_indices[:int(m+2)]]
                combinations = itertools.combinations(range(int(m+1)), m)
            else: 
                contexts = [contexts[i] for i in sorted_indices[:int(1.5*m)]]
                combinations = itertools.combinations(range(int(1.5*m)), m)
        mineigen=0
        mincombi=None

        for combi in combinations:
            V=np.zeros((d,d))
            for index in combi: 
                V+=contexts[index].reshape(-1,1) @ contexts[index].reshape(1,-1) 
            if np.min(np.linalg.eigvals(V))>mineigen:
                mineigen=np.min(np.linalg.eigvals(V))
                mincombi=combi
        #random unit (m-d) vectors 
        for index in mincombi:
            ini_vector = contexts[index]/np.linalg.norm(contexts[index])
            initial_vectors.append(ini_vector)
    
        return np.array(initial_vectors) 



def fixed_regular_cxts(K, d, Tlist, x_max=1, sigma=0.1, TPF_size=None, seed=0):

    ## Generate context satisfying regularity : m arms near to theta & K-m arms chosen random
    ## input : K, d, Tlist(thetalist), x_max, sigma(distance from theta)
    ## output : K context vectors 
    
    
    m=len(Tlist)
    if (m>K) or (len(Tlist[0])!=d):
        print("Invalid Input!")


    Tlist_normalized = [t / np.linalg.norm(t) for t in Tlist]
    iter=1

    while(True):
        # Generate m vectors that are similar to Tlist vectors    
        similar_vectors = []
        remaining_vectors = []
        
        for i in range(m):
            mean_vector = Tlist_normalized[i]
            cov_matrix = sigma * np.identity(d)
        
            # Draw a sample from the multivariate normal distribution
            sample = np.random.multivariate_normal(mean_vector, cov_matrix)
            sample_normalized = sample / np.linalg.norm(sample)
            radius=np.random.uniform(x_max*3/4, x_max)
            similar_vector = sample_normalized * radius        
            similar_vectors.append(similar_vector)
    
        # Generate K-m vectors uniformly within the ball of radius xmax    
        for _ in range(K - m):
            # Generate a random vector in d-dimensions from the standard normal distribution
            random_vector = np.random.randn(d)
            #random_vector = np.abs(random_vector)        
            norm = np.linalg.norm(random_vector)
            if _<3*m:
                scale = np.random.uniform(x_max*3/4, x_max)
            else:
                scale = np.random.uniform(0, x_max*3/4)
            random_vector = random_vector / norm * scale        
            remaining_vectors.append(random_vector)
    
        # Combine the similar and remaining vectors
        all_vectors = similar_vectors + remaining_vectors
        if TPF_size==None:
            return np.array(all_vectors)
        else: 
            exp_rewards=np.matmul(all_vectors, Tlist.T)   
            TPF_num=len(pareto_front(exp_rewards))
            if (TPF_num <=TPF_size) : #| (iter>100)
                return np.array(all_vectors)

        iter+=1 

def uniform_cxts(K, d, x_max=1, seed=0):
    ## Generate contexts
    # input: K, d
    # output: [X(1),...,X(K)] list with K uniformly distributed contexts in B_{xmax}^d
    #np.random.seed(seed) # For reproducibility
    X_directions = np.random.normal(size=(d,K))
    X_directions /= np.linalg.norm(X_directions, axis=0)
    X_radius = x_max * (np.random.random(K)) 
    return np.array((X_directions * X_radius).T)


## For quick update of Vinv
def sherman_morrison(X, V, w=1):
    result = V-(w*np.einsum('ij,j,k,kl -> il', V, X, X, V))/(1.+w*np.einsum('i,ij,j ->', X, V, X))
    return result


# Find theta minimize the target function  
def minimize_in_unit_ball(F, dim):
    v0 = np.zeros(dim)
    constraints = {'type': 'ineq', 'fun': lambda v: 1 - np.linalg.norm(v)}
    result = minimize(F, v0, constraints=constraints, method='SLSQP')
    return result.x

## Pareto Front via index [Y(1), ..., Y(K)]
def pareto_front(Y):
    K= Y.shape[0]
    pareto_index = [i for i in range(K)]
    for i in range(K):
        for j in pareto_index:
            if np.max(Y[i,:] - Y[j,:]) < 0:
                pareto_index.remove(i)
                break    
    subopt_gap=[max(min(Y[i,:]-Y[j,:]) for i in pareto_index) for j in range(K)]
    return pareto_index, subopt_gap


def pareto_front_v2(Y, p_idx):
    K = Y.shape[0]
    regret = [0 for _ in range(K)]
    pareto_index = copy.deepcopy(p_idx)
    Y_ = Y[pareto_index]
    w_size = Y_.shape[0]
    constraints = [{'type': 'eq', 'fun': lambda x: np.sum(x) - 1}]
    bounds = [(0, 1) for _ in range(w_size)]
    for i in range(K):
        def maximize_regret(weight):
            _max = weight.T @ Y_
            return np.max(Y[i,:] - _max) 
        for _ in range(2 * w_size):
            init_w = np.random.exponential(scale = 1.0, size = w_size)
            ans = minimize(maximize_regret, init_w,  bounds = bounds, constraints = constraints)
            if maximize_regret(ans.x) < -1e-5:
                regret[i] = min(regret[i], maximize_regret(ans.x))        
        if regret[i] < 0:
            if i in pareto_index:
                pareto_index.remove(i)
        regret[i] = -regret[i]
    return pareto_index, regret
    