import numpy as np
import scipy
import scipy.optimize as so
import torch
import copy
import os
from tqdm import tqdm


class Node:
    def __init__(self, level, boundary, pl, pu, Nk):
        self.leaf = True
        self.children = []
        self.boundary = boundary
        self.d = boundary.shape[0]
        self.level = level
        self.count = 0
        self.pl = pl
        self.pu = pu
        self.Nk = Nk
        self.N = np.zeros((self.Nk))
        self.Y = np.zeros((self.Nk))

    def get_price(self):
        idx = (self.count-1) % self.Nk
        return idx, self.pl + idx*(self.pu-self.pl)/self.Nk

    def update(self, j, z):
        self.Y[j] = (self.N[j]*self.Y[j]+z)/(self.N[j]+1)
        self.N[j] += 1
    
    def split(self, T):
        self.leaf = False

        j = np.argmax(self.Y)
        p_star = self.pl + j*(self.pu-self.pl)/self.Nk
        delta_k = np.log(T)/np.power(2,self.level+1)
        pl_next = np.maximum(0, p_star-delta_k/2)
        pu_next = np.minimum(1, p_star+delta_k/2)

        left = []
        right = []
        for _d in range(self.d):
            left.append(np.array([self.boundary[_d,0],(self.boundary[_d,0]+self.boundary[_d,1])/2]))
            right.append(np.array([(self.boundary[_d,0]+self.boundary[_d,1])/2,self.boundary[_d,1]]))
        for i in range(np.power(2,self.d)):
            boundary = []
            for _d in range(self.d):
                if (i/np.power(2,_d))%2:
                    boundary.append(right[_d])
                else:
                    boundary.append(left[_d])
            boundary = np.array(boundary)
            self.children.append(Node(self.level+1, boundary, pl_next, pu_next, self.Nk))
        
        return p_star



class ABE:
    def __init__(self, d, c, T):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.d = d
        self.T = T
        self.K = np.maximum(1, int(np.ceil(np.log(self.T)/(self.d+4)/np.log(2))))
        self.N = int(np.ceil(np.log(self.T)))
        self.c = c

        self.reward         = np.zeros((T,))
        self.optimal_reward = np.zeros((T,))

    def reset(self):
        self.n = np.zeros((self.K,))
        for k in range(self.K):
            self.n[k] = np.maximum(0, int(self.c * np.power(2,4*k+15) / np.log(self.T)**3 *\
                                           (np.log(self.T) + np.log(np.log(self.T)) - k*(self.d+2)*np.log(2))))
        boundary = np.tile(np.array([-1,1]),(self.d,1))
        self.root = Node(0, boundary, 0, 1, self.N)
    
    def find_node(self, x):
        node = self.root
        while True:
            if node.leaf:
                return node
            idx = 0
            for _d in range(self.d):
                if x[_d]>=(node.boundary[_d,0]+node.boundary[_d,1])/2:
                    idx += np.power(2,_d)
            node = node.children[idx]

    def run(self, rep, env, basedir):
        rewards = np.zeros((rep,self.T))
        optimal_rewards = np.zeros((rep,self.T))
        
        for r in range(rep):
            print(f'run {r}')
            env.reset()
            self.reset()

            for t in range(self.T):
                x = env.gen_context()
                node = self.find_node(x.numpy(force=True))
                node.count += 1
                k = node.level

                if k < self.K:
                    if node.count < self.n[k]:
                        j, price = node.get_price()
                        realization, probability = env.act(x, price)
                        node.update(j, price*realization)
                    else:
                        price = node.split(self.T)
                        realization, probability = env.act(x, price)
                else:
                    price = (node.pl + node.pu)/2
                    realization, probability = env.act(x, price)

                self.reward[t] = price*probability
                _, self.optimal_reward[t] = env.optimal_action(x)
    
            rewards[r] = copy.deepcopy(self.reward)
            optimal_rewards[r] = copy.deepcopy(self.optimal_reward)

        # save results
        if not os.path.exists(basedir):
            os.makedirs(basedir)
        np.save(basedir+'/reward.npy', rewards)
        np.save(basedir+'/optimal_reward.npy', optimal_rewards)