import numpy as np
import torch
from Hyperrectangle import Hyperrectangle


class DesignPoint:

    def __init__(self, x: np.ndarray, R: Hyperrectangle, design_index: int, depth = None, cell_bound = None ):
        self.x = x
        self.R = R  # The confidence region (Hyperrectangle)
        self.design_index = design_index
        self.depth = depth
        self.cell_bound = cell_bound

    def __eq__(self, other):
        return (self.x == other.x).all()

    def __str__(self):
        name = "\nDesign Point: x " + str(self.x) +\
               "\nHyperrectangle" + str(self.R)
        return name

    def update_cumulative_conf_rect(self, mu, cov, beta, t):
        # High probability lower and upper bound, B
        L = mu.reshape(-1)-np.sqrt(np.diag(cov.detach().cpu().numpy()))*np.sqrt(beta)#mu - np.sqrt(beta) * sigma
        U = mu.reshape(-1)+np.sqrt(np.diag(cov.detach().cpu().numpy()))*np.sqrt(beta)#mu + np.sqrt(beta) * sigma

        #L = mu.reshape(-1)-np.sqrt(np.diag(cov))*np.log(beta)#mu - np.sqrt(beta) * sigma
        #U = mu.reshape(-1)+np.sqrt(np.diag(cov))*np.log(beta)#mu + np.sqrt(beta) * sigma
        
        # Confidence hyperrectangle, Q
        Q = Hyperrectangle(L.reshape(-1).tolist(), U.reshape(-1).tolist())
        # Cumulative confidence hyperrectangle, R
        self.R = self.R.intersect(Q,t)
        self.mu = mu
        #self.cov = cov
    def get_child_list(self):#returns a list of children design_points
        option1 = [self.cell_bound[0][0], (self.cell_bound[0][0]+  self.cell_bound[0][1])/2 ,self.cell_bound[0][1]]
        option2 = [self.cell_bound[1][0], (self.cell_bound[1][0]+  self.cell_bound[1][1])/2 ,self.cell_bound[1][1]]
        new_bounds = [[[option1[i,i+1]],[option2[e,e+1]]] for i in range(2) for e in range(2)]
        list_children = []
        for ind,bound in enumerate(new_bounds):
            x = torch.tensor(bound,dtype=float).mean(axis=1)
            list_children.append(DesignPoint(x=x, Hyperrectangle=self.R, design_index= (self.design_index)*4+ind, depth = self.depth+1, cell_bound = bound )) #TODO: fix design_index here
        return list_children

