import numpy as np
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from catenets.models.torch import TARNet
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, GradientBoostingClassifier
from sklearn.linear_model import LinearRegression
from econml.dml import CausalForestDML
from econml.dr import DRLearner
import torch
from engression import engression

import zepid
from zepid.causal.doublyrobust import TMLE


import rpy2
from rpy2.robjects.packages import importr
from rpy2.robjects.packages import SignatureTranslatedAnonymousPackage

rpy2.robjects.numpy2ri.activate()
from rpy2.robjects import r, pandas2ri,numpy2ri
from rpy2.robjects.vectors import StrVector
import logging


# Enable automatic conversion of pandas and numpy objects to R objects


logging.basicConfig(format='%(asctime)s %(message)s', level=logging.ERROR)
numpy2ri.activate()
pandas2ri.activate()

utils = importr('utils')
utils = rpy2.robjects.packages.importr('utils')
utils.chooseCRANmirror(ind=1) # select the first mirror in the list
packnames = ('dbarts', 'grf', 'marginaleffects')
names_to_install = [x for x in packnames if not rpy2.robjects.packages.isinstalled(x)]
if len(names_to_install) > 0:
    utils.install_packages(StrVector(names_to_install))


def S_Linear_fit(Z_train, X_train, Y_train):
    """
    Fit an S-learner using Linear Regression.
    
    Z_train: Covariates
    X_train: Treatment assignment
    Y_train: Outcome
    
    Returns:
    s_learner: Trained Linear Regression model
    """
    # Combine treatment and covariates for S-learner
    W_train = np.column_stack((Z_train, X_train))
    
    # Fit a Linear Regression model
    s_linear = LinearRegression().fit(W_train, Y_train)
    
    return s_linear

def S_Linear_predict(Z_test, s_linear):
    """
    Predict using S-learner fitted model.

    Z_test: Covariates for test data.
    s_learner: Trained S-learner model.
    
    Returns:
    cate_estimates: Estimated Conditional Average Treatment Effect (CATE).
    potential_outcomes_control: Predicted potential outcomes under control.
    potential_outcomes_treated: Predicted potential outcomes under treatment.
    """
    # Z_test | X=1 (treated)
    W_test_treated = np.column_stack((Z_test, np.ones(Z_test.shape[0])))

    # Z_test | X=0 (control)
    W_test_control = np.column_stack((Z_test, np.zeros(Z_test.shape[0])))

    # Predict potential outcomes under treatment and control
    potential_outcomes_treated = s_linear.predict(W_test_treated)
    potential_outcomes_control = s_linear.predict(W_test_control)

    # CATE is the difference between treated and control predictions
    cate_estimates = potential_outcomes_treated - potential_outcomes_control

    return cate_estimates, potential_outcomes_control, potential_outcomes_treated

def T_Linear_fit(Z_train, X_train, Y_train):
    """
    Fit a T-learner using Linear Regression.

    Z_train: Covariates
    X_train: Treatment assignment
    Y_train: Outcome
    
    Returns:
    lr_treated: Linear Regression model for the treated group.
    lr_control: Linear Regression model for the control group.
    """
    # Split data into treated and control groups
    treated_indices = X_train == 1
    control_indices = X_train == 0
    
    # Fit separate models for treated and control groups
    lr_treated = LinearRegression().fit(Z_train[treated_indices], Y_train[treated_indices])
    lr_control = LinearRegression().fit(Z_train[control_indices], Y_train[control_indices])
    
    return lr_control, lr_treated

