import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt


class Discretizing_Layer(nn.Module):
    def __init__(self, n_features, predicates_per_feature, initial_cutpoints, temperature=0.1):
        super().__init__()
        self.temperature = temperature
        self.predicates_per_feature = predicates_per_feature

        self.cut_points = nn.Parameter(initial_cutpoints.detach().clone(), requires_grad=True)

        D = 2
        fixed_weights = torch.reshape(torch.linspace(1.0, D + 1.0, D + 1, dtype=torch.float32), [D+1])
        # repeat per feature
        self.fixed_weights = fixed_weights.repeat(n_features,1)
        self.fixed_weights = nn.Parameter(self.fixed_weights.clone().detach(),requires_grad=False)
        self.is_discrete = [False for i in range(n_features)]

    def forward(self, x):
        cut_points = self.cut_points
        x = x.unsqueeze(2)
        lower_threshold = cut_points[:,0,:]
        upper_threshold = cut_points[:,1,:]
        

        in_interval = (2*x - lower_threshold)/self.temperature
        below_interval = (x)/self.temperature
        above_interval = (3*x - lower_threshold - upper_threshold)/self.temperature

        # Use the Log-Sum-Exp trick for numerical stability
        max_interval = torch.maximum(torch.maximum(in_interval, below_interval), above_interval).detach()
        numerator = torch.exp(in_interval - max_interval)
        denominator = (
            torch.exp(in_interval - max_interval) +
            torch.exp(below_interval - max_interval) +
            torch.exp(above_interval - max_interval)
        )
        if torch.isinf(numerator).any() or torch.isinf(denominator).any():
            print("Inf in numerator layer")
            #print(self.temperature)
            #print(numerator)
            #print(denominator)

        output = numerator / denominator

        # mid_point = (lower_threshold + upper_threshold) / 2

        # mid_point_in = (2*mid_point - lower_threshold)/self.temperature
        # mid_point_below = (mid_point)/self.temperature
        # mid_point_above = (3*mid_point - lower_threshold - upper_threshold)/self.temperature
        # # Use the Log-Sum-Exp trick for numerical stability
        # max_mid_point = torch.maximum(torch.maximum(mid_point_in, mid_point_below), mid_point_above).detach()
        # mid_point_numerator = torch.exp(mid_point_in - max_mid_point)
        # mid_point_denominator = (
        #     torch.exp(mid_point_in - max_mid_point) +
        #     torch.exp(mid_point_below - max_mid_point) +
        #     torch.exp(mid_point_above - max_mid_point)
        # )
        # mid_point_max = mid_point_numerator / mid_point_denominator
        # mid_point_max = torch.clamp(mid_point_max, min=0.5)
        # mid_point_max.unsqueeze_(0)
        # scaling_factor = 1 / (mid_point_max)

        # output = output * scaling_factor

        
        return output

    def fix_parameters(self):
        self.cut_points.data, _ = torch.sort(self.cut_points.data,dim=1)
        return
    
    def get_predicates(self,data_limits,feature_names=None,scaler=None,as_string=True):
        cut_points = self.cut_points.data
        if feature_names is None:
            feature_names = [f"Feature {i}" for i in range(cut_points.shape[0])]
        if scaler is not None:
            all_thresholds = np.zeros(cut_points.shape)

            for i in range(self.predicates_per_feature):
                thresholds = scaler.inverse_transform(cut_points[:,:,i].detach().cpu().numpy().T).T
                all_thresholds[:,:,i] = thresholds            
            data_limits = scaler.inverse_transform(data_limits.detach().cpu().numpy().T).T
            cut_points = all_thresholds
        else:
            cut_points = cut_points.detach().cpu().numpy()
            data_limits = data_limits.detach().cpu().numpy()
        
        predicates = []
        tuple_predicates = []
        for i in range(cut_points.shape[0]):
            for j in range(self.predicates_per_feature):
                lower_bound = cut_points[i,0,j]
                upper_bound = cut_points[i,1,j]
                if lower_bound < data_limits[i,0] and upper_bound > data_limits[i,1]:
                    predicates.append("True")
                    tuple_predicates.append((lower_bound,upper_bound))
                    continue
                
                if self.is_discrete[i]:
                    lower_bound = np.ceil(lower_bound)
                    upper_bound = np.floor(upper_bound)
                    lower_bound = bool(lower_bound)
                    predicate = f"{feature_names[i]} = {lower_bound}"
                    predicates.append(predicate)
                    tuple_predicates.append((lower_bound,upper_bound))
                else:
                    lower_bound = np.max([data_limits[i,0],lower_bound])
                    upper_bound = np.min([data_limits[i,1],upper_bound])
                    predicate = f"{lower_bound:.2f} < {feature_names[i]} < {upper_bound:.2f}"
                    predicates.append(predicate)
                    tuple_predicates.append((lower_bound,upper_bound))
        if as_string:
            return predicates
        return tuple_predicates

