import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from .nn_utils import sparsemax, sparsemoid, ModuleWithInit
from .utils import check_numpy
from warnings import warn


relu = nn.ReLU()


class ODST(ModuleWithInit):
    def __init__(self, in_features, num_trees, tree_dim=1, flatten_output=True,
                 choice_function=sparsemax, bin_function=sparsemoid,
                 initialize_response_=nn.init.normal_, initialize_selection_logits_=nn.init.uniform_,
                 threshold_init_beta=1.0, threshold_init_cutoff=1.0, device= "cpu" ,monotone=False):
        super().__init__()
        
        depth = in_features
        self.q = 0.01
        self.qq = 0.0
        self.device = device
        self.depth, self.num_trees, self.tree_dim, self.flatten_output = depth, num_trees, tree_dim, flatten_output
        self.choice_function, self.bin_function = choice_function, bin_function
        self.threshold_init_beta, self.threshold_init_cutoff = threshold_init_beta, threshold_init_cutoff
        self.monotone = monotone
                
        
        self.response_ = nn.Parameter(torch.zeros([num_trees, tree_dim]), requires_grad=True)
        
        initialize_response_(self.response_)
        
        #if self.depth == 1 :
        #    self.final_response = nn.Parameter(torch.stack([self.response_,self.response_],dim=2), requires_grad=False)
        #elif self.depth == 2 :
        #    self.final_response = nn.Parameter(torch.stack([self.response_,self.response_,self.response_,self.response_],dim=2), requires_grad=False)
        
        self.final_response = nn.Parameter(torch.stack([self.response_]*(2**self.depth),dim=2), requires_grad=False)
            
        
        self.feature_selection_logits = nn.Parameter(
            torch.zeros([in_features, num_trees, depth]), requires_grad=True
        )
        initialize_selection_logits_(self.feature_selection_logits)

        self.feature_thresholds = nn.Parameter(
            torch.full([num_trees, depth], float('nan'), dtype=torch.float32), requires_grad=True
        )  # nan values will be initialized on first batch (data-aware init)

        self.log_temperatures = nn.Parameter(
            torch.full([num_trees, depth], float('nan'), dtype=torch.float32), requires_grad=True
        )

        # binary codes for mapping between 1-hot vectors and bin indices
        with torch.no_grad():
            indices = torch.arange(2 ** self.depth)
            offsets = 2 ** torch.arange(self.depth)
            bin_codes = (indices.view(1, -1) // offsets.view(-1, 1) % 2).to(torch.float32)
            bin_codes_1hot = torch.stack([bin_codes, 1.0 - bin_codes], dim=-1)
            self.bin_codes_1hot = nn.Parameter(bin_codes_1hot, requires_grad=False)
            # ^-- [depth, 2 ** depth, 2]

    def forward(self, input,training):
        #defalut
        
        self.training = training
        assert len(input.shape) >= 2
        if len(input.shape) > 2:
            return self.forward(input.view(-1, input.shape[-1])).view(*input.shape[:-1], -1)
        # new input shape: [batch_size, in_features]

        feature_selectors  = torch.zeros([self.depth,self.num_trees,self.depth]).to(self.device)
        for i in range(self.depth):
            feature_selectors[i][:,i] = 1
      

        feature_values = torch.einsum('bi,ind->bnd', input, feature_selectors)
        # ^--[batch_size, num_trees, depth]

        threshold_logits = (feature_values - self.feature_thresholds) * torch.exp(-self.log_temperatures)

        threshold_logits = torch.stack([-threshold_logits, threshold_logits], dim=-1)
        # ^--[batch_size, num_trees, depth, 2]

        bins = self.bin_function(threshold_logits)
        # ^--[batch_size, num_trees, depth, 2], approximately binary
        
        #print(bins , self.bin_codes_1hot)

        bin_matches = torch.einsum('btds,dcs->btdc', bins, self.bin_codes_1hot)
        # ^--[batch_size, num_trees, depth, 2 ** depth]

        response_weights = torch.prod(bin_matches, dim=-2)
        # ^-- [batch_size, num_trees, 2 ** depth]
        
        self.response_weights = response_weights
        
        
        if self.training == True :
            
            self.update_response(response_weights,self.monotone)
      

        if self.training == True:
            output = torch.einsum('bnd,ncd->bnc', response_weights, self.response).flatten(1,2)
        else:
            output = torch.einsum('bnd,ncd->bnc', response_weights, self.final_response).flatten(1,2)
        
        
        output = output.reshape(input.shape[0],self.num_trees,self.tree_dim)
        # ^-- [batch_size, num_trees, tree_dim]
        
   
        return output
    
        
    def update_response(self, response_weights,monotone):

        weight_ = torch.sum(response_weights, dim= 0)
        
        #Weight_ & response_weights , R : Right prob , L : Left prob
        #depth 1: [R,L]
        #depth 2: [RR,LR,RL,LL]
        #depth 3: [RRR,LRR,RLR,LLR,RRL,LRL,RLL,LLL]
        #depth 4: [RRRR,LRRR,RLRR,LLRR,
        #          RRLR,LRLR,RLLR,LLLR,
        #          RRRL,LRRL,RLRL,LLRL,
        #          RRLL,LRLL,RLLL,LLLL]
        
        # For stable learning.
        weight_[weight_ == 0] = 0.1
        
        if self.depth == 1:

            L = torch.ones([self.num_trees,1 ]).to(self.device)   
            R = - (weight_[:,1] / weight_[:,0])
            R = R.reshape(-1,1)
            
            
            if monotone == "incre":
                self.response = torch.stack([R*(-1)*(relu(self.response_)+1),L*(-1)*(relu(self.response_)+1)],dim=2) 
            elif monotone == "decre":
                self.response = torch.stack([R*(relu(self.response_)+1),L*(relu(self.response_)+1)],dim=2) 
            else:
                self.response = torch.stack([R*self.response_,L*self.response_],dim=2) 
            
            
        elif self.depth == 2:
                          
            
            LL = torch.ones([self.num_trees,1 ]).to(self.device)
 
            LR = -( ( weight_[:,2] + weight_[:,3] ) / (weight_[:,1] + weight_[:,0])  )
            RL = -( ( weight_[:,3] + weight_[:,1] ) / (weight_[:,2] + weight_[:,0])  )               
            
            RR = LR*RL
            
            LR = LR.reshape(-1,1)
            RL = RL.reshape(-1,1)
            RR = RR.reshape(-1,1)
            
            self.response = torch.stack([RR*self.response_,LR*self.response_,RL*self.response_,LL*self.response_],dim=2)  
          
            
        elif self.depth == 3:
            LLL = torch.ones([self.num_trees,1 ]).to(self.device)
            LLR = - (weight_[:,4] + weight_[:,5] + weight_[:,6] + weight_[:,7])/ (weight_[:,3] + weight_[:,1] + weight_[:,2] + weight_[:,0])
            LRL = - (weight_[:,7] + weight_[:,3] + weight_[:,6] + weight_[:,2])/ (weight_[:,5] + weight_[:,0] + weight_[:,1] + weight_[:,4] )
            RLL = - (weight_[:,7] + weight_[:,3] + weight_[:,5] + weight_[:,1])/ (weight_[:,0] + weight_[:,2] + weight_[:,4] + weight_[:,6]  )
            LRR = LRL*LLR
            RLR = LLR*RLL
            RRL = LRL*RLL
            RRR = LLR*LRL*RLL
            
            LLR,LRL,RLL,LRR,RLR,RRL,RRR = LLR.reshape(-1,1),LRL.reshape(-1,1),RLL.reshape(-1,1),LRR.reshape(-1,1),RLR.reshape(-1,1),RRL.reshape(-1,1),RRR.reshape(-1,1)
            
            self.response = torch.stack([RRR*self.response_,LRR*self.response_,RLR*self.response_,LLR*self.response_,RRL*self.response_,LRL*self.response_,RLL*self.response_,LLL*self.response_],dim=2)  
            
        elif self.depth == 4:
            LLLL = torch.ones([self.num_trees,1 ]).to(self.device)
            LLLR = -(weight_[:,15] + weight_[:,11] + weight_[:,13] + weight_[:,14] + weight_[:,9] + weight_[:,10] + weight_[:,12] + weight_[:,8] ) / ( weight_[:,0] + weight_[:,1] + weight_[:,2] + weight_[:,3] + weight_[:,4] + weight_[:,5] + weight_[:,6] + weight_[:,7] )
            LLRL = -(weight_[:,15] + weight_[:,7] + weight_[:,13] + weight_[:,14] + weight_[:,5] + weight_[:,6] + weight_[:,12] + weight_[:,4] ) / (weight_[:,0] + weight_[:,1] + weight_[:,2] + weight_[:,3] + weight_[:,8] + weight_[:,9] + weight_[:,10] + weight_[:,11] )
            RLLL = -(weight_[:,15] + weight_[:,7] + weight_[:,11] + weight_[:,13] + weight_[:,3] + weight_[:,5] + weight_[:,9] + weight_[:,1]) / (weight_[:,0] + weight_[:,2] + weight_[:,4] + weight_[:,6] + weight_[:,8] + weight_[:,10] + weight_[:,12] + weight_[:,14] )
            LRLL = -(weight_[:,15] + weight_[:,7] + weight_[:,11] + weight_[:,14] + weight_[:,3] + weight_[:,6] + weight_[:,10] + weight_[:,2])/ (weight_[:,0] + weight_[:,1] + weight_[:,4] + weight_[:,5] + weight_[:,8] + weight_[:,9] + weight_[:,12] + weight_[:,13])
            
            RLRL = RLLL * LLRL 
            RRLL = RLLL * LRLL
            RLLR = RLLL * LLLR
            
            LRLR = LRLL * LLLR
            LRRL = LRLL * LLRL
            RRLR = LRLL * RLLR
            RRRL = LRLL * RLRL 
            
            LRRR = LLRL * LRLR
            LLRR = LLRL * LLLR
            RLRR = LLRL * RLLR
            
            RRRR = LRLL * RLRR
            
            RRRR,LRRR,RLRR,LLRR,RRLR,LRLR,RLLR,LLLR,RRRL,LRRL,RLRL,LLRL,RRLL,LRLL,RLLL,LLLL  = RRRR.reshape(-1,1),LRRR.reshape(-1,1),RLRR.reshape(-1,1),LLRR.reshape(-1,1),RRLR.reshape(-1,1),LRLR.reshape(-1,1),RLLR.reshape(-1,1),LLLR.reshape(-1,1),RRRL.reshape(-1,1),LRRL.reshape(-1,1),RLRL.reshape(-1,1),LLRL.reshape(-1,1),RRLL.reshape(-1,1),LRLL.reshape(-1,1),RLLL.reshape(-1,1),LLLL.reshape(-1,1) 
            
            self.response = torch.stack([RRRR*self.response_,LRRR*self.response_,RLRR*self.response_,LLRR*self.response_,RRLR*self.response_,LRLR*self.response_,RLLR*self.response_,LLLR*self.response_,RRRL*self.response_,LRRL*self.response_,RLRL*self.response_,LLRL*self.response_,RRLL*self.response_,LRLL*self.response_,RLLL*self.response_,LLLL*self.response_],dim=2)
            
    #Save the constant for sum-to-zero condition.
    def save_id_constants(self):
        self.update_response(self.response_weights,self.monotone)
        self.final_response = nn.Parameter( self.response, requires_grad=False )
    
            
            

    def initialize(self, input, eps=1e-6):
        # data-aware initializer
        assert len(input.shape) == 2
        if input.shape[0] < 1000:
            warn("Data-aware initialization is performed on less than 1000 data points. This may cause instability."
                 "To avoid potential problems, run this model on a data batch with at least 1000 data samples."
                 "You can do so manually before training. Use with torch.no_grad() for memory efficiency.")
        with torch.no_grad():
            #print("intial")
            #feature_selectors = self.choice_function(self.feature_selection_logits, dim=0)
            
            #if self.depth == 1:
            #    feature_selectors  = torch.zeros([self.depth,self.num_trees,self.depth]).to(self.device)
            #    feature_selectors[0][:,0] = 1
                
            #elif self.depth == 2:
            #    feature_selectors  = torch.zeros([self.depth,self.num_trees,self.depth]).to(self.device)
            #    feature_selectors[0][:,0] = 1
            #    feature_selectors[1][:,1] = 1           

            feature_selectors  = torch.zeros([self.depth,self.num_trees,self.depth]).to(self.device)
            for i in range(self.depth):
                feature_selectors[i][:,i] = 1    
            
            # ^--[in_features, num_trees, depth]

            feature_values = torch.einsum('bi,ind->bnd', input, feature_selectors)
            # ^--[batch_size, num_trees, depth]

            # initialize thresholds: sample random percentiles of data
            percentiles_q = 100 * np.random.beta(self.threshold_init_beta, self.threshold_init_beta,
                                                 size=[self.num_trees, self.depth])
            self.feature_thresholds.data[...] = torch.as_tensor(
                list(map(np.percentile, check_numpy(feature_values.flatten(1, 2).t()), percentiles_q.flatten())),
                dtype=feature_values.dtype, device=feature_values.device
            ).view(self.num_trees, self.depth)

            # init temperatures: make sure enough data points are in the linear region of sparse-sigmoid
            temperatures = np.percentile(check_numpy(abs(feature_values - self.feature_thresholds)),
                                         q=100 * min(1.0, self.threshold_init_cutoff), axis=0)

            # if threshold_init_cutoff > 1, scale everything down by it
            temperatures /= max(1.0, self.threshold_init_cutoff)
            self.log_temperatures.data[...] = torch.log(torch.as_tensor(temperatures) + eps)

    def __repr__(self):
        return "{}(in_features={}, num_trees={}, depth={}, tree_dim={}, flatten_output={})".format(
            self.__class__.__name__, self.feature_selection_logits.shape[0],
            self.num_trees, self.depth, self.tree_dim, self.flatten_output
        )
        
        

## GAM_NODE


class GAM_NODE_ODST(ModuleWithInit):
    def __init__(self, in_features, num_trees, depth=6, tree_dim=1, flatten_output=True,
                 choice_function=sparsemax, bin_function=sparsemoid,
                 initialize_response_=nn.init.normal_, initialize_selection_logits_=nn.init.uniform_,
                 threshold_init_beta=1.0, threshold_init_cutoff=1.0,
                 ):
        """
        Oblivious Differentiable Sparsemax Trees. http://tinyurl.com/odst-readmore
        One can drop (sic!) this module anywhere instead of nn.Linear
        :param in_features: number of features in the input tensor
        :param num_trees: number of trees in this layer
        :param tree_dim: number of response channels in the response of individual tree
        :param depth: number of splits in every tree
        :param flatten_output: if False, returns [..., num_trees, tree_dim],
            by default returns [..., num_trees * tree_dim]
        :param choice_function: f(tensor, dim) -> R_simplex computes feature weights s.t. f(tensor, dim).sum(dim) == 1
        :param bin_function: f(tensor) -> R[0, 1], computes tree leaf weights

        :param initialize_response_: in-place initializer for tree output tensor
        :param initialize_selection_logits_: in-place initializer for logits that select features for the tree
        both thresholds and scales are initialized with data-aware init (or .load_state_dict)
        :param threshold_init_beta: initializes threshold to a q-th quantile of data points
            where q ~ Beta(:threshold_init_beta:, :threshold_init_beta:)
            If this param is set to 1, initial thresholds will have the same distribution as data points
            If greater than 1 (e.g. 10), thresholds will be closer to median data value
            If less than 1 (e.g. 0.1), thresholds will approach min/max data values.

        :param threshold_init_cutoff: threshold log-temperatures initializer, \in (0, inf)
            By default(1.0), log-remperatures are initialized in such a way that all bin selectors
            end up in the linear region of sparse-sigmoid. The temperatures are then scaled by this parameter.
            Setting this value > 1.0 will result in some margin between data points and sparse-sigmoid cutoff value
            Setting this value < 1.0 will cause (1 - value) part of data points to end up in flat sparse-sigmoid region
            For instance, threshold_init_cutoff = 0.9 will set 10% points equal to 0.0 or 1.0
            Setting this value > 1.0 will result in a margin between data points and sparse-sigmoid cutoff value
            All points will be between (0.5 - 0.5 / threshold_init_cutoff) and (0.5 + 0.5 / threshold_init_cutoff)
        """
        super().__init__()
        self.depth, self.num_trees, self.tree_dim, self.flatten_output = depth, num_trees, tree_dim, flatten_output
        self.choice_function, self.bin_function = choice_function, bin_function
        self.threshold_init_beta, self.threshold_init_cutoff = threshold_init_beta, threshold_init_cutoff

        self.response = nn.Parameter(torch.zeros([num_trees, tree_dim, 2 ** depth]), requires_grad=True)
        initialize_response_(self.response)

        self.feature_selection_logits = nn.Parameter(
            torch.zeros([in_features, num_trees, depth]), requires_grad=True
        )
        initialize_selection_logits_(self.feature_selection_logits)

        self.feature_thresholds = nn.Parameter(
            torch.full([num_trees, depth], float('nan'), dtype=torch.float32), requires_grad=True
        )  # nan values will be initialized on first batch (data-aware init)

        self.log_temperatures = nn.Parameter(
            torch.full([num_trees, depth], float('nan'), dtype=torch.float32), requires_grad=True
        )

        # binary codes for mapping between 1-hot vectors and bin indices
        with torch.no_grad():
            indices = torch.arange(2 ** self.depth)
            offsets = 2 ** torch.arange(self.depth)
            bin_codes = (indices.view(1, -1) // offsets.view(-1, 1) % 2).to(torch.float32)
            bin_codes_1hot = torch.stack([bin_codes, 1.0 - bin_codes], dim=-1)
            self.bin_codes_1hot = nn.Parameter(bin_codes_1hot, requires_grad=False)
            # ^-- [depth, 2 ** depth, 2]

    def forward(self, input):
        assert len(input.shape) >= 2
        if len(input.shape) > 2:
            return self.forward(input.view(-1, input.shape[-1])).view(*input.shape[:-1], -1)
        # new input shape: [batch_size, in_features]

        feature_logits = self.feature_selection_logits
        feature_selectors = self.choice_function(feature_logits, dim=0)
        # ^--[in_features, num_trees, depth]

        feature_values = torch.einsum('bi,ind->bnd', input, feature_selectors)
        # ^--[batch_size, num_trees, depth]

        threshold_logits = (feature_values - self.feature_thresholds) * torch.exp(-self.log_temperatures)

        threshold_logits = torch.stack([-threshold_logits, threshold_logits], dim=-1)
        # ^--[batch_size, num_trees, depth, 2]

        bins = self.bin_function(threshold_logits)
        # ^--[batch_size, num_trees, depth, 2], approximately binary

        bin_matches = torch.einsum('btds,dcs->btdc', bins, self.bin_codes_1hot)
        # ^--[batch_size, num_trees, depth, 2 ** depth]

        response_weights = torch.prod(bin_matches, dim=-2)
        # ^-- [batch_size, num_trees, 2 ** depth]

        response = torch.einsum('bnd,ncd->bnc', response_weights, self.response)
        # ^-- [batch_size, num_trees, tree_dim]

        return response.flatten(1, 2) if self.flatten_output else response

    def initialize(self, input, eps=1e-6):
        # data-aware initializer
        assert len(input.shape) == 2
        if input.shape[0] < 1000:
            warn("Data-aware initialization is performed on less than 1000 data points. This may cause instability."
                 "To avoid potential problems, run this model on a data batch with at least 1000 data samples."
                 "You can do so manually before training. Use with torch.no_grad() for memory efficiency.")
        with torch.no_grad():
            feature_selectors = self.choice_function(self.feature_selection_logits, dim=0)
            # ^--[in_features, num_trees, depth]

            feature_values = torch.einsum('bi,ind->bnd', input, feature_selectors)
            # ^--[batch_size, num_trees, depth]

            # initialize thresholds: sample random percentiles of data
            percentiles_q = 100 * np.random.beta(self.threshold_init_beta, self.threshold_init_beta,
                                                 size=[self.num_trees, self.depth])
            self.feature_thresholds.data[...] = torch.as_tensor(
                list(map(np.percentile, check_numpy(feature_values.flatten(1, 2).t()), percentiles_q.flatten())),
                dtype=feature_values.dtype, device=feature_values.device
            ).view(self.num_trees, self.depth)

            # init temperatures: make sure enough data points are in the linear region of sparse-sigmoid
            temperatures = np.percentile(check_numpy(abs(feature_values - self.feature_thresholds)),
                                         q=100 * min(1.0, self.threshold_init_cutoff), axis=0)

            # if threshold_init_cutoff > 1, scale everything down by it
            temperatures /= max(1.0, self.threshold_init_cutoff)
            self.log_temperatures.data[...] = torch.log(torch.as_tensor(temperatures) + eps)

    def __repr__(self):
        return "{}(in_features={}, num_trees={}, depth={}, tree_dim={}, flatten_output={})".format(
            self.__class__.__name__, self.feature_selection_logits.shape[0],
            self.num_trees, self.depth, self.tree_dim, self.flatten_output
        )
    
    
