import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

df = pd.read_csv('results_new/combined_attack_results.csv')
df = df[df['sensitivity_filtered_direction'] == 1].copy()

# model-wise L2 sensitivity (mean over images, then mean over voxels)
l2 = df[df['attack_type'] == 'l2'].copy()
l2_voxel_mean = (l2.groupby(['model','subject','roi','voxel_idx'])['sensitivity']
                   .mean()
                   .groupby('model')
                   .mean())

# model-wise predictivity: per-voxel corr(old_response, label) averaged per model
def _voxel_corr(g):
    if g['old_response'].nunique() < 2 or g['label'].nunique() < 2: 
        return np.nan
    return np.corrcoef(g['old_response'], g['label'])[0,1]

vox_corr = (l2.groupby(['model','voxel_idx'])
              .apply(_voxel_corr)
              .groupby('model')
              .mean())


common = vox_corr.dropna().index.intersection(l2_voxel_mean.dropna().index)
vals_corr = vox_corr.loc[common].values
vals_sens = l2_voxel_mean.loc[common].values


vals_corr_norm = vals_corr / np.nanmax(vals_corr) if np.nanmax(vals_corr) > 0 else vals_corr
vals_sens_norm = vals_sens / np.nanmax(vals_sens) if np.nanmax(vals_sens) > 0 else vals_sens

var_corr_norm = np.nanvar(vals_corr_norm)
var_sens_norm = np.nanvar(vals_sens_norm)

def sparseness(x):
    m1, m2 = np.nanmean(x), np.nanmean(x**2)
    return 1 - (m1*m1)/m2 if m2 > 0 else np.nan

spar_corr = sparseness(vals_corr_norm)
spar_sens = sparseness(vals_sens_norm)

plt.figure(figsize=(7,5), dpi=200)
plt.bar(['R','Sₗ₂'], [var_corr_norm, var_sens_norm], color='gray', alpha=0.85, hatch='//', edgecolor='k')
plt.ylabel('Normalized Variance')
plt.tight_layout()
plt.savefig('normalized_variance_comparison.pdf', bbox_inches='tight')

plt.figure(figsize=(7,5), dpi=200)
plt.bar(['R','Sₗ₂'], [spar_corr, spar_sens], color='gray', alpha=0.85, hatch='//', edgecolor='k')
plt.ylabel('Sparseness')
plt.tight_layout()
plt.savefig('sparseness_comparison.pdf', bbox_inches='tight')

print('Models used (N):', len(common))
print('Normalized variance  R, S_l2 :', var_corr_norm, var_sens_norm)
print('Sparseness           R, S_l2 :', spar_corr,      spar_sens)
