import sys
import numpy as np
import sklearn.linear_model as skl
import pickle
import argparse
import os

parser = argparse.ArgumentParser(description='Argument Parser')
parser.add_argument("-sub", "--sub",help="Subject Number",default=1)
parser.add_argument("-fmri_path", "--fmri_path",help="fmri data",default=None)
parser.add_argument("-out_path", "--out_path",help="fmri data",default=None)
args = parser.parse_args()
fmri_path = args.fmri_path
out_path = args.out_path
sub=int(args.sub)
assert sub in [1,2,5,7]

train_path = os.path.join(fmri_path, 'train_fmri.npy')
train_fmri = np.load(train_path)
test_path = os.path.join(fmri_path, 'test_fmri.npy')
test_fmri = np.load(test_path)

## Preprocessing fMRI

train_fmri = train_fmri/300
test_fmri = test_fmri/300


norm_mean_train = np.mean(train_fmri, axis=0)
norm_scale_train = np.std(train_fmri, axis=0, ddof=1)
train_fmri = (train_fmri - norm_mean_train) / norm_scale_train
test_fmri = (test_fmri - norm_mean_train) / norm_scale_train

print(np.mean(train_fmri),np.std(train_fmri))
print(np.mean(test_fmri),np.std(test_fmri))

print(np.max(train_fmri),np.min(train_fmri))
print(np.max(test_fmri),np.min(test_fmri))

num_voxels, num_train, num_test = train_fmri.shape[1], len(train_fmri), len(test_fmri)

root = "/storage/user1/brain-diffuser/"
train_clip = np.load(os.path.join(root,f'data/extracted_features/subj{sub:02d}/nsd_cliptext_train.npy'))
test_clip = np.load(os.path.join(root,f'data/extracted_features/subj{sub:02d}/nsd_cliptext_test.npy'))
#np.save('data/predicted_features/subj{:02d}/nsd_clipvision_gt_test_nsdgeneral.npy'.format(sub),test_clip)

## Regression
num_samples,num_embed,num_dim = train_clip.shape

print("Training Regression")
reg_w = np.zeros((num_embed,num_dim,num_voxels)).astype(np.float32)
reg_b = np.zeros((num_embed,num_dim)).astype(np.float32)
pred_clip = np.zeros_like(test_clip)
all_pred_test_latents = np.zeros_like(test_clip)
for i in range(num_embed):
    reg = skl.Ridge(alpha=100000, max_iter=50000, fit_intercept=True)
    reg.fit(train_fmri, train_clip[:,i])
    reg_w[i] = reg.coef_
    reg_b[i] = reg.intercept_
    
    pred_test_latent = reg.predict(test_fmri)
    all_pred_test_latents[:,i] = pred_test_latent
    std_norm_test_latent = (pred_test_latent - np.mean(pred_test_latent,axis=0)) / np.std(pred_test_latent,axis=0)
    pred_clip[:,i] = std_norm_test_latent * np.std(train_clip[:,i],axis=0) + np.mean(train_clip[:,i],axis=0)
    print(i,reg.score(test_fmri,test_clip[:,i]))

np.save(os.path.join(out_path, 'cliptext_test_out.npy'),pred_clip)

datadict = {
    'weight' : reg_w,
    'bias' : reg_b,

}

#with open('data/regression_weights/subj{:02d}/cliptext_regression_weights.pkl'.format(sub),"wb") as f:
#  pickle.dump(datadict,f)
