import numpy as np
import pandas as pd
import statsmodels.api as sm

datasets = ['B', 'C', 'E', 'L', 'H2', 'H']
checkpoints = ['50-c', '100-c', '150-c', '200-c', '250-c',
               '300-c', '350-c', '400-c', '450-c', '500-c']

asr_values = {
    'B': [65.00, 71.25, 62.50, 65.00, 68.75, 67.50, 65.00, 61.25, 66.25, 65.00],
    'C': [61.25, 60.00, 60.00, 65.00, 61.25, 65.00, 57.50, 61.25, 65.00, 70.00],
    'E': [57.50, 62.50, 60.00, 58.75, 61.25, 63.75, 61.25, 60.00, 63.75, 60.00],
    'L': [62.50, 61.25, 63.75, 61.25, 65.00, 71.25, 65.00, 63.75, 66.25, 57.50],
    'H2': [66.25, 67.50, 68.75, 65.00, 70.00, 71.25, 70.00, 73.75, 70.00, 70.00],
    'H': [71.25, 71.25, 71.25, 70.00, 70.00, 66.25, 70.00, 67.50, 71.25, 68.75],
}

drift_values = {
    'B': [0.00027881, 0.00045183, 0.00018238, 0.00028907, 0.00053083,
          0.00093876, 0.00139291, 0.00193000, 0.00248633, 0.00255867],
    'C': [0.00006913, 0.00012138, 0.00011745, 0.00016412, 0.00016192,
          0.00009768, 0.00007791, 0.00006973, 0.00006056, 0.00004712],
    'E': [0.00009826, 0.00015218, 0.00011106, 0.00018275, 0.00023687,
          0.00018143, 0.00012203, 0.00007132, 0.00006210, 0.00005945],
    'L': [0.00009014, 0.00016387, 0.00010130, 0.00008583, 0.00010755,
          0.00012990, 0.00015618, 0.00018942, 0.00015409, 0.00015265],
    'H2': [0.00144813, 0.00278553, 0.00663939, 0.00630987, 0.00507533,
           0.00269942, 0.00138713, 0.00066924, 0.00027455, 0.00010488],
    'H': [0.00647930, 0.01156052, 0.00778648, 0.00788402, 0.00217110,
          0.00271518, 0.00203949, 0.00076138, 0.00027753, 0.00010579],
}

features = pd.DataFrame({
    'dataset': datasets,
    'ptoxicity': [0.0016, 0.0062, 0.0008, 0.0007, 0.0292, 0.0328],
    'plength':   [13.0, 47.5, 30.2, 40.7, 15.1, 16.9],
    'psentiment':[0.06, 0.0376, 0.0097, 0.0218, -0.0583, -0.0147],
    'pttr': [0.958, 0.921, 0.923, 0.856, 0.969, 0.966],
    'rtoxicity': [0.0044, 0.0009, 0.0007, 0.0008, 0.0162, 0.0203],
    'rttr': [0.848, 0.757, 0.841, 0.882, 0.67, 0.636]
})

agg = []
for ds in datasets:
    asr_mean = np.mean(asr_values[ds])
    drift_mean = np.mean(drift_values[ds])
    feats = features[features['dataset'] == ds].iloc[0].to_dict()
    agg.append({'dataset': ds, 'asr': asr_mean, 'cosine_drift': drift_mean, **feats})

df = pd.DataFrame(agg)

for col in df.columns:
    if col not in ['dataset']:
        df[col] = (df[col] - df[col].mean()) / df[col].std()

def run_regression(y, X):
    X = sm.add_constant(X)
    return sm.OLS(y, X).fit()

def mediation_analysis(df, exposure, mediator, outcome):
    M = df[mediator]
    Y = df[outcome]
    X_exp = df[[exposure]]

    med_model = run_regression(M, X_exp)
    out_model = run_regression(Y, df[[exposure, mediator]])

    a = med_model.params[exposure]
    b = out_model.params[mediator]
    c_prime = out_model.params[exposure]

    indirect_effect = a * b
    direct_effect = c_prime
    total_effect = indirect_effect + direct_effect
    prop_mediated = indirect_effect / total_effect if total_effect != 0 else np.nan

    return {
        'feature': exposure,
        'indirect_effect': indirect_effect,
        'direct_effect': direct_effect,
        'total_effect': total_effect,
        'prop_mediated': prop_mediated,
        'indirect_pval': med_model.pvalues[exposure], 
        'direct_pval': out_model.pvalues[exposure],   
        'total_pval': out_model.f_pvalue             
    }

results = []
for feat in ['ptoxicity', 'plength', 'psentiment', 'pttr', 'rtoxicity', 'rttr']:
    results.append(mediation_analysis(df, exposure=feat, mediator='cosine_drift', outcome='asr'))

results_df = pd.DataFrame(results)
pd.set_option('display.max_columns', None)
print(results_df)
