import numpy as np
from scipy.stats import norm
from scipy.optimize import fsolve, minimize
from sklearn.linear_model import LogisticRegression
import math, itertools

## For quick update of Vinv
def sherman_morrison(X, V, w=1):
    result = V-(w*np.einsum('ij,j,k,kl -> il', V, X, X, V))/(1.+w*np.einsum('i,ij,j ->', X, V, X))
    return result

## Pareto Front via index [Y(1), ..., Y(K)]
def pareto_front(Y):
    K= Y.shape[0]
    pareto_index = [i for i in range(K)]
    for i in range(K):
        for j in pareto_index:
            if np.max(Y[i,:] - Y[j,:]) < 0:
                pareto_index.remove(i)
                break    
    return pareto_index

# Find theta minimize the target function  

from scipy.optimize import minimize

def minimize_in_unit_ball(F, dim):
    v0 = np.zeros(dim)
    constraints = {'type': 'ineq', 'fun': lambda v: 1 - np.linalg.norm(v)}
    result = minimize(F, v0, constraints=constraints, method='SLSQP')
    return result.x


'''
P-UCB
'''
class P_UCB:
    def __init__(self, m, K, TPF_size):  
        self.K=K
        self.m=m
        self.t=1
        self.TPF_size=TPF_size
        self.num = np.zeros(K)
        self.y_hat = np.array([np.zeros(m) for i in range(self.K)])

    def select_ac(self):       
        if self.t>K:
            const=math.sqrt(2*math.log(self.t)+(1/2)*math.log(self.m*self.TPF_size))
            width = const*np.array([([np.sqrt(1/n_i) for n_i in self.num]) for j in range(self.m)])
            ucbs = self.y_hat + width.T
            empirical_pareto_front= pareto_front(ucbs)
            a_t = np.random.choice(empirical_pareto_front)
        else:
            a_t = self.t-1
        return(a_t)

    def update(self,a_t,reward):
        self.y_hat[a_t]=(self.y_hat[a_t]*self.num[a_t]+reward)/(self.num[a_t]+1)
        self.num[a_t]+=1
        self.t+=1

'''
MOGLM UCB
'''
class MOGLM_UCB:
    def __init__(self, d, m, c=1, sig=0.1, lam=None, k=1):
        #c: tuning parameter
        self.d=d
        self.m=m
        self.c=c
        self.T=T
        self.sig=sig
        self.k=k
        if lam==None:
            self.lam=max(1, k/2)
        else:
            self.lam=lam

        self.t=1
        self.yx=np.array([np.zeros(d) for j in range(self.m)])
        self.V=self.lam*np.eye(d)
        self.Vinv=(1/self.lam)*np.eye(d)
        self.X_a=np.zeros(d)
        self.target=np.zeros(d)
        self.theta_hat = np.array([np.zeros(d) for j in range(self.m)])
        self.settings = {'lambda':self.lam, 'c': self.c}

    def select_ac(self, contexts):
        means = np.matmul(contexts, self.theta_hat.T)
        stds = np.array([([np.sqrt(X.T @ self.Vinv@ X) for X in contexts]) for j in range(self.m)])
        width= self.c*(math.log(np.linalg.det(self.V)-d*math.log(self.lam)))
        ucbs = means + width*stds.T
        empirical_pareto_front= pareto_front(ucbs)
        a_t = np.random.choice(empirical_pareto_front)
        self.X_a = contexts[a_t,:]
        return(a_t)
        
    def theta_F(self, theta):
        return (theta-self.target).T @ self.V @ (theta-self.target)   
    
    def update(self,reward):
        self.V+= (self.k/2)*(self.X_a.reshape(-1,1) @ self.X_a.reshape(1,-1))
        self.Vinv = sherman_morrison(self.X_a, self.Vinv, self.k/2)
        for j in range(self.m):
            grad_l=-reward[j]*self.X_a+np.inner(self.theta_hat[j,:],self.X_a)*self.X_a
            self.target=self.theta_hat[j]-np.matmul(self.Vinv,grad_l.T)
            self.theta_hat[j]= minimize_in_unit_ball(self.theta_F, self.d)  
        self.t+=1


