import random
import numpy as np
import math
from numpy.random import seed
from numpy.random import rand
from Environment import *
from tqdm import tqdm
from scipy.optimize import fmin_tnc
from scipy.optimize import minimize
from scipy.linalg import cholesky
from itertools import product

#################    
    
class Elimination:
    
    def compute_prob_loss(self, theta, x, y):
        means = np.dot(x, theta)
        u = np.exp(means.astype(float))
        SumExp = np.sum(u)+1
        if 1 in y:
            prob = u / SumExp
        else: 
            prob = 1 / SumExp
             
        return prob    
    def compute_prob_grad(self, theta, x, y):
        means = np.dot(x, theta)
        u = np.exp(means)
        SumExp = np.sum(u)+1
        prob = u / SumExp
   

        return prob        
    def cost_function(self, theta, *args):
        x, y, S= args[0], args[1], args[2]
        loss=0
        for t in range(len(S)):
            if len(S[t])!=0:    
                x_=[x[i] for i in S[t]]
                y_=y[t]
                prob = self.compute_prob_loss(theta, x_, y_)
                loss+=-np.sum(np.multiply(y_, np.log(prob)))
        return loss + (1/2)*self.lamb* np.linalg.norm(theta)**2

    def gradient(self, theta, *args):
        x, y, S = args[0], args[1], args[2]
        m = 1
        grad=0
        for t in range(len(S)):
            if len(S[t])!=0:
                x_=x[S[t]]
                y_=y[t]
                prob = self.compute_prob_grad(theta, x_, y_)
                eps = (prob - y_)
                prod = eps[:, np.newaxis] * x_
                grad+=(1/m)*prod.sum(axis=0)
        grad=grad+self.lamb*theta
        return grad

    def fit(self, theta, *args):
        opt_weights = fmin_tnc(func=self.cost_function, x0=theta, fprime=self.gradient, args=args, ftol=1e-6, disp=False,approx_grad=True)
        w = opt_weights[0]
        return w
    
    def divide_into_groups(self, t,k):
        groups=[[] for _ in range(self.K)]
        l=min(self.L,len(self.A[k]))
        a=(l*(t-1))%len(self.A[k])
        b=(l*t)%len(self.A[k])
        if a<b:
            groups[k]=self.A[k][a:b]
        elif a==b:
            groups[k]=self.A[k]
        else:
            groups[k]=self.A[k][:b]+self.A[k][a:]

        # Split the shuffled array into K groups
        input_array=[]
        for k_ in range(self.K):
            input_array+=self.A[k]
        input_array = set(input_array)
        input_array=np.array(list(input_array.difference(set(groups[k]))))
                
        M=[]
        for S in self.M:
            if set(S[k])==set(groups[k]):
                M.append(S)
        if len(input_array)!=0:
            groups=random.choice(M)
   
        return groups

    def divide_into_groups_max(self,k):

        groups=[[] for _ in range(self.K)]
        groups[k]=self.A[k]
        input_array = np.arange(self.N)
        input_array=np.array(list(set(input_array).difference(set(self.A[k]))))

        R_sum=0
        tmp=0
        for S in self.M:
            if groups[k]==S[k]:
                tmp= self.upper_R(S)
                if tmp>R_sum:
                    R_sum=tmp
                    S_max=S
        return S_max
          
    def divide_into_groups_random(self):
        groups=[[] for _ in range(self.K)]
        # Split the shuffled array into K groups
        input_array = np.arange(self.N)
        A_=[[] for _ in range(self.N)]
        for k in range(self.K):
            for n in self.A[k]:
                A_[n].append(k)
        for n in input_array:
            groups[random.choice(A_[n])].append(n)


        return groups
    


    def match_elements(self, S, index):
        result = []

        for sublist in S:
            # Check if any element in the sublist matches any element in index
            matching_elements = [1 if element in index else 0 for element in sublist]
            result.append(matching_elements)

        return result
    
    def construct_M(self):
        A_=[[] for _ in range(self.N)]
        M=[]
        for n in range(self.N):
            A_[n].append(None)
        for k in range(self.K):
            for n in self.A[k]:
                A_[n].append(k)           
                
        combinations=list(product(*A_))
        for combination in combinations:
            S=[[] for _ in range(self.K)]
            for n in range(len(combination)):
                if combination[n]!=None: 
                    S[combination[n]].append(n)
            if all(len(sublist) <= self.L for sublist in S) and sum(len(sublist) for sublist in S)>=min(self.L*self.K,self.N):
                M.append(S)
        return M              

            
    def upper_R(self,S):
        R_sum=0
        for k in range(self.K):
            u = np.exp(self.p[k][S[k]])
            u_sum=np.sum(u)
            SumExp =  u_sum+1
            prob_sum =  u_sum / SumExp
            R_sum+=prob_sum
        return R_sum
    def lower_R(self,S):
        R_sum=0
        for k in range(self.K):
            u = np.exp(self.b[k][S[k]])
            u_sum=np.sum(u)
            SumExp =  u_sum+1
            prob_sum =  u_sum / SumExp
            R_sum+=prob_sum
        return R_sum    

    def S_argmax(self,n,k):
        R_sum=0
        tmp=0
        for S in self.M:
            if n in S[k]:
                tmp= self.upper_R(S)
                if tmp>R_sum:
                    R_sum=tmp
                    S_max=S
        return S_max
    
    def max_lower(self):
        R_sum=0
        tmp=0
        for S in self.M:
            tmp= self.lower_R(S)
            if tmp>R_sum:
                R_sum=tmp
        return R_sum
    
    def elimination(self):
        A=[[] for _ in range(self.K)]

        max_lower=self.max_lower()
        for k in range(self.K):
            for n in self.A[k]:                
                if max_lower<=self.upper_R(self.S_max[k][n]):
                    A[k].append(n)
        return A  
    
    def objective_function(self,pi, z):
        cov_matrix = sum(pi[i] * np.outer(z[i], z[i]) for i in range(len(pi)))
        log_det_cov = np.log(np.linalg.det(cov_matrix))
        return -log_det_cov

    def constraint(self,pi):
        return 1.0 - np.sum(pi)

    def find_optimal_policy(self,k):
        num_agents = len(self.A[k])
        # Initial guess for pi
        if num_agents ==0:
            return []
        
        initial_pi = np.ones(num_agents) / num_agents
        if num_agents ==1:
            return initial_pi

        else:
            z=[self.z[k][i] for i in self.A[k]]
            # Constraints
            constraints = ({'type': 'eq', 'fun': self.constraint})

            # Optimization
            result = minimize(self.objective_function, initial_pi, args=(z,), constraints=constraints, bounds=[(0, 1) for _ in range(num_agents)])

            if result.success:
                optimal_pi = result.x
                return optimal_pi
            else:
                return initial_pi
                raise ValueError("Optimization failed.")
    
    def reset(self):
        np.random.seed(self.seed)
        random.seed(self.seed)
        
    def __init__(self,seed,x,N,K,L,T):
        print('Elimination')
        np.random.seed(seed)
        random.seed(seed)
        self.seed=seed
        self.y_hist=[]
        self.x=x
        self.S=[]
        self.S_hist=[]
        self.r=np.zeros(K)
        self.kappa=0.1
        self.N=N
        self.K=K
        self.L=L
        self.T=T
        self.lamb=0
        self.h=np.zeros((N,K))
        self.V=[0]*self.K
        self.z=[[[] for _ in range(self.N)] for _ in range(self.K)]
        self.Ur=[]
        self.A=[]
        self.T1=np.zeros(self.K)
        self.T2=[None]*self.K
        self.t1=np.zeros(self.K)
        self.t2=np.zeros((self.K,self.N))
        self.T_tau=(3/10)*math.log(self.K*self.T*self.N)
        self.start_epoch=True
        self.bool1=False
        self.bool2=False
        self.bool_initial=True
        self.k=0
        self.n=0
        self.beta=(1/30)*(1/self.kappa)*np.sqrt(math.log(self.T*self.K*self.N))
        self.p=np.zeros((self.K,self.N))
        self.b=np.zeros((self.K,self.N))
        self.M=[]
        self.pi=[None]*self.K
        self.S_max= [[[] for _ in range(self.N)] for _ in range(self.K)]
        self.theta=[None]*self.K
        self.theta_0=[None]*self.K 
        self.epoch=0
        self.S_=[]
        self.initial=True
        self.bool_play1=False
        self.bool_play2=False
        self.bool_warmup=True
        self.bool_main=False
        self.main_start=0
        for k in range(self.K):
            self.A.append(list(range(self.N)))
    def run(self,t,index):   
        if self.bool_warmup==True:
            if self.initial==True:
                self.M=self.construct_M()
                for k in range(self.K):
                    X=self.x[self.A[k],:]
                    U, Sigma, Vt = np.linalg.svd(X.T)
                    self.r[k]=np.count_nonzero(Sigma)
                    self.Ur=U[:,:int(self.r[k])]
                    z=(X@self.Ur)
                    (w,v) = np.linalg.eigh(z.T@z)
                    lambda_min= np.amin(w)
                    for i,n in enumerate(self.A[k]):
                        self.z[k][n]=z[i]  
                    self.T1[k]=(1/800)*int(len(self.A[k])/(self.L*self.kappa**2*lambda_min*math.log(self.T*self.K*self.N))*(self.r[k]+math.log(self.T*self.K*self.N))**2)
                self.t1[0]=1
                self.k=0
                self.initial=False
        
            if t!=1:
                y=self.match_elements(self.S, index) 
                self.y_hist.append(y)
                self.S_hist.append(self.S)        
                for k in range(self.K):
                    g=np.array([self.z[k][i] for i in self.S[k]])
                    self.V[k]+=g.T@g

            if t <= self.t1[self.k]+self.T1[self.k]-1:

                self.S = self.divide_into_groups( t, self.k)
            elif self.k!=self.K-1:
                self.k+=1
                self.t1[self.k]=t
                self.S = self.divide_into_groups( t, self.k)

            else:
                self.bool_warmup=False
                self.bool_main=True
                self.main_start=t
                self.initial=True

        if self.bool_main==True:
            self.bool_play1=False
            while self.bool_play1==False:
                if self.initial==True:     
                    if self.main_start!=t:
                        for k in range(self.K):
                            y_k=[sublist[k] for sublist in self.y_hist]
                            S_k=[sublist[k] for sublist in self.S_hist]
                            self.theta[k]=self.fit(self.theta_0[k],self.z[k],y_k,S_k)
                            for n in self.A[k]:
                                if len(self.theta[k])>1:
                                    self.p[k][n]=self.z[k][n]@self.theta[k]+self.beta*np.sqrt(self.z[k][n]@np.linalg.inv(self.V[k])@self.z[k][n])
                                    self.b[k][n]=self.z[k][n]@self.theta[k]-self.beta*np.sqrt(self.z[k][n]@np.linalg.inv(self.V[k])@self.z[k][n])
                                else:
                                    self.p[k][n]=self.z[k][n]@self.theta[k]+self.beta*np.sqrt(self.z[k][n]*(1/self.V[k])*self.z[k][n])
                                    self.b[k][n]=self.z[k][n]@self.theta[k]-self.beta*np.sqrt(self.z[k][n]*(1/self.V[k])*self.z[k][n])                            
                        for k in range(self.K):
                            for n in self.A[k]:
                                self.S_max[k][n]=self.S_argmax(n,k)
                        self.A=self.elimination()
                        self.T_tau=2*self.T_tau
                        self.epoch+=1                                    
                    
                    self.M=self.construct_M()
                    for k in range(self.K):
                        self.pi[k] = self.find_optimal_policy(k)
                        if len(self.pi[k])==0:
                            self.T2[k]=0
                        else:
                            self.T2[k]=[math.ceil(x) for x in self.pi[k]*self.r[k]*self.T_tau]
                        self.theta_0[k]=np.ones(int(self.r[k]))
                        self.theta[k]=np.zeros(int(self.r[k]))

                    self.k=0
                    while len(self.A[self.k])==0 and self.k!=self.K-1:
                        self.k+=1

                    self.n=self.A[self.k][0]
                    self.t2[self.k][self.n]=t


                    self.initial=False

                    for k in range(self.K):
                        y_k=[sublist[k] for sublist in self.y_hist]
                        S_k=[sublist[k] for sublist in self.S_hist]
                        self.theta[k]=self.fit(self.theta_0[k],self.z[k],y_k,S_k)

                        for n in self.A[k]:
                            if len(self.theta[k])>1:
                                self.p[k][n]=self.z[k][n]@self.theta[k]+self.beta*np.sqrt(self.z[k][n]@np.linalg.inv(self.V[k])@self.z[k][n])
                                self.b[k][n]=self.z[k][n]@self.theta[k]-self.beta*np.sqrt(self.z[k][n]@np.linalg.inv(self.V[k])@self.z[k][n])
                            else:
                                self.p[k][n]=self.z[k][n]@self.theta[k]+self.beta*np.sqrt(self.z[k][n]*(1/self.V[k])*self.z[k][n])
                                self.b[k][n]=self.z[k][n]@self.theta[k]-self.beta*np.sqrt(self.z[k][n]*(1/self.V[k])*self.z[k][n])
                    for k in range(self.K):
                        for n in self.A[k]:
                            self.S_max[k][n]=self.S_argmax(n,k)  

                if self.main_start!=t and self.bool_play2==True:
                    y=self.match_elements(self.S, index) 
                    self.y_hist.append(y)
                    self.S_hist.append(self.S)
                    for k in range(self.K):
                        g=np.array([self.z[k][i] for i in self.S[k]])
                        self.V[k]+=g.T@g

                    self.bool_play2=False

                ind_n=self.A[self.k].index(self.n)
                if t<= self.t2[self.k][self.n]+self.T2[self.k][ind_n]-1:
                    self.S=self.S_max[self.k][self.n]       
                    self.bool_play1=True
                    self.bool_play2=True
                elif self.n!=self.A[self.k][-1]:
                    self.n=self.A[self.k][ind_n+1]
                    self.t2[self.k][self.n]=t
                elif self.k!=self.K-1:
                    self.k+=1
                    while len(self.A[self.k])==0 and self.k!=self.K-1:
                        self.k+=1
                    if len(self.A[self.k])!=0:
                        self.n=self.A[self.k][0]
                        self.t2[self.k][self.n]=t
                    if len(self.A[self.k])==0 and self.k==self.K-1:
                        self.initial=True
                else:
                    self.initial=True

    def offer(self):
        return self.S   
    def name(self):
        return 'Elimination'    
    