def T_Linear_predict(Z_test, lr_control, lr_treated):
    """
    Predict using T-learner fitted models.
    
    Z_test: Covariates for test data.
    lr_control: Trained Linear Regression model for control group.
    lr_treated: Trained Linear Regression model for treated group.
    
    Returns:
    cate_estimates: Estimated Conditional Average Treatment Effect (CATE).
    potential_outcomes_control: Predicted potential outcomes under control.
    potential_outcomes_treated: Predicted potential outcomes under treatment.
    """
    # Predict potential outcomes for control and treated groups
    potential_outcomes_control = lr_control.predict(Z_test)
    potential_outcomes_treated = lr_treated.predict(Z_test)
    
    # CATE is the difference between treated and control predictions
    cate_estimates = potential_outcomes_treated - potential_outcomes_control

    return cate_estimates, potential_outcomes_control, potential_outcomes_treated



def TARNet_fit(Z_train, Y_train, X_train, hyperparams = {
    'n_layers_r': 2,  # Number of layers in the representation network
    'n_units_r': 128,  # Number of units per layer in the representation network
    'batch_size': 64,  # Batch size for training
    'lr': 0.0001,  # Learning rate
    'n_iter': 2000 # Number of iterations (epochs)
}):
    # Initialize TARNet (similar to CFRNet but simpler)
    tarnet = TARNet(Z_train.shape[1], n_layers_r=hyperparams['n_layers_r'],
                n_units_r=hyperparams['n_units_r'],
                batch_size=hyperparams['batch_size'],
                lr=hyperparams['lr'],
                n_iter=hyperparams['n_iter'])
    
    # Fit the model
    tarnet.fit(Z_train, Y_train, X_train)

    return tarnet

def TARNet_predict(Z_test,tarnet):
    result_test = tarnet.predict(Z_test,return_po=True)
    cate_test=result_test[0].detach().numpy()

    # Predict potential outcomes under treatment (T=1) and no treatment (T=0)
    mu1_test = result_test[2].detach().numpy()  # Potential outcome under treatment
    mu0_test = result_test[1].detach().numpy()  # Potential outcome under no treatment

    return cate_test, mu0_test, mu1_test

def CausalForestDML_fit(Z_train, Y_train, X_train):
    causal_forest = CausalForestDML(model_t=RandomForestRegressor(),
                                model_y=RandomForestRegressor(),
                                n_estimators=1000, min_samples_leaf=10, max_depth=3)
    # Fit the model
    causal_forest.fit(Y_train, X_train, X=Z_train)
    return causal_forest
def CausalForestDML_predict(Z_test, causal_forest):
    cate_test = causal_forest.effect(Z_test)
    mu0_test = causal_forest.effect(Z_test, T0=0)  # Counterfactual outcome under no treatment
    mu1_test = mu0_test + cate_test
    return cate_test, mu0_test, mu1_test


def DRLearner_fit(Z_train, Y_train, X_train, regressor = 'rf'):
    if regressor == 'gbdt':
        dr_learner = DRLearner(model_regression=GradientBoostingRegressor(),
                    model_propensity=GradientBoostingClassifier(),
                    model_final=GradientBoostingRegressor())
    elif regressor == 'rf':
        dr_learner = DRLearner(model_t=RandomForestRegressor(),
                                model_y=RandomForestRegressor(),
                                n_estimators=1000, min_samples_leaf=10, max_depth=3)

    
    dr_learner.fit(y=Y_train, T=X_train, X=Z_train)
    return dr_learner
def DRLearner_predict(Z_test, dr_learner):
    cate_test = dr_learner.effect(Z_test)
    mu0_test = dr_learner.effect(Z_test, T0=0)
    mu1_test = mu0_test + cate_test
    return cate_test, mu0_test, mu1_test


def S_engression_fit(Z_train, X_train, Y_train, lr=0.01, num_epochs=500, batch_size=100, device='cpu', verbose=False):
    """
    Fit an S-learner using Engression.
    Z_train: Covariates
    X_train: Treatment assignment
    Y_train: Outcome
    
    Returns:
    engressor: Trained Engression model
    """
    # Combine treatment and covariates (S-learner structure)
    W_train = torch.cat([torch.Tensor(Z_train), torch.Tensor(X_train).unsqueeze(1)], dim=1).unsqueeze(2).to(device)
    Y_train = torch.Tensor(Y_train).unsqueeze(1).to(device)

    # Initialize and train Engression model
    engressor = engression(x=W_train, y=Y_train, lr=lr, num_epochs=num_epochs, batch_size=batch_size, device=device, verbose=verbose)
    # engressor.summary()

    return engressor

