import numpy as np
from torch.utils.data import Dataset


class RandomDuelingPolicy:
    def __init__(self, num_A, opt_a=None, cov=0.0, mode='random'):
        self.pairs = []
        for i in range(num_A):
            for j in range(num_A):
                if i != j:
                    self.pairs.append((i,j))
        num_pairs = len(self.pairs)
        # generate the random policy
        if mode == 'random':
            p_1 = np.ones(num_pairs)/num_pairs
            self. p  = p_1
        elif mode == 'dirich':
            p_1 = np.random.dirichlet(np.ones(num_pairs)) # randomly generate a policy following Dirichilet prior
            self. p  = p_1
        elif mode == 'mix':
            p_1 = np.random.dirichlet(np.ones(num_pairs))
            p_2 = np.zeros(num_pairs) # generate a point-mass (deterministic) policy with a randomly chosen action
            if cov > 0:
                p_2[opt_a] = 1
                w = cov
            else:
                p_2[np.random.choice(num_pairs)] = 1
                w = (np.random.choice(11)) / 10 # randomly choose the mixing weight
            self.p = (1 - w) * p_1 + w * p_2 # mixing 

        
        self.num_A = num_A
        self.num_pairs = num_pairs

    def act(self, As, Rs): # a policy takes history as input and take an action
        '''
        As: list of history actions
        Rs: list of history rewards
        '''
        pair_index = np.random.choice(self.num_pairs, p=self.p)
        a_1,a_2 = self.pairs[pair_index]
        return (a_1,a_2)

    def update(self, *args, **kargs):
        '''
        Dummy function; will be removed if parent class is defined
        '''
        pass

    def batch_update(self,*args,**kargs):
        pass

class DTS:
    def __init__(self,num_A, alpha=1):
        self.B = np.zeros((num_A, num_A))
        self.num_A = num_A
        self.alpha=alpha

    def act(self, h):
        # calculate upper and lower confidence bound
        U = np.zeros((self.num_A, self.num_A))
        L = np.zeros((self.num_A, self.num_A))
        for i in range(self.num_A):
            for j in range(self.num_A):
                if i == j:
                    U[i,j] = 1/2
                    L[i,j] = 1/2
                else:
                    U[i,j] = (self.B[i,j]/(self.B[i,j]+self.B[j,i]) + np.sqrt(self.alpha*np.log(h+1)/(self.B[i,j]+self.B[j,i]))) if self.B[i,j]+self.B[j,i] != 0 else 2
                    L[i,j] = (self.B[i,j]/(self.B[i,j]+self.B[j,i]) - np.sqrt(self.alpha*np.log(h+1)/(self.B[i,j]+self.B[j,i]))) if self.B[i,j]+self.B[j,i] != 0 else 0
        cs = np.zeros(self.num_A)
        for i in range(self.num_A):
            score = 0.
            for j in range(self.num_A):
                score += 1 if U[i,j] > 1/2 else 0
            cs[i] = score/(self.num_A-1)
        candidates_1 = [i for i in range(self.num_A) if cs[i] == max(cs)]
        
        Theta = np.zeros((self.num_A, self.num_A))
        for i in range(self.num_A):
            for j in range(self.num_A):
                if i < j:
                    Theta[i,j] = np.random.beta(self.B[i,j]+1, self.B[j,i]+1)
                    Theta[j,i] = 1 - Theta[i,j]
    
        # Choose the first bandit
        max_score = -1.
        best_b = None
        for b in candidates_1:
            score = 0
            for j in range(self.num_A):
                score += 1 if Theta[i,j] > 1/2 else 0
            if score > max_score:
                max_score = score
                best_b = b

        # Choose the second bandit
        Theta_2 = np.zeros(self.num_A)
        for j in range(self.num_A):
            Theta_2[j] = np.random.beta(self.B[j,best_b]+1, self.B[best_b,j]+1) if j != i else 1/2

        max_theta = -1
        best_b_2 = np.random.choice(self.num_A)
        for j in range(self.num_A):
            if L[j,best_b] > 1/2:
                pass
            else:
                if Theta_2[j] > max_theta:
                    max_theta = Theta_2[j]
                    best_b_2 = j
        
        return (best_b, best_b_2)
        
    def reset(self):
        self.B = np.zeros((self.num_A, self.num_A))
    
    def update(self, aw, al):
        self.B[aw,al]+=1

    def batch_update(self, aw_list, al_list):
        for i in range(len(aw_list)):
            self.B[aw_list[i], al_list[i]] += 1

