import matplotlib.pyplot as plt
import numpy as np
import torch
import joblib
import os
import sys

def rmse(true, pred):
    return np.sqrt(np.mean((true.reshape(-1, 1) - pred.reshape(-1, 1))**2))

# Command line arguments: number of features and blocks
features = int(sys.argv[1])
blocks_total = int(sys.argv[2])
fold = 0 # Single fold

# Load data
scalers = {}
prefshap_values = {}
ytrain_values = {}
ytest_values = {}
train_preds = {}
test_preds = {}

for block in range(blocks_total):
    scalers[block] = joblib.load(f'y_block_tr_{block}_scaler_fold_{fold}.pkl')
    prefshap_values[block] = np.loadtxt(f"prefshap_block_{block}_fold_{fold}.txt")
    ytrain_values[block] = np.loadtxt(f"Ytrain_{block}_fold_{fold}.txt")
    ytest_values[block] = np.loadtxt(f"Ytest_{block}_fold_{fold}.txt")
    test_preds[block] = np.transpose(np.load(f"test_prediction_{block}_fold_{fold}.npy").squeeze(-1))
    train_preds[block] = np.transpose(np.load(f"train_prediction_{block}_fold_{fold}.npy").squeeze(-1))

val_data = {b: np.loadtxt(f"Yval_{b}_fold_{fold}.txt") for b in [0, 1, 2]}
val_preds = {b: np.transpose(np.load(f"val_prediction_{b}_fold_{fold}.npy").squeeze(-1)) for b in [0, 1, 2]}

#print("prefshap_values[0] shape = ", prefshap_values[0].shape)

# Compute RMSEs
trainrmse_block0, testrmse_block0, testrmse_block_original, testrmse_block_deepNN = [], [], [], []

blocks_eachmethod = int((blocks_total - 2)/2)
for k in range(train_preds[0].shape[1]):
    
    pred_sum_block_original = 0
    target_sum_block_original = 0
    pred_sum_block_deepNN = 0
    target_sum_block_deepNN = 0

    for block in range(1,blocks_eachmethod+1):
        pred_sum_block_original = pred_sum_block_original + test_preds[block][:, k] 
        target_sum_block_original = target_sum_block_original + ytest_values[block]

    for block in range(blocks_eachmethod+1,blocks_total):
        pred_sum_block_deepNN = pred_sum_block_deepNN + test_preds[block][:, k]
        target_sum_block_deepNN = target_sum_block_deepNN + ytest_values[block]

    testrmse_block_original.append(rmse(pred_sum_block_original, target_sum_block_original))
    trainrmse_block0.append(rmse(train_preds[0][:, k], ytrain_values[0]))
    testrmse_block0.append(rmse(test_preds[0][:, k], ytest_values[0]))
    testrmse_block_deepNN.append(rmse(pred_sum_block_deepNN, target_sum_block_deepNN))

# Set up subplots: 1 RMSE plot + 1 subplot per feature
ncols = features
fig, axes = plt.subplots(2, ncols, figsize=(5 * ncols, 10))

# First row: Only 1 RMSE plot, others blank
for ax in axes[0]:
    ax.axis('off')

ax_rmse = axes[0, 0]
ax_rmse.scatter(range(len(testrmse_block0)), testrmse_block0, alpha=0.5, label="Pref-SHAP", color="blue", marker="+")
ax_rmse.scatter(range(len(testrmse_block_original)), testrmse_block_original, alpha=0.5, label="block original", color="red", marker="+")
ax_rmse.scatter(range(len(testrmse_block_deepNN)), testrmse_block_deepNN, alpha=0.5, label="GPref-SHAP", color="green", marker="+")
ax_rmse.set_title(f"fold {fold}:Test RMSEs")
ax_rmse.set_xlabel("Iteration")
ax_rmse.set_ylabel("Test RMSE")
ax_rmse.legend()
ax_rmse.axis('on')  # re-enable this plot

# Second row: Correlation per feature

prefshap_block_original = 0
prefshap_block_deepNN = 0


for block in range(1,blocks_eachmethod+1):
    prefshap_block_original = prefshap_block_original + prefshap_values[block]


for block in range(blocks_eachmethod+1,blocks_total):
    prefshap_block_deepNN = prefshap_block_deepNN + prefshap_values[block]

for f_idx in range(features):
    ax_corr = axes[1, f_idx]
    p0 = prefshap_values[0][:, f_idx]
    #print("prefshap_values[0], p0 shape = ", prefshap_values[0].shape, p0.shape)

    ax_corr.scatter(p0, prefshap_block_original[:, f_idx], color="green", marker="x", alpha=0.4, label=f"block original")
    ax_corr.scatter(p0, prefshap_block_deepNN[:, f_idx], color="red", marker="x", alpha=0.4, label="GPref-SHAP")

    corr_block = np.corrcoef(p0, prefshap_block_original[:, f_idx])[0, 1]
    corr_deep = np.corrcoef(p0, prefshap_block_deepNN[:, f_idx])[0, 1]
    ax_corr.set_title(f"fold {fold}: Feature {f_idx}\n block original={corr_block:.2f}, GPref-SHAP={corr_deep:.2f}")
    ax_corr.set_xlabel("Pref-SHAP")
    ax_corr.set_ylabel("block original/GPref-SHAP (Linearity)")

    ax_corr.legend()

plt.tight_layout()
plt.savefig(f"plot{fold}.png")
plt.show()