###########

class ETC_GS:
    
    def construct_M(self):
        A_=[[] for _ in range(self.N)]
        M=[]
        for n in range(self.N):
            A_[n].append(None)
        for k in range(self.K):
            for n in range(self.N):
                A_[n].append(k)           
                
        combinations=list(product(*A_))
        for combination in combinations:
            S=[[] for _ in range(self.K)]
            for n in range(len(combination)):
                if combination[n]!=None: 
                    S[combination[n]].append(n)
            if all(len(sublist) <= self.L for sublist in S) and sum(len(sublist) for sublist in S)>=min(self.L*self.K,self.N):
                M.append(S)
        return M              
                
    def divide_into_groups(self, K, N):
        # Generate the input array [0, 1, ..., N]
        input_array = np.arange(N)

        # Split the shuffled array into K groups
        groups = np.array_split(input_array, K)

        return groups
    def match_elements(self, S, index):
        result = []

        for sublist in S:
            # Check if any element in the sublist matches any element in index
            matching_elements = [1 if element in index else 0 for element in sublist]
            result.append(matching_elements)

        return result
 
    def generate_preferences(self, ucb):
        n, k = ucb.shape
        men_preferences = [list(np.argsort(-ucb[n_, :])) for n_ in range(n)]  # Sorting in descending order
        women_preferences = [list(np.argsort(-ucb[:, k_])) for k_ in range(k)]  # Sorting in descending order
        return men_preferences, women_preferences


    def stable_matching(self, men_preferences, women_preferences):
        # Number of men and women
        n = len(men_preferences)
        k = len(women_preferences)
        # Initialize arrays to store matching
        engaged_to = [None] * n  # engaged_to[w] represents the man engaged to woman w
        men_status = [0] * n  # men_status[m] represents the index of the next woman to propose to
        pre_engaged= [None] * n
        while None in engaged_to:            
            for man in range(n):
                if engaged_to[man] is None:
                    # Get the next woman to propose to
                    woman = men_preferences[man][men_status[man]]
                    men_status[man] += 1

                    # Check if the woman is not engaged
                    if woman not in engaged_to:
                        engaged_to[man] = woman
                    else:
                        # Woman is engaged, check her preferences
                        current_man = engaged_to.index(woman)
                        if women_preferences[woman].index(man) < women_preferences[woman].index(current_man):
                            # Woman prefers the current proposal
                            engaged_to[current_man] = None
                            engaged_to[man] = woman
            if pre_engaged==engaged_to:
                break
            pre_engaged=engaged_to
        return engaged_to

    def reset(self):
        np.random.seed(self.seed)
        random.seed(self.seed)
        
    def __init__(self,seed,x,N,K,L,T,delta):
        print('ETC_GS')
        np.random.seed(seed)
        random.seed(seed)
        self.seed=seed
        self.y_hist=[]
        self.x=x
        self.S=[]
        self.S_hist=[]
        self.r=0
        self.alpha=1
        self.kappa=0.2
        self.N=N
        self.K=K
        self.L=L
        self.T=T
        self.lamb=1
        self.h=np.zeros((N,K))
        self.V=[]
        self.z=[]
        self.Ur=[]
        self.mean=np.zeros((self.N,self.K))
        self.n=np.ones((self.N,self.K))
        self.ucb=np.zeros((self.N,self.K))
        self.explore=True
        self.delta=delta
        self.h=4/delta**2*math.log(1+(T*delta**2*N/4))
    def run(self,t,index):   
        if self.explore==True:
            if t!=1:
                y=self.match_elements(self.S, index) 
                for k in range(self.K):
                    for i,n in enumerate(self.S[k]):
                        self.n[n,k]+=1
                        self.mean[n,k]=((self.n[n,k]-1)*self.mean[n,k]+y[k][i])/self.n[n,k]
