import numpy as np
import torch
import copy
import os

class DEEP_C:
    def __init__(self, gamma, d, T, generator):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.z_supp_len = 1
        self.theta_supp_len = 1
        self.gamma = gamma
        self.d = d
        self.T = T
        self.generator  = generator

        self.reward         = np.zeros((T,))
        self.optimal_reward = np.zeros((T,))

        self.t = 0
    
    def run(self, rep, env, basedir):
        rewards = np.zeros((rep,self.T))
        optimal_rewards = np.zeros((rep,self.T))
        
        for r in range(rep):
            print(f'run {r}')
            env.reset()

            # needed quantities for the procedures
            # k = np.ceil((n)**(1/4)) # the len of each partition
            k = np.ceil((self.T/10)**(1/4)) 
            a = np.arange(k)

            B = np.vstack(np.meshgrid(*[a]*(self.d+1))).reshape(self.d+1, -1).T
            #B[:,0] : the partition of the support of z
            #B[:,1] or the remaining columns : the partition of the support of theta
            B = B*1/k
            B[:, 0] = B[:, 0] * self.z_supp_len
            #B[:, 1:(B.shape[1])] = B[:, 1:(B.shape[1])] * self.theta_supp_len
            B[:, 1:] = B[:, 1:] * self.theta_supp_len

            # arms
            m = int(k**(self.d+1))
            # Initializing Ta (#times arm played) and Sa (#rewards) for each arm
            Ta = np.zeros(m)
            Sa = np.zeros(m)
            # Initializing estimates and confidence bounds for arms
            mu = np.zeros(m)
            l = np.zeros(m)
            u = np.ones(m)

            # Inititalizing set of active arms. 1 imples active, 0 implies inactive
            A = np.ones(m)

            Pmin = np.zeros(m)
            Pmax = np.zeros(m)

            for t in range(self.T):
                # receive a context
                x = env.gen_context()
                x_np = x.numpy(force=True)

                # Compute price range for each active arm. For Pmin, use right boundry for dimension
                # with negative X-entry, and left otherwise
                for j in np.nditer(np.nonzero(A)):
                    Pmin[j] = B[j, 0] * \
                        np.exp(np.dot(x_np, B[j, 1:self.d+1] + (x_np < 0)*(self.theta_supp_len/k)))
                    Pmax[j] = (B[j, 0] + self.z_supp_len/k) * \
                        np.exp(np.dot(x_np, B[j, 1:self.d+1] + (x_np > 0)*(self.theta_supp_len/k)))
                # play one of the active prixes at random
                price = np.random.uniform(np.amin(Pmin[A > 0]), np.amax(Pmax[A > 0]))
                # take the action and receive response of env
                realization, probability = env.act(x,price)
                # log data
                self.reward[t] = price*probability
                _, self.optimal_reward[t] = env.optimal_action(x)
                t += 1

                # Update arm's statistical variables
                A_temp = A.copy()
                for at in np.nditer(np.nonzero(A)):
                    if price > Pmin[at] and price < Pmax[at]:
                        Ta[at] = Ta[at] + 1
                        Sa[at] = Sa[at] + price * realization
                        mu[at] = Sa[at]/Ta[at]
                        l[at] = mu[at] - np.sqrt(self.gamma/Ta[at])
                        u[at] = mu[at] + np.sqrt(self.gamma/Ta[at])

                        # Eliminate arms
                        if u[at] < np.max(np.multiply(A, l)):  # Eliminate this arms if l to low
                            A[at] = 0
                        A[u < l[at]*A] = 0  # Eliminate other arms with low u
                if (sum(A != 0) == 0):
                    A = A_temp.copy()  # if supp(A) becomes zero unexpectedly, then neglict this step
                    print("Warning: active set became empty at time %s. wiil maintain the current active set..." % t)

            rewards[r] = copy.deepcopy(self.reward)
            optimal_rewards[r] = copy.deepcopy(self.optimal_reward)

        # save results
        if not os.path.exists(basedir):
            os.makedirs(basedir)
        np.save(basedir+'/reward.npy', rewards)
        np.save(basedir+'/optimal_reward.npy', optimal_rewards)