def S_engression_predict(Z_test, engressor, device='cpu'):
    # Convert Z_test to a tensor and move to the specified device
    Z_test_tensor = torch.Tensor(Z_test).to(device)
    
    # Create two test datasets: one with X=1 and one with X=0
    # Z_test | X=1
    W_test_treated = torch.cat([Z_test_tensor, torch.ones(Z_test_tensor.size(0), 1).to(device)], dim=1).unsqueeze(2)
    
    # Z_test | X=0
    W_test_control = torch.cat([Z_test_tensor, torch.zeros(Z_test_tensor.size(0), 1).to(device)], dim=1).unsqueeze(2)
    
    # Predict potential outcomes under treatment (X=1)
    potential_outcomes_treated = engressor.predict(W_test_treated, sample_size = 100).detach().cpu().numpy()
    
    # Predict potential outcomes under control (X=0)
    potential_outcomes_control = engressor.predict(W_test_control, sample_size = 100).detach().cpu().numpy()
    
    # Calculate the Conditional Average Treatment Effect (CATE)
    cate_estimates = potential_outcomes_treated - potential_outcomes_control
    
    return cate_estimates, potential_outcomes_control, potential_outcomes_treated, 


def T_engression_fit(Z_train, X_train, Y_train, lr=0.01, num_epochs=500, batch_size=64, device='cpu',verbose=False):
    """
    Fit a T-learner using Engression.
    Z_train: Covariates
    X_train: Treatment assignment
    Y_train: Outcome
    
    Returns:
    engressor_treated: Model for treated group
    engressor_control: Model for control group
    """
    # Split the data into treated and control groups
    treated_indices = X_train == 1
    control_indices = X_train == 0
    
    Z_treated = torch.Tensor(Z_train[treated_indices]).unsqueeze(2).to(device)
    Y_treated = torch.Tensor(Y_train[treated_indices]).unsqueeze(1).to(device)
    
    Z_control = torch.Tensor(Z_train[control_indices]).unsqueeze(2).to(device)
    Y_control = torch.Tensor(Y_train[control_indices]).unsqueeze(1).to(device)

    # Fit Engression model for treated group
    engressor_treated = engression(x=Z_treated, y=Y_treated, lr=lr, num_epochs=num_epochs, batch_size=batch_size, device=device, verbose=verbose)
    # engressor_treated.summary()

    # Fit Engression model for control group
    engressor_control = engression(x=Z_control, y=Y_control, lr=lr, num_epochs=num_epochs, batch_size=batch_size, device=device, verbose=verbose)
    # engressor_control.summary()

    return engressor_control, engressor_treated


def T_engression_predict(Z_test, engressor_control, engressor_treated, device='cpu'):
    """
    Predict potential outcomes using the trained T-learner Engression models and calculate CATE.
    
    Z_test: Covariates for which to predict CATE.
    engressor_control: Trained Engression model for control group.
    engressor_treated: Trained Engression model for treated group.
    
    Returns:
    cate_estimates: Estimated Conditional Average Treatment Effect (CATE) for each test sample.
    potential_outcomes_treated: Predicted outcomes under treatment (X=1).
    potential_outcomes_control: Predicted outcomes under control (X=0).
    """
    # Convert Z_test to a tensor and move to the specified device
    Z_test_tensor = torch.Tensor(Z_test).unsqueeze(2).to(device)

    # Predict potential outcomes under control (X=0)
    potential_outcomes_control = engressor_control.predict(Z_test_tensor, sample_size = 100).detach().cpu().numpy()
    
    # Predict potential outcomes under treatment (X=1)
    potential_outcomes_treated = engressor_treated.predict(Z_test_tensor, sample_size = 100).detach().cpu().numpy()

    # Calculate the Conditional Average Treatment Effect (CATE)
    cate_estimates = potential_outcomes_treated - potential_outcomes_control
    
    return cate_estimates, potential_outcomes_control, potential_outcomes_treated

