import rpy2.robjects as robjects
import numpy as np
from rpy2.robjects import pandas2ri, numpy2ri
from rpy2.robjects import conversion
import pandas as pd
import random

pandas2ri.activate()
numpy2ri.activate()


robjects.r['source']('data_simulations.R')
sampleCorrelationMatrix = robjects.globalenv['sampleCorrelationMatrix']
generate_train_data = robjects.globalenv['generate_train_data']
generate_test_data = robjects.globalenv['generate_test_data']

def generate_train_test_data(n_samples_train, n_samples_test, corr_matrix, eta, num_cont_covariates, num_disc_covariates,
                             train_cont_margin_params, train_cont_margin_family, train_disc_margin_params,
                             train_disc_margin_family, test_cont_margin_params, test_cont_margin_family,
                             test_disc_margin_params, test_disc_margin_family, treatment_family,
                             prop_score_params, causal_effect_family, causal_effect_params, seed = 42):    
    random.seed(seed)
    num_covariates = num_cont_covariates + num_disc_covariates
    # corr_matrix_trial = sampleCorrelationMatrix(num_covariates + 1, eta)
    corr_matrix_trial = robjects.r.matrix(corr_matrix, nrow=corr_matrix.shape[0])

    # Convert Python lists to R-compatible vectors
    train_cont_margin_params = robjects.FloatVector(train_cont_margin_params)
    train_disc_margin_params = robjects.FloatVector(train_disc_margin_params)
    test_cont_margin_params = robjects.FloatVector(test_cont_margin_params)
    test_disc_margin_params = robjects.FloatVector(test_disc_margin_params)
    prop_score_params = robjects.FloatVector(prop_score_params)
    causal_effect_params = robjects.FloatVector(causal_effect_params)

    # Generate training data
    train_data = generate_train_data(
        n_samples_train, corr_matrix_trial, num_cont_covariates, num_disc_covariates, train_cont_margin_params,
        train_cont_margin_family,
        train_disc_margin_params,
        train_disc_margin_family,
        test_cont_margin_params,
        test_cont_margin_family,
        test_disc_margin_params,
        test_disc_margin_family,
        treatment_family,
        prop_score_params,
        causal_effect_family,
        causal_effect_params)

    # Access the elements in the returned R list using index or name-based accessors
    Z_train = train_data.rx2('Z_train')
    X_train = train_data.rx2('X_train')
    Y_train = train_data.rx2('Y_train')

    # Combine the data into a Pandas DataFrame
    data_samples_train = pd.DataFrame({'X': np.array(X_train).flatten(), 'Y': np.array(Y_train).flatten()})
    
    # Add each column of Z_train separately
    for i in range(Z_train.shape[1]):  # Assumes Z_train is a 2D array
        data_samples_train[f'Z{i+1}'] = Z_train[:, i]

    # Generate testing data
    test_data = generate_test_data(
        n_samples_test,
        corr_matrix_trial,
        num_cont_covariates,
        num_disc_covariates,
        test_cont_margin_params,
        test_cont_margin_family,
        test_disc_margin_params,
        test_disc_margin_family,
        treatment_family,
        prop_score_params,
        causal_effect_family,
        causal_effect_params)

    Z_test = test_data.rx2('Z_test')
    X_test = test_data.rx2('X_test')
    Y_test = test_data.rx2('Y_test')

    data_samples_test = pd.DataFrame({'X': np.array(X_test).flatten(), 'Y': np.array(Y_test).flatten()})
    
    # Add each column of Z_test separately
    for i in range(Z_test.shape[1]):
        data_samples_test[f'Z{i+1}'] = Z_test[:, i]

    # Convert the R data frame to a pandas DataFrame
    with conversion.localconverter(robjects.default_converter + pandas2ri.converter):
        data_samples_train = conversion.rpy2py(data_samples_train)
        data_samples_test = conversion.rpy2py(data_samples_test)
        
    return data_samples_train, data_samples_test
