import numpy as np
import nibabel as nib

import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")


def get_voxel_mask(subject_id: int, true_mask_values: list = [], base_path: str = "/research/XXXX-3/repos/nesim/training/nsd/nsd-data/nsd/", ):
    """
    note: subject_id should be 1, 2, 5, 7
    """
    mask_numpy = np.load(f"{base_path}/voxels_masks/subj{str(subject_id)}/roi_1d_mask_subj0{str(subject_id)}_floc-faces.npy")
    binary_mask = np.isin(mask_numpy, true_mask_values)
    return binary_mask

#  (values 2,3 in the mask corresponds to FFA-1 and 2. i just combine them both). OFA is value 1 in the mask
ffa_voxel_mask = get_voxel_mask(subject_id=7, true_mask_values=[2, 3],base_path="/research/XXXX-3/repos/nesim/training/nsd/nsd-data/nsd/")

plt.plot(ffa_voxel_mask)
plt.savefig("voxel_mask.jpg")
plt.close()

tvals = np.load("/research/XXXX-3/repos/nesim/training/nsd/nsd-data/nsd/tvals/subj7_faces.npy")

def apply_mask(array_to_mask, binary_mask):
    masked_array = array_to_mask.copy()  # Create a copy of the array to avoid modifying the original
    masked_array[~binary_mask] = 0  # Set elements where binary mask is False to 0
    return masked_array

masked_tvals = apply_mask(array_to_mask=tvals, binary_mask=ffa_voxel_mask)

threshold=10
thresholded_masked_tvals = np.where(masked_tvals < threshold, 0, masked_tvals)
print(f"tvals {tvals.shape}")

plt.title("subject 7 - tvals")
plt.plot(tvals, label="raw tvals", color = "c")
plt.plot(masked_tvals, label="ffa masked tvals (2, 3)", color = "g")

plt.axhline(y=threshold, color='k', linestyle='-', label=f"threshold = {threshold}", alpha=0.2)
plt.plot(thresholded_masked_tvals, label=f"masked + thresholded ffa tvals", color = "r")
plt.legend()
plt.savefig("tvals.jpg")
plt.close()

# #todo: which of these is feature, label for our task?