def T_BART_fit(Z_train, X_train, Y_train):
    """
    Fit Bayesian Additive Regression Trees (BART) model for causal inference.

    Parameters:
    Z_train (pd.DataFrame): Covariates.
    X_train (pd.Series or np.array): Treatment (binary).
    Y_train (pd.Series or np.array): Outcome.

    Returns:
    bart_model_control, bart_model_treatment: Fitted BART models for control and treatment groups.
    covariates (list): List of covariate names to be used in predictions.
    """
    utils = importr('utils')
    dbarts = importr('dbarts')

   # Split the data into control and treatment groups
    control_data = Z_train[X_train == 0]
    treatment_data = Z_train[X_train == 1]

    # Extract the outcomes for control and treatment groups
    Y_control = Y_train[X_train == 0]
    Y_treatment = Y_train[X_train == 1]

    # Convert Y arrays to 1D if necessary
    if Y_control.ndim == 2:
        Y_control = Y_control.ravel()
    if Y_treatment.ndim == 2:
        Y_treatment = Y_treatment.ravel()

    # Fit BART models for the control and treatment groups
    bart_control = dbarts.bart(control_data, Y_control, keeptrees=True, verbose=False)
    bart_treatment = dbarts.bart(treatment_data, Y_treatment, keeptrees=True, verbose=False)

    return bart_control, bart_treatment


def T_BART_predict(Z_test, bart_model_control, bart_model_treatment):
    # Import the R 'predict' function
    r_predict = r['predict']

    # Convert Z_test into the appropriate format for R if necessary (NumPy to R conversion is handled automatically)
    Z_test_array = Z_test

    # Predict potential outcomes under control (X=0) using the control BART model
    predicted_outcomes_control = np.array(r_predict(bart_model_control, Z_test_array))

    # Predict potential outcomes under treatment (X=1) using the treatment BART model
    predicted_outcomes_treatment = np.array(r_predict(bart_model_treatment, Z_test_array))

    # Calculate CATE as the difference between treated and control outcomes
    cate_estimates = predicted_outcomes_treatment - predicted_outcomes_control

    return cate_estimates, predicted_outcomes_control, predicted_outcomes_treatment


def S_BART_fit(Z_train, X_train, Y_train):
    utils = importr('utils')
    dbarts = importr('dbarts')

    W_train = np.column_stack((Z_train, X_train))

    # Convert the outcome Y_train to a 1D array if necessary
    if Y_train.ndim == 2 and Y_train.shape[1] == 1:
        Y_train = Y_train.ravel()

    # Fit a single BART model with both covariates and treatment as predictors
    bart_model = dbarts.bart(W_train, Y_train, keeptrees=True, verbose=False)

    return bart_model



def S_BART_predict(Z_test, bart_model):
    r_predict = r['predict']

    # Create two versions of the test set:
    # Z_test | X = 1 (treated)
    Z_test_treated = np.column_stack((Z_test, np.ones(Z_test.shape[0])))

    # Z_test | X = 0 (control)
    Z_test_control = np.column_stack((Z_test, np.zeros(Z_test.shape[0])))

    # Predict potential outcomes under treatment (X=1) using R's predict function
    potential_outcomes_treated = np.array(r_predict(bart_model, Z_test_treated))

    # Predict potential outcomes under control (X=0) using R's predict function
    potential_outcomes_control = np.array(r_predict(bart_model, Z_test_control))

    # Calculate CATE as the difference between treated and control outcomes
    cate_estimates = potential_outcomes_treated - potential_outcomes_control

    return cate_estimates, potential_outcomes_control, potential_outcomes_treated
