import numpy as np

def foo(arr, seed=False):
    if seed:
        mse = arr[:, ::3]
        corr = arr[:, 1::3]
        r2 = arr[:, 2::3]
    else:
        mse = arr[:,:3]
        corr = arr[:,3:6]
        r2 = arr[:,6:9]
    d = {
        'mse' : np.stack((np.nanmean(mse, axis=1), np.nanstd(mse, axis=1)), axis=0).transpose(1,0),
        'corr' : np.stack((np.nanmean(corr, axis=1), np.nanstd(corr, axis=1)), axis=0).transpose(1,0),
        'r2' : np.stack((np.nanmean(r2, axis=1), np.nanstd(r2, axis=1)), axis=0).transpose(1,0)
    }
    return d

def print_mean_std(arr):
    statement = [f'{x:.3f}' if i%2==0 else '('+f'{x:.3f}'+') &' for i, x in enumerate(arr)]
    print(' '.join(statement).rstrip('&'))

def print_mean_row(arr):
    statement = [f'{x:.3f} &' for x in arr]
    print(' '.join(statement).rstrip('&'))

def change(arr):
    return (arr[0] - arr[1]) / np.abs(arr[1])

def pc(x,y):
    if type(x) == str:
        x = [float(i) for i in x.split()]
    y = [float(i) for i in y.split()]
    x, y = np.array(x), np.array(y)
    xy = np.concatenate((x,y)).reshape(2,-1)

    mse, corr, r2 = xy[:,:5], xy[:, 5:10], xy[:, 10:15]
    mse_sum = np.sum(mse, axis=-1)
    corr_sum = np.sum(corr, axis=-1)
    r2_sum = np.sum(r2, axis=-1)
    print(mse_sum, corr_sum, r2_sum)

    print(change(mse_sum), change(corr_sum), change(r2_sum))

# # table percentage
# x = '0.653 1.071 1.370 1.893 2.646 0.070 0.113 0.165 0.191 0.241 0.005 0.007 0.028 0.033 0.054'
# y = '0.654 1.077 1.406 1.949 2.795 0.081 0.082 0.079 0.092 0.085 0.003 0.002 0.002 0.004 0.001'
# pc(x,y)

# fi mean std metrics
# data = np.array([
# [2.81613,  0.09923, -0.00684,	2.88564,  0.14377, -0.03170,	3.18848,  0.13297, -0.13997],
# ])
# d = (foo(data, seed=True))

# np.set_printoptions(precision=3, suppress=True)

# metrics = []
# for k, v in d.items():
#     print(f'{k} : {v}')
#     metrics.append(v[:,0].flatten())
# metrics = np.array(metrics).flatten()
# mse = d['mse'].flatten()
# corr = d['corr'].flatten()
# r2 = d['r2'].flatten()

# print_mean_std(mse)
# print_mean_std(corr)
# print_mean_std(r2)
# print_mean_row(metrics)

# y = '0.217 0.224 0.222 0.142 0.372 0.000 0.000 0.000 0.000 0.000 -0.001 -0.028 0.012 -0.005 -0.001'
# pc(metrics, y)

# data = '0.2169, -0.0007, -0.0013,	0.2169,  0.0003, -0.0015,	0.2169,  0.0010, -0.0014'.split()
# data = np.array([float(x.strip(',')) for x in data])
# mse = data[::3]
# corr = data[1::3]
# r2 = data[2::3]
# print(np.mean(mse), np.std(mse))
# print(np.mean(corr), np.std(corr))
# print(np.mean(r2), np.std(r2))
    
# # ablation seeds
# data = np.array([[1.8101, 0.2766, 0.747],
# [1.8500, 0.2354, 0.0543],
# [1.8645, 0.2167, 0.0469]])
# means = np.mean(data, axis=0)
# print(means)
data = ['0.653 1.071 1.370 1.893 2.646 0.070 0.113 0.165 0.191 0.241 0.005 0.007 0.028 0.033 0.054',
        '0.654 1.077 1.406 1.949 2.795 0.081 0.082 0.079 0.092 0.085 0.003 0.002 0.002 0.004 0.001',
        '0.650 1.042 1.352 1.796 2.548 0.104 0.192 0.205 0.291 0.313 0.010 0.035 0.040 0.082 0.089',
        '0.652 1.073 1.402 1.945 2.782 0.080 0.081 0.074 0.083 0.084 0.006 0.006 0.005 0.006 0.005',
        '0.654 1.084 1.402 2.002 2.649 0.054 0.070 0.088 0.121 0.249 0.002 -0.005 0.005 -0.024 0.053',
        '0.683 1.183 1.582 2.279 3.401 0.045 0.045 0.033 0.063 0.056 -0.041 -0.096 -0.123 -0.165 -0.216',
        '0.642 1.033 1.329 1.807 2.494 0.160 0.221 0.257 0.298 0.353 0.022 0.043 0.056 0.076 0.109',
        '0.657 1.075 1.394 1.888 2.643 0.083 0.110 0.135 0.201 0.271 -0.001 0.004 0.011 0.035 0.055'
        ]
data = [x.split() for x in data]
data = np.array(list(map(lambda row : [float(x) for x in row], data)))
cvml, vanilla = data[::2, :], data[1::2, :]
print(cvml)
print(vanilla)
mse_cv = np.sum(cvml[:,:5])
mse_va = np.sum(vanilla[:,:5])
corr_cv = np.sum(cvml[:,5:10])
corr_va = np.sum(vanilla[:,5:10])
r2_cv = np.sum(cvml[:,10:15])
r2_va = np.sum(vanilla[:,10:15])
print(mse_cv, mse_va)
print(corr_cv, corr_va)
print(r2_cv, r2_va)

mse_change = (mse_cv - mse_va) / np.abs(mse_va)
corr_change = (corr_cv - corr_va) / np.abs(corr_va)
r2_change = (r2_cv - r2_va) / np.abs(r2_va)

print(mse_change)
print(corr_change)
print(r2_change)