class And_Layer(nn.Module):
    def __init__(self, n_features, n_rules, epsilon=1e-6):
        super().__init__()
        self.epsilon = epsilon
        self.n_rules = n_rules
        self.and_weights = nn.Parameter(torch.rand([n_rules,n_features],dtype=torch.float32), requires_grad=True)#,predicates_per_feature],dtype=torch.float32), requires_grad=True)
        
        self.and_weights.data[:] = 1

    
        self.relu = nn.ReLU()

    def forward(self, x):
        and_weights = self.relu(self.and_weights)
        
        # swap 1 and 2 axes of x
        
        x = x.permute(0,2,1)
        
        weight_sum = and_weights.sum(axis=1) #+ 1e-5
        eta = (self.epsilon / weight_sum)
        eta = eta.detach()
        eta = eta.unsqueeze(0)
        eta = eta.unsqueeze(2)
        # geometric weight mean

        inverse_sum = (1+eta)/(x+eta)
        
        and_weights = and_weights.unsqueeze(0)

        weighted_inverse_sum = inverse_sum*and_weights
        weighted_inverse_sum = weighted_inverse_sum.sum(dim=2)


        #inverse_sum = inverse_sum.reshape([x.shape[0],-1])
        #weighted_inverse_sum = torch.multiply(inverse_sum,and_weights)
        #weighted_inverse_sum = weighted_inverse_sum.reshape([x.shape[0],self.n_rules,-1])
        #weighted_inverse_sum = torch.sum(weighted_inverse_sum,dim=[2])

        res = weight_sum/(weighted_inverse_sum)
        # if torch.isnan(weight_sum).any():
        #     print("NaNs in weight_sum")
        # if torch.isnan(weighted_inverse_sum).any():
        #     print("NaNs in weighted_inverse_sum")
        # if torch.isnan(res).any():
        #     print("NaNs in res")
        return res

    def fix_parameters(self):
        #self.and_weights.data = torch.clamp(self.and_weights.data,max=5)
        pass
    
