import rpy2.robjects as robjects
import numpy as np
from rpy2.robjects import pandas2ri, numpy2ri
from rpy2.robjects import conversion
import pandas as pd


pandas2ri.activate()
numpy2ri.activate()


robjects.r['source']('inference_utils.R')
experiments_generate_realistic_data  = robjects.globalenv['experiments_generate_realistic_data']

## if you want to fix the parameters and simulate different test data, just change the random seed in test_data_seed
##  and keep the marginal_cdf_seed as well as the training_data_seed the same
def generate_generate_realistic_data(train_df, test_df, n_samples_test=10, n_samples_train=10, marginal_cdf_seed = 2,
  training_data_seed = 1,
  test_data_seed = 2,
  prop_score_params = None):

    # Duplicate the train
    duplicated_dat = test_df.copy()
    # Invert the treatment variable (if it's boolean)
    # duplicated_dat['treatment'] = ~duplicated_dat['treatment']
    duplicated_dat['treatment'] = 1-duplicated_dat['treatment']

    # Mutate and rename columns similar to the R code
    duplicated_dat['y_factual_new'] = duplicated_dat['y_cfactual']
    duplicated_dat['y_cfactual_new'] = duplicated_dat['y_factual']

    # Drop the old columns and rename the new ones
    duplicated_dat = duplicated_dat.drop(columns=['y_factual', 'y_cfactual'])
    duplicated_dat = duplicated_dat.rename(columns={
        'y_factual_new': 'y_factual',
        'y_cfactual_new': 'y_cfactual'
    })

    # Combine the original and duplicated datasets
    combined_dat = pd.concat([test_df, duplicated_dat], ignore_index=True)

    # Select the relevant columns and rename them similarly to R
    dat = combined_dat.loc[:, ['treatment', 'y_factual'] + [f'Z{i}' for i in range(1, 26)]]

    # Rename columns for Y and X
    dat = dat.rename(columns={'y_factual': 'Y', 'treatment': 'X'})

    # Rescale Z14 (subtract 1) and adjust Y as in the R code
    dat['Z14'] = dat['Z14'] - 1
    dat['Y'] = dat['Y'] - dat['Y'].min() + 1e-4

    # Add jitter to columns Z1:Z6
    jitter_columns = ['Z1', 'Z2', 'Z3', 'Z4', 'Z5', 'Z6']
    dat_with_jitter = dat.copy()
    for col in jitter_columns:
        dat_with_jitter[col] += np.random.normal(loc=0, scale=0.0001, size=dat_with_jitter.shape[0])

    # Separate Y, X, and Z
    Y = dat_with_jitter['Y']
    X = dat_with_jitter['X']
    Z = dat_with_jitter.drop(columns=['X', 'Y'])
    # Reset index to avoid duplicate indices
    Z.reset_index(drop=True, inplace=True)


    # Get dimensions
    d = Z.shape[1]
    n = Z.shape[0]

    ## feature engineering on train_df
    train_df = train_df.loc[:, [f'Z{i}' for i in range(1, 26)]]
    train_df['Z14'] = train_df['Z14']-1
    Z_train = train_df.copy()
    for col in jitter_columns:
        Z_train[col] += np.random.normal(loc=0, scale=0.0001, size=Z_train.shape[0])
    
    Z_train.reset_index(drop=True, inplace=True)

    # Ensure that Z, Y, X are numeric
    Z = Z.astype(float)
    Y = Y.astype(float)
    X = X.astype(float)
    Z_train = Z_train.astype(float)

    
    # Prepare to pass to R
    r_Y = robjects.FloatVector(Y)
    r_X = robjects.FloatVector(X)
    r_Z = pandas2ri.py2rpy(Z)
    r_Z_train = pandas2ri.py2rpy(Z_train)
    
    if prop_score_params is None:
        prop_score_params=np.array([0] * (d + 1))
    elif not isinstance(prop_score_params, np.ndarray):
        prop_score_params = np.array(prop_score_params)
    
    full_sim_data = experiments_generate_realistic_data(
        n_samples_test=n_samples_test,
        n_samples_train=n_samples_train,
        prop_score_params=prop_score_params,
        Y_test=r_Y,
        X_test=r_X,
        Z_test=r_Z,
        Z_train=r_Z_train,
        cov_datatype=np.array([0] * 6 + [1] * (d - 6)),
        marginal_cdf_seed = marginal_cdf_seed,
        training_data_seed = training_data_seed,
        test_data_seed = test_data_seed
    )


    # Extract 'train_data' and 'test_data' from 'full_sim_data'
    train_data = full_sim_data.rx2('train_data')  # Equivalent to full_sim_data$train_data in R
    test_data = full_sim_data.rx2('test_data')    # Equivalent to full_sim_data$test_data in R

    # Check if objects already have the correct format
    sim_Z_train = train_data.rx2('Z_train').astype('float')  # Extract Z_train
    sim_Y_train = train_data.rx2('Y_train').astype('float')  # Extract Y_train
    sim_Y_train = sim_Y_train.reshape(-1)
    sim_X_train = np.array(train_data.rx2('X_train')).astype('float')  # Extract X_train

    sim_Z_test = test_data.rx2('Z_test').astype('float')  # Extract Z_test
    sim_Y_test = test_data.rx2('Y_test').astype('float')  # Extract Y_test
    sim_Y_test = sim_Y_test.reshape(-1)
    sim_X_test = np.array(test_data.rx2('X_test')).astype('float') # Extract X_test


    return sim_Z_train, sim_Y_train, sim_X_train, sim_Z_test, sim_Y_test, sim_X_test


