import random
import numpy as np
import math
from numpy.random import seed
from numpy.random import rand
from scipy.stats import bernoulli
from itertools import combinations
from itertools import product

class linear_Env:
    def __init__(self,seed,d,N,K,L):
        np.random.seed(seed)
        random.seed(seed)
        self.d=d
        self.N=N
        self.K=K
        self.L=L
        self.x=np.zeros((self.N,self.d)) ## observed feature
        self.theta=np.zeros((self.K,self.d))
        for n in range(self.N):
            self.x[n]=np.random.uniform(-1,1,self.d)
            self.x[n]=self.x[n]/np.sqrt(np.sum(self.x[n]**2))
            # self.x[n]=np.random.normal(0, 1, self.d)
            # self.x[n]=self.x[n]/np.sqrt(np.sum(self.x[n]**2))
            # print(self.x[n])
        for k in range(self.K):
            self.theta[k]=np.random.uniform(-1,1,self.d)
            self.theta[k]=self.theta[k]/np.sqrt(np.sum(self.theta[k]**2))
        self.w=self.x@self.theta.T
        self.index=np.zeros(self.K)
        self.p=np.zeros(self.K+1)
        self.rev=np.zeros((self.K,self.N))
        for n in range(self.N):
            for k in range(self.K):
                self.rev[k,n]=random.uniform(0,1)
        
    def delta(self):#reward feedback
        delta=2
        delta_tmp=0
        for n in range(self.N):
            sort_list=np.sort(self.x[n]@self.theta.T)
            delta_tmp=sort_list[-1]-sort_list[-2]
            if delta_tmp<delta:
                delta=delta_tmp
        # print(delta)
        return delta         
    
    def construct_M(self):
        A_=[[] for _ in range(self.N)]
        M=[]
        for n in range(self.N):
            A_[n].append(None)
            for k in range(self.K):
                A_[n].append(k)
        A_=[a if a else [None] for a in A_]

        combinations=list(product(*A_))
        for combination in combinations:
            S=[[] for _ in range(self.K)]
            for n in range(len(combination)):
                if combination[n]!=None: 
                    S[combination[n]].append(n)
            if all(len(sublist) <= self.L for sublist in S):
                M.append(S)
        return M

    
    def oracle(self):
        oracle_reward=0
        input_list = range(self.N)
        num_groups = self.K
        tmp_oracle_reward=0
        M=self.construct_M()
        for partition in M:
            for k in range(self.K):
                if len(partition[k])>0:
                    tmp_oracle_reward+=np.sum(self.rev[k,partition[k]]*np.exp(self.w[partition[k],k]))/(1+np.sum(np.exp(self.w[partition[k],k])))
            if oracle_reward<tmp_oracle_reward:
                oracle_reward=tmp_oracle_reward
                S=partition
            tmp_oracle_reward=0
        print('oracle',S)
        return oracle_reward, S

    def observe(self,S):#reward feedback
        reward=0
        exp_reward=0
        index=[]
        # print('S',S)
        for k in range(self.K):
            # print('S',S)
            # print('k',k)
            if len(S[k])==0:
                index.append(None)
            else:
                prob=np.zeros(len(S[k])+1)
                prob[1:1+len(S[k])]=np.exp(self.w[S[k],k])/(1+np.sum(np.exp(self.w[S[k],k])))
                prob[0]=1- np.sum(prob[1:1+len(S[k])])
                index_list=np.insert(S[k], 0, self.N)
                x=np.random.choice(index_list, p=prob)
                index.append(x)

        return index        

    def exp_reward(self,S):#reward feedback
        exp_reward=0
        for k in range(self.K):
            exp_reward+=np.sum(self.rev[k,S[k]]*np.exp(self.w[S[k],k]))/(1+np.sum(np.exp(self.w[S[k],k])))
           
        return exp_reward               

    