class RuleLearner(nn.Module):
    def __init__(self, config, initial_cutpoints):
        super().__init__()

        self.discretizer = Discretizing_Layer(config.n_features,config.n_rules, initial_cutpoints, config.predicate_temperature)

        self.rules = And_Layer(config.n_features,config.n_rules)


    def forward(self, x):

        predicates = self.discretizer(x)

        prediction = self.rules(predicates)

        
        return prediction
    

    def fix_parameters(self):
        self.discretizer.fix_parameters()
        self.rules.fix_parameters()
        return

    def get_rule(self,index, data_limits, threshold=0., scaler_x=None, feature_names=None):

        if feature_names is None:
            feature_names = [f"Feature {i}" for i in range(data_limits.shape[0])]

        cut_points = self.discretizer.cut_points
        if scaler_x is not None:
            all_thresholds = np.zeros(cut_points.shape)

            for i in range(self.discretizer.predicates_per_feature):
                thresholds = scaler_x.inverse_transform(cut_points[:,:,i].detach().cpu().numpy().T).T
                all_thresholds[:,:,i] = thresholds            
            data_limits = scaler_x.inverse_transform(data_limits.detach().cpu().numpy().T).T
            cut_points = all_thresholds
        else:
            cut_points = cut_points.detach().cpu().numpy()
            data_limits = data_limits.detach().cpu().numpy()
        rule = []
        and_weights = self.rules.and_weights.data[index,:].detach().cpu().numpy()
        for i in range(and_weights.shape[0]):
            lower_bound = data_limits[i,0]
            upper_bound = data_limits[i,1]
            feature_weight = and_weights[i]
            if feature_weight <= threshold:
                continue
            
            lower_threshold = cut_points.data[i,0,index]
            upper_threshold = cut_points.data[i,1,index]
            if self.discretizer.is_discrete[i]:
                lower_threshold = np.ceil(lower_threshold)
                upper_threshold = np.floor(upper_threshold)
                lower_threshold = bool(lower_threshold)
                lower_bound = np.max([lower_bound,lower_threshold])
                
            else:
                lower_bound = np.max([lower_bound,lower_threshold])
                upper_bound = np.min([upper_bound,upper_threshold])
            if data_limits[i,0] == lower_bound and data_limits[i,1] == upper_bound:
                continue
            if self.discretizer.is_discrete[i]:
                predicate = f"{feature_names[i]} = {lower_bound}"
                rule.append(predicate)
            else:
                predicate = ""
                if lower_bound > data_limits[i,0]:
                    predicate = predicate + f"{lower_bound:.2f} < "
                predicate = predicate + f"{feature_names[i]}"
                if upper_bound < data_limits[i,1]:
                    predicate = predicate + f" < {upper_bound:.2f}"
                rule.append(predicate)
        rule = " ∧ ".join(rule)
        return rule
    
    def to_hard_rule(self):
        cut_points = self.discretizer.cut_points.data
        and_weights = self.rules.and_weights.data
        rule_order = torch.arange(cut_points.shape[2], dtype=torch.float32)
        return HardRuleModel(cut_points, and_weights, rule_order)
    
    def plot_rule_predicates(self, feature_names, X):
        cutpoints = self.cut_points
        and_weights = self.and_weights
        for i, feature_name in enumerate(feature_names):
            #print(f"Feature: {feature_name}, and_weight: {and_weights[0,i]}, Min {fmin}, Cutpoint for pred_0: {cutpoints[i][0]}, Cutpoint for pred_1: {cutpoints[i][1]}, Max {fmax}")
            if and_weights[0,i] <= 0:
                continue
            plt.figure(figsize=(10, 5))
            plt.hist(X[:,i], bins=30, alpha=0.5, label='s0(X)=1')
            plt.hist(X[:,i], bins=30, alpha=0.5, label='s0(X)=0')
            plt.vlines(cutpoints[i][0],0,100, color='black', linestyle='--',label="cutpoint")
            plt.vlines(cutpoints[i][1],0,100, color='black', linestyle='--')
            plt.xlabel(feature_name)
            plt.title(f"{cutpoints[i][0]} < {feature_name} < {cutpoints[i][1]}")
            plt.legend()

        
class HardRuleModel(nn.Module):
    def __init__(self, cut_points, and_weights, rule_order, path=None):
        super().__init__()
        if path:
            self.load_model_variables(path)
        else:
            self.cut_points = cut_points
            self.and_weights = and_weights
            self.rule_order = rule_order
        
    @classmethod
    def from_file(cls, path):
            return cls(None, None, None, None, path=path)
    
    def forward(self, x):
        pred = torch.zeros((x.shape[0],self.and_weights.shape[0]),dtype=torch.float32)
        for s in range(x.shape[0]):
            sample = x[s,:]
            for i in range(self.and_weights.shape[0]):
                lower_bound = self.cut_points[:,0,i]
                upper_bound = self.cut_points[:,1,i]
                predicate = (sample > lower_bound) & (sample < upper_bound)
                and_weights = self.and_weights[i,:] <= 0.
                if and_weights.all():
                    continue
                rule = (predicate | and_weights).all()
                if rule:
                    pred[s,i] = 1 
                    break
        return pred
    
    def forward_predicates(self, x):
        predicates = torch.zeros((x.shape[0],self.cut_points.shape[0],self.cut_points.shape[2]),dtype=torch.float32)
        for s in range(x.shape[0]):
            sample = x[s,:]
            for i in range(self.and_weights.shape[0]):
                lower_bound = self.cut_points[:,0,i]
                upper_bound = self.cut_points[:,1,i]
                predicate = (sample > lower_bound) & (sample < upper_bound)
                and_weights = self.and_weights[i,:] > 0.
                predicates[s,:,i] = (predicate & and_weights).float()
        return predicates
    
    def load_model_variables(self,filepath):
        checkpoint = torch.load(filepath)
        self.cut_points = checkpoint['cut_points']
        self.and_weights = checkpoint['and_weights']
        self.rule_order = checkpoint['rule_order']
        self.index_order = torch.argsort(self.rule_order, descending=True)
        return self
    
    def save(self,path):
        torch.save({
            'cut_points': self.cut_points,
            'and_weights': self.and_weights,
            'rule_order': self.rule_order,
        }, path)
        return
    
    