import numpy as np
import torch
import matplotlib.pyplot as plt
from models import *
from utils import *
from configs import *
from scipy.stats import gaussian_kde

from sklearn.linear_model import LinearRegression
from sklearn.ensemble import HistGradientBoostingRegressor, HistGradientBoostingClassifier, RandomForestClassifier, RandomForestRegressor


def run_method_single_subgroup(X,X0,X1,Y,Y0,Y1,scaler_X,scaler_Y,features_names,is_discrete,our_config,discrete_target=False,maximize=True,verbose=False, refitting_steps=10, plot=True):    
    rule_model_config = Rule_Config(X.shape[1],1)

    X_full = torch.tensor(X, dtype=torch.float32)
    X0_tensor = torch.tensor(X0, dtype=torch.float32)
    X1_tensor = torch.tensor(X1, dtype=torch.float32)

    data_limits = get_data_limits(X_full)
    initial_cutpoints = torch.zeros((X_full.shape[1], 2,rule_model_config.n_rules), dtype=torch.float32)
    for i in range(rule_model_config.n_rules):
        if rule_model_config.init_set_size == -1:
            X_init = X_full
        else:
            init_set = np.random.choice(X_full.shape[0], size=rule_model_config.init_set_size, replace=False)
            X_init = X_full[init_set]
        limits = get_data_limits(X_init)
        initial_cutpoints[:,:,i] = limits
        for j in range(X_full.shape[1]):
            if is_discrete[j]:
                initial_cutpoints[j, 0, i] = X_full[:, j].min()
                initial_cutpoints[j, 1, i] = X_full[:, j].max()
    sg_model = RuleLearner(rule_model_config,initial_cutpoints)

    #sg_model.discretizer.is_discrete = is_discrete

    gamma = our_config["gamma"]

    sg_optimizer = torch.optim.Adam(sg_model.parameters(), lr=our_config["lr_classifier"])
    sg_model.train()


    temp_schedule = Temperature_Scheduler(our_config["n_epochs"], rule_model_config.schedule_predicate_temperature)

    losses = []

    js_divs = []
    regs = []
    sg_sizes0 = []
    sg_sizes1 = []
    size_corrections = []

    s0 = sg_model(X0_tensor).squeeze()
    s1 = sg_model(X1_tensor).squeeze()

    if discrete_target:
        Y0_labels = Y0.argmax(dim=1)
        Y1_labels = Y1.argmax(dim=1)

        Y0 = torch.tensor(Y0, dtype=torch.float32)
        Y1 = torch.tensor(Y1, dtype=torch.float32)
        p_y0 = (s0[:,None]*Y0).sum(dim=0)/ (s0.sum()+1e-8)
        p_y0 = p_y0.detach()
        p_y1 = (s1[:,None]*Y1).sum(dim=0)/ (s1.sum()+1e-8)
        p_y1 = p_y1.detach()
        
        log_p0_y0 = torch.log(p_y0*Y0 + 1e-8).sum(dim=1)
        log_p1_y1 = torch.log(p_y1*Y1 + 1e-8).sum(dim=1)

        mixture_0 = (p_y0*Y0 + p_y1*Y0) / 2
        mixture_1 = (p_y0*Y1 + p_y1*Y1) / 2
        mixture_0 = torch.log(mixture_0 + 1e-8).sum(dim=1)
        mixture_1 = torch.log(mixture_1 + 1e-8).sum(dim=1)



        pred0_f0 = RandomForestClassifier().fit(X0, Y0_labels)
        pred1_f1 = RandomForestClassifier().fit(X1, Y1_labels)
        p_y0_f0 = pred0_f0.predict_proba(X0)
        p_y1_f1 = pred1_f1.predict_proba(X1)
        p_y0_f0 = torch.tensor(p_y0_f0, dtype=torch.float32)
        p_y1_f1 = torch.tensor(p_y1_f1, dtype=torch.float32)

    else:
        p_y0 = gaussian_kde(Y0, weights=s0.squeeze().detach().numpy())
        p_y1 = gaussian_kde(Y1, weights=s1.squeeze().detach().numpy())
        log_p0_y0 = torch.tensor(p_y0.logpdf(Y0), dtype=torch.float32)
        log_p1_y1 = torch.tensor(p_y1.logpdf(Y1), dtype=torch.float32)
    
        mixture_0 = (p_y0.pdf(Y0) + p_y1.pdf(Y0)) / 2
        mixture_1 = (p_y0.pdf(Y1) + p_y1.pdf(Y1)) / 2
        mixture_0 = torch.tensor(np.log(mixture_0), dtype=torch.float32)
        mixture_1 = torch.tensor(np.log(mixture_1), dtype=torch.float32)

        pred0_f0 = RandomForestRegressor().fit(X0, Y0)
        pred1_f1 = RandomForestRegressor().fit(X1, Y1)
        p_y0_f0 = torch.tensor(pred0_f0.predict(X0), dtype=torch.float32)
        p_y1_f1 = torch.tensor(pred1_f1.predict(X1), dtype=torch.float32)


    refit_counter = 0

    # train normal sg
    for epoch in range(our_config["n_epochs"]):
        sg_optimizer.zero_grad()
        s0 = sg_model(X0_tensor).squeeze()
        s1 = sg_model(X1_tensor).squeeze()

        # check how many nans
        if refit_counter >= refitting_steps:
            if discrete_target:
                p_y0 = (s0[:,None]*Y0).sum(dim=0)/ (s0.sum()+1e-8)
                p_y0 = p_y0.detach()
                p_y1 = (s1[:,None]*Y1).sum(dim=0)/ (s1.sum()+1e-8)
                p_y1 = p_y1.detach()
                
                log_p0_y0 = torch.log(p_y0*Y0 + 1e-8).sum(dim=1)
                log_p1_y1 = torch.log(p_y1*Y1 + 1e-8).sum(dim=1)
                mixture_0 = (p_y0*Y0 + p_y1*Y0)/ 2
                mixture_1 = (p_y0*Y1 + p_y1*Y1) / 2
                mixture_0 = torch.log(mixture_0 + 1e-8).sum(dim=1)
                mixture_1 = torch.log(mixture_1 + 1e-8).sum(dim=1)
            else:
                p_y0= gaussian_kde(Y0, weights=s0.squeeze().detach().numpy())
                p_y1 = gaussian_kde(Y1, weights=s1.squeeze().detach().numpy())
                log_p0_y0 = torch.tensor(p_y0.logpdf(Y0), dtype=torch.float32)
                log_p1_y1 = torch.tensor(p_y1.logpdf(Y1), dtype=torch.float32)

                mixture_0 = (p_y0.pdf(Y0) + p_y1.pdf(Y0)) / 2
                mixture_1 = (p_y0.pdf(Y1) + p_y1.pdf(Y1)) / 2
                mixture_0 = torch.tensor(np.log(mixture_0), dtype=torch.float32)
                mixture_1 = torch.tensor(np.log(mixture_1), dtype=torch.float32)

            refit_counter = 0
        else:
            refit_counter += 1

        kl_div_0 = (s0 * (log_p0_y0 - mixture_0)).sum() / (s0.sum() + 1e-8)
        kl_div_1 = (s1 * (log_p1_y1 - mixture_1)).sum() / (s1.sum() + 1e-8)


        if maximize:
            d0 = - kl_div_0
            d1 = - kl_div_1
            size_correction = (s0.mean()*s1.mean())**(gamma/2)
            #size_correction = torch.tensor(1.0)
            
        else:
            d0 = kl_div_0
            d1 = kl_div_1
            size_correction = (1/(s0.mean()*s1.mean()))**(gamma/2)
            #size_correction = torch.tensor(1.0)
        
        # regularizer
        if discrete_target:
            reg_0 = (p_y0_f0*(torch.log(p_y0_f0+1e-8) - torch.log(p_y0[None,:] + 1e-8))).sum(dim=1)
            reg_1 = (p_y1_f1*(torch.log(p_y1_f1+1e-8) - torch.log(p_y1[None,:] + 1e-8))).sum(dim=1)

            difference_0 = (s0*reg_0).sum() / (s0.sum() + 1e-8)
            difference_1 = (s1*reg_1).sum() / (s1.sum() + 1e-8)

        else:
            mean_0 = (p_y0_f0 * s0).sum() / (s0.sum() + 1e-8)
            mean_1 = (p_y1_f1 * s1).sum() / (s1.sum() + 1e-8)
            mean_0 = mean_0.detach()
            mean_1 = mean_1.detach()

            l2_distance_0 = (p_y0_f0 - mean_0)**2
            l2_distance_1 = (p_y1_f1 - mean_1)**2

            difference_0 = (s0 * l2_distance_0).sum() / (s0.sum() + 1e-8)
            difference_1 = (s1 * l2_distance_1).sum() / (s1.sum() + 1e-8)

        regularizer = (difference_0 + difference_1) / 2

        if epoch > our_config["n_epochs"] // 2 or True:
            # after half the epochs, we add the regularizer to the loss
            loss = (d0 + d1) * size_correction
            if our_config["lambd"]>0:
                loss = loss + our_config["lambd"] * regularizer #* size_correction
        else:
            loss = (d0 + d1) * size_correction
        loss.backward()

        losses.append(loss.item())
        js_divs.append(((kl_div_0 + kl_div_1)).item())
        sg_sizes0.append(s0.mean().item())
        sg_sizes1.append(s1.mean().item())
        size_corrections.append(size_correction.item())
        regs.append(regularizer.item())

        sg_optimizer.step()


        if epoch % 100 == 0 and verbose:
            print(f"Epoch {epoch}, Loss: {loss.item()}")
            rule = sg_model.get_rule(0,data_limits,scaler_x=scaler_X,feature_names=features_names)
            # cover
            print(f"Rule: {rule}")
            cov0 = (s0> 0.5).float().mean().item()
            cov1 = (s1> 0.5).float().mean().item()
            print(f"Rule coverage  0: {cov0:.2f}, Rule coverage 1: {cov1:.2f}")
        
        sg_model.discretizer.temperature = temp_schedule.get_temperature()     

    if plot:
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 3, 1)
        plt.suptitle("Maximizing JS Divergence" if maximize else "Minimizing JS Divergence")
        plt.title("Loss over epochs")
        plt.plot(losses, label='Loss')
        plt.plot(js_divs, label='Weighted Exceptionality')
        plt.plot(regs, label='Regularizer')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 3, 2)
        plt.ylim(0,1)
        plt.plot(sg_sizes0, label='Size of Subgroup A=0')
        plt.plot(sg_sizes1, label='Size of Subgroup A=1')
        plt.plot(size_corrections, label='Size Correction')
        plt.legend()

        plt.subplot(1, 3, 3)
        rule = sg_model.get_rule(0,data_limits,scaler_x=scaler_X,feature_names=features_names)
        if discrete_target:
            plt.bar(np.arange(len(p_y0)), p_y0.detach().numpy(), label='P(Y|A=0, s(X)=1)', width=0.3)
            plt.bar(np.arange(len(p_y1))+0.2, p_y1.detach().numpy(), label='P(Y|A=1, s(X)=1)', width=0.3)
        else:
            Y_domain = np.linspace(Y.min(), Y.max(), 100)
            plt.plot(scaler_Y.inverse_transform(Y_domain[:,None]), p_y0.pdf(Y_domain), label='P(Y|A=0, s(X)=1)')
            plt.plot(scaler_Y.inverse_transform(Y_domain[:,None]), p_y1.pdf(Y_domain), label='P(Y|A=1), s(X)=1)')
        rule_coverage0 = (s0 > 0.5).float().mean().item()
        rule_coverage1 = (s1 > 0.5).float().mean().item()
        plt.title(f"Coverage A=0: {rule_coverage0:.2f}, A=1: {rule_coverage1:.2f}")
        plt.legend()
        plt.show()
        print(rule)
    return sg_model, s0 > 0.5, s1 > 0.5
    