#                 for k in range(self.K):
#                     for n in range(self.N):
#                         self.ucb[n,k]=self.mean[n,k]+np.sqrt(3*math.log(t)/(2*self.n[n,k]))
                    
            self.S=[[] for _ in range(self.K)]
            for n in range(self.N):
                self.S[(t+n-1)%self.K].append(n)
            if t>self.h*self.K:
                men_preferences, women_preferences = self.generate_preferences(self.mean)
                result = self.stable_matching(men_preferences, women_preferences)
                self.S=[[] for _ in range(self.K)]
                for n,k in enumerate(result):
                    if k!= None:
                        self.S[k].append(n)
                self.explore=False
            # print(self.S)
            
    def offer(self):
        return self.S   
    
    def name(self):
        return 'ETC-GS'    
    
    
########################
    
    
class UCB_GS:
        
        
    def construct_M(self):
        A_=[[] for _ in range(self.N)]
        M=[]
        for n in range(self.N):
            A_[n].append(None)
        for k in range(self.K):
            for n in range(self.N):
                A_[n].append(k)           
                
        combinations=list(product(*A_))
        for combination in combinations:
            S=[[] for _ in range(self.K)]
            for n in range(len(combination)):
                if combination[n]!=None: 
                    S[combination[n]].append(n)
            if all(len(sublist) <= self.L for sublist in S) and sum(len(sublist) for sublist in S)>=min(self.L*self.K,self.N):
                M.append(S)
        return M              
                
    def divide_into_groups(self, K, N):
        # Generate the input array [0, 1, ..., N]
        input_array = np.arange(N)
        # Split the shuffled array into K groups
        groups = np.array_split(input_array, K)

        return groups
    def match_elements(self, S, index):
        result = []

        for sublist in S:
            # Check if any element in the sublist matches any element in index
            matching_elements = [1 if element in index else 0 for element in sublist]
            result.append(matching_elements)

        return result
 
    def generate_preferences(self, ucb):
        n, k = ucb.shape
        men_preferences = [list(np.argsort(-ucb[n_, :])) for n_ in range(n)]  # Sorting in descending order
        women_preferences = [list(np.argsort(-ucb[:, k_])) for k_ in range(k)]  # Sorting in descending order
        return men_preferences, women_preferences


    def stable_matching(self, men_preferences, women_preferences):
        # Number of men and women
        n = len(men_preferences)
        k = len(women_preferences)
        # Initialize arrays to store matching
        engaged_to = [None] * n  # engaged_to[w] represents the man engaged to woman w
        men_status = [0] * n  # men_status[m] represents the index of the next woman to propose to
        pre_engaged= [None] * n
        while None in engaged_to:            
            for man in range(n):
                if engaged_to[man] is None:
                    # Get the next woman to propose to
                    woman = men_preferences[man][men_status[man]]
                    men_status[man] += 1

                    # Check if the woman is not engaged
                    if woman not in engaged_to:
                        engaged_to[man] = woman
                    else:
                        # Woman is engaged, check her preferences
                        current_man = engaged_to.index(woman)
                        if women_preferences[woman].index(man) < women_preferences[woman].index(current_man):
                            # Woman prefers the current proposal
                            engaged_to[current_man] = None
                            engaged_to[man] = woman
            if pre_engaged==engaged_to:
                break
            pre_engaged=engaged_to
        return engaged_to

    def reset(self):
        np.random.seed(self.seed)
        random.seed(self.seed)
        
    def __init__(self,seed,x,N,K,L,T):
        print('UCB_GS')
        np.random.seed(seed)
        random.seed(seed)
        self.seed=seed
        self.y_hist=[]
        self.x=x
        self.S=[]
        self.S_hist=[]
        self.N=N
        self.K=K
        self.L=L
        self.T=T
        self.V=[]
        self.z=[]
        self.Ur=[]
        self.mean=np.zeros((self.N,self.K))
        self.n=np.ones((self.N,self.K))
        self.ucb=np.zeros((self.N,self.K))
    def run(self,t,index):   
#         np.random.seed(seed)
        if t==1:
            self.M= self.construct_M()
            self.S = random.choice(self.M)
                                
        else:
            y=self.match_elements(self.S, index) 
            for k in range(self.K):
                for i,n in enumerate(self.S[k]):
                    self.n[n,k]+=1
                    self.mean[n,k]=((self.n[n,k]-1)*self.mean[n,k]+y[k][i])/self.n[n,k]
            for k in range(self.K):
                for n in range(self.N):
                    self.ucb[n,k]=self.mean[n,k]+np.sqrt(3*math.log(t)/(2*self.n[n,k]))
            
            men_preferences, women_preferences = self.generate_preferences(self.ucb)
            result = self.stable_matching(men_preferences, women_preferences)
            self.S=[[] for _ in range(self.K)]
            for n,k in enumerate(result):
                if k!= None:
                    self.S[k].append(n)
                        
    def offer(self):
        return self.S   
    
    def name(self):
        return 'UCB-GS'    