## Initial objective parameters

def initial_thetalist(m, d, contexts=None, reducing=False):
    #context: to get exploration-facilitating 
    initial_vectors = []
    
    # if context vec is stochastic : use diverse the most diverse theta in positive part of  
    if contexts is None:
        vnum=m
        while (True):
            for num1 in range (1, m+1):
                for indices in itertools.combinations(range(d), num1):
                    vector = np.zeros(d)
                    for index in indices:
                        vector[index] = 1
                    initial_vectors.append(vector/np.linalg.norm(vector))
                    vnum-=1 
                    if vnum==0:
                        return np.array(initial_vectors)

        
    #fixed context : use exploration facilitating (reduce candidates for large m)
    else: 
        K=len(contexts)
        combinations = itertools.combinations(range(K), m)
    
        if reducing==True:
            lengths = np.array([np.linalg.norm(vec) for vec in contexts])
            sorted_indices = np.argsort(lengths)[::-1]
            if m>5:        
                contexts = [contexts[i] for i in sorted_indices[:int(m+2)]]
                combinations = itertools.combinations(range(int(m+1)), m)
            else: 
                contexts = [contexts[i] for i in sorted_indices[:int(1.5*m)]]
                combinations = itertools.combinations(range(int(1.5*m)), m)
        mineigen=0
        mincombi=None

        for combi in combinations:
            V=np.zeros((d,d))
            for index in combi: 
                V+=contexts[index].reshape(-1,1) @ contexts[index].reshape(1,-1) 
            if np.min(np.linalg.eigvals(V))>mineigen:
                mineigen=np.min(np.linalg.eigvals(V))
                mincombi=combi
        #random unit (m-d) vectors 
        for index in mincombi:
            ini_vector = contexts[index]/np.linalg.norm(contexts[index])
            initial_vectors.append(ini_vector)
    
        return np.array(initial_vectors) 

'''
MORR-Greedy
'''
class MORR_Greedy:
    def __init__(self, d, m, lam=1, T=None, sig=0.1, lam_0=None, inTlist=None, contexts=None):
        # lam: update thresholds for min_eigenvalue, T&lam_0 is needed to computing theoretical lambda, 
        # contexts is needed to make exploration facilating initial values
        self.d=d
        self.m=m
        
        if inTlist is None:
            if contexts is None:
                self.ini_theta=initial_thetalist(m, d)
                
            else:
                self.ini_theta=initial_thetalist(m, d, contexts, reducing=True)
        else:
            self.ini_theta=inTlist
            
        self.lam=lam
        self.settings = {'lambda': self.lam, 'initial_objective_parameters:':self.ini_theta}

        self.lam_t=0        
        self.V=np.zeros((d,d))
        self.after_first_update=False 
        self.yx=np.array([np.zeros(d) for j in range(self.m)])
        self.Vinv=np.zeros((d,d))
        self.theta_hat = np.array([np.zeros(d) for j in range(self.m)])
        

    def select_ac(self, contexts, j):
        if self.after_first_update==True:
            means = np.matmul(contexts, self.theta_hat.T)
        else: 
            means = np.matmul(contexts, self.ini_theta.T)
        a_t = np.argmax(means[:,j])
        self.X_a = contexts[a_t,:]
        return(a_t)

    def update(self,reward):
        self.yx = np.array([self.yx[j,:]+reward[j]*self.X_a for j in range(self.m)]) 
        
        if self.after_first_update==True:
            self.Vinv = sherman_morrison(self.X_a, self.Vinv)
            self.theta_hat = np.array([self.Vinv @ self.yx[j,:] for j in range(self.m)])
            
        elif self.lam_t<=self.lam:
            self.V+= self.X_a.reshape(-1,1) @ self.X_a.reshape(1,-1) 
            self.lam_t=np.min(np.linalg.eigvals(self.V))
            if self.lam_t>self.lam:
                self.Vinv=np.linalg.inv(self.V)
                self.theta_hat = np.array([self.Vinv @ self.yx[j,:] for j in range(self.m)])
                self.after_first_update=True
                print("End Exploration")
        else:
            print("Something is wrong!")