def run_method_multiple_times(X,X0,X1,Y,Y0,Y1,scaler_X,scaler_Y,features_names,is_discrete,our_config,maximize=True,verbose=False, refitting_steps=10, max_reps=5, plot=True, discrete_target=False, remove_data=False):
    s0_labels = np.zeros(X0.shape[0], dtype=bool)
    s1_labels = np.zeros(X1.shape[0], dtype=bool)
    s0_indices = np.arange(X0.shape[0])
    s1_indices = np.arange(X1.shape[0])
    sg_labels = []
    sg_rules = []
    sg_models = []
    for i in range(max_reps):
        if verbose:
            print(f"Run {i+1}/{max_reps}")
        sg_model, s0_sg, s1_sg = run_method_single_subgroup(X,X0,X1,Y,Y0,Y1,scaler_X,scaler_Y,features_names,is_discrete,our_config,maximize=maximize,verbose=verbose, refitting_steps=refitting_steps, plot=plot, discrete_target=discrete_target)

        s0_labels[s0_indices[s0_sg]] = i + 1
        s1_labels[s1_indices[s1_sg]] = i + 1

        s0_mask = np.zeros(s0_labels.shape[0], dtype=bool)
        s1_mask = np.zeros(s1_labels.shape[0], dtype=bool)
        s0_mask[s0_indices[s0_sg]] = True
        s1_mask[s1_indices[s1_sg]] = True
        sg_labels.append((s0_mask, s1_mask))

        if remove_data:
            X0 = X0[~s0_sg]
            X1 = X1[~s1_sg]
            Y0 = Y0[~s0_sg]
            Y1 = Y1[~s1_sg]
            s0_indices = s0_indices[~s0_sg]
            s1_indices = s1_indices[~s1_sg]
        sg_rules.append(sg_model.get_rule(0,get_data_limits(torch.tensor(X, dtype=torch.float32)),scaler_x=scaler_X,feature_names=features_names))
        sg_models.append(sg_model)

        if X0.shape[0] < 10 or X1.shape[0] < 10:
            break
        
        
    return sg_labels, sg_rules, sg_models

