import numpy as np
import torch
import h5py
import os
import copy
from einops import reduce


def load_h5file(path, keys=None):
    """
    Load contents of a hdf5 file
    :param path: path to h5 file
    :param keys: keys to load
    :return:
    """

    with h5py.File(path, "r") as h5file:
        try:
            if keys is None:
                keys = []
                h5file.visit(keys.append)
            return {k: np.array((h5file[k])) for k in keys}
        except KeyError:
            #   print('Existing keys are: ', h5file.keys())
            raise KeyError


def mask(paths, hemisphere, mask_type):
    fmri_path, mask_path = paths
    fmri_data_full = load_neural_data_new(fmri_path)[hemisphere + "h"]
    fmri_mask = [
        int(i) - 1 for i in load_h5file(mask_path)["data/" + hemisphere + mask_type]
    ]

    fmri_data = fmri_data_full[:, fmri_mask]  # num_images, num_voxels

    # if you want the full brain with zeroes elsewhere:
    full_brain = np.zeros(fmri_data_full.shape[1])
    full_brain[fmri_mask] = np.nanmean(fmri_data, axis=0)

    return fmri_data


def normalize_session_data(data, idx):
    mu = np.mean(data[idx, :], axis=0)
    sig = np.std(data[idx, :], axis=0)
    return (data - mu) / sig


def load_neural_data_new(root_path):
    """
    MODIFIED TO RETURN (185,Voxels), where it stacks the voxels for each ROI
    include a reliability filter if you want to threshold the voxels
    IMPORTANT:  rel_data -> FULL PATH only, does not know where this is supposed to be
                rel_threshold -> minimum reliability for a voxel to make it into the analysis
    returns dict, hemi:(525,numVoxels (good ones))
    """
    normalizer_idx = {
        "stimgroup1": [0, 4, 99, 14, 20, 23, 33, 40, 43, 45, 49, 76, 78, 83, 86],
        "stimgroup2": list(range(15)),
    }

    # read data
    neural_data_dict = {}
    for f in [
        "stimgroup1_data.h5",
        "stimgroup2_data.h5",
    ]:
        g_name = f.split("_data")[0]
        tmp = load_h5file(
            os.path.join(root_path, f), keys=["data/lhdata", "data/rhdata"]
        )
        for k in tmp:
            tmp[k] = tmp[k].T

        neural_data_dict[g_name] = copy.deepcopy(tmp)

    # pool data
    pooled_data = {}
    for k in ["data/lhdata", "data/rhdata"]:
        g1_norm = normalize_session_data(
            neural_data_dict["stimgroup1"][k], normalizer_idx["stimgroup1"]
        )
        g2_norm = normalize_session_data(
            neural_data_dict["stimgroup2"][k], normalizer_idx["stimgroup2"]
        )[15:]
        pooled_data[k] = np.concatenate((g1_norm, g2_norm), axis=0)

    output = dict()
    for hemi in ["lh", "rh"]:
        DATA_SUBSET = "data/{}data".format(hemi)
        ###np.isfinite(pooled_data[DATA_SUBSET].mean(0))
        if hemi == "lh":
            output[hemi] = pooled_data[DATA_SUBSET][:, :].squeeze()
        elif hemi == "rh":
            output[hemi] = pooled_data[DATA_SUBSET][:, :].squeeze()

    return output


def load_brain_data(dataset_name, subject="1", roi = "ffa", dataset_root = "/home/XXXX-5/repos/robust-brainmodels/temp_root/", averaging=False):
    if dataset_name == "185":
        fmri_path = f"{dataset_root}/datasets/murty_185/subject{subject}"

        mask_path = f"{dataset_root}/datasets/murty_185/subject{subject}/handmade_rel_mask.h5"

        paths = [fmri_path, mask_path]

        brain_data_raw = mask(
            paths, params.hemisphere, roi
        )  # num_images, num_voxels

        brain_data = np.nanmean(brain_data, axis=1)

        brain_data = torch.from_numpy(brain_data)
        # brain_data = torch.from_numpy(np.load(dataset_root+"/datasets/murty185/185_braindata_ffa_high.npy"))
    elif dataset_name == "NSD1000":

        subject = f"subj0{subject}"

        if roi == "ffa":

            brain_data_raw=(
                torch.load(
                    dataset_root
                    + f"/{subject}/nsd_floc_facestval_threshold_6_responses_1000_all_reps_all_voxels_subj01.pth"
                , weights_only=True)
                .detach()
                .cpu()
            )

            # raise AssertionError(f"raw FFA data, shape: {brain_data.shape}")

            brain_data = reduce(
                brain_data_raw, "num_samples reps voxels -> num_samples", "mean"
            )

            print("using FFA data")

        elif roi == "vtc":
            brain_data_raw=(
                torch.load(
                    dataset_root
                    + f"./{subject}/nsd_hcp_glasser_parcellation_cortex_Ventral_Stream_Visual_subj01.pth"
                )
                .detach()
                .cpu()
            )
            brain_data = reduce(
                brain_data_raw, "num_samples reps voxels -> num_samples", "mean"
            )

            print(f"using VTC data")

        elif roi == "eba":
            brain_data_raw=(
                torch.load(
                    dataset_root
                    + f"./{subject}/nsd_floc_bodiestval_threshold_6_responses_1000_all_reps_all_voxels_subj01.pth"
                ,weights_only=True)
                .detach()
                .cpu()
            )
            brain_data = reduce(
                brain_data_raw, "num_samples reps voxels -> num_samples", "mean"
            )

            print(f"using EBA data")
        elif roi == "ppa":

            brain_data_raw=(
                torch.load(
                    dataset_root+ f"./{subject}/nsd_floc_placestval_threshold_6_responses_1000_all_reps_all_voxels_subj01.pth"
                , weights_only=True)
                .detach()
                .cpu()
            )
            brain_data = reduce(
                brain_data_raw, "num_samples reps voxels -> num_samples", "mean"
            )


            print(f"using PPA data")

        elif roi == "vwfa":

            brain_data_raw=(
                torch.load(
                    dataset_root
                    + f"./{subject}/nsd_floc_wordtval_threshold_6_responses_1000_all_reps_all_voxels_subj01.pth"
                ,weights_only=True)
                .detach()
                .cpu()
            )

            brain_data = reduce(
                brain_data_raw, "num_samples reps voxels -> num_samples", "mean"
            )


            print(f"using VWFA data")
        
        elif roi == "v1v":
            brain_data_raw=(
                torch.load(
                    dataset_root
                    + f"./{subject}/nsd_V1v_voxels_responses_1000_all_reps_all_voxels_subj01.pth"
                ,weights_only=True)
                .detach()
                .cpu()
            )

            brain_data = reduce(
                brain_data_raw, "num_samples reps voxels -> num_samples", "mean"
            )
            print(f"using v1v data")
        
        elif roi == "v1d":
            brain_data_raw=(
                torch.load(
                    dataset_root
                    + f"./{subject}/nsd_V1d_voxels_responses_1000_all_reps_all_voxels_subj01.pth"
                ,weights_only=True)
                .detach()
                .cpu()
            )
            brain_data = reduce(
                brain_data_raw, "num_samples reps voxels -> num_samples", "mean"
            )
            print(f"using v1d data")
        
        elif roi == "v2v":
            brain_data_raw=(
                torch.load(
                    dataset_root
                    + f"./{subject}/nsd_V2v_voxels_responses_1000_all_reps_all_voxels_subj01.pth"
                ,weights_only=True)
                .detach()
                .cpu()
            )
            brain_data = reduce(
                brain_data_raw, "num_samples reps voxels -> num_samples", "mean"
            )
            print(f"using v2v data")
        
        elif roi == "v2d":
            brain_data_raw=(
                torch.load(
                    dataset_root
                    + f"./{subject}/nsd_V2d_voxels_responses_1000_all_reps_all_voxels_subj01.pth"
                ,weights_only=True)
                .detach()
                .cpu()
            )
            brain_data = reduce(
                brain_data_raw, "num_samples reps voxels -> num_samples", "mean"
            )
            print(f"using v2d data")
        
        elif roi == "v3v":
            brain_data_raw=(
                torch.load(
                    dataset_root
                    + f"./{subject}/nsd_V3v_voxels_responses_1000_all_reps_all_voxels_subj01.pth"
                ,weights_only=True)
                .detach()
                .cpu()
            )
            brain_data = reduce(
                brain_data_raw, "num_samples reps voxels -> num_samples", "mean"
            )
            print(f"using v3v data")
        
        elif roi == "v3d":
            brain_data_raw=(
                torch.load(
                    dataset_root
                    + f"./{subject}/nsd_V3d_voxels_responses_1000_all_reps_all_voxels_subj01.pth"
                ,weights_only=True)
                .detach()
                .cpu()
            )
            brain_data = reduce(
                brain_data_raw, "num_samples reps voxels -> num_samples", "mean"
            )
            print(f"using v3d data")
        
        elif roi == "hv4":
            brain_data_raw=(
                torch.load(
                    dataset_root
                    + f"./{subject}/nsd_hV4_voxels_responses_1000_all_reps_all_voxels_subj01.pth"
                ,weights_only=True)
                .detach()
                .cpu()
            )
            brain_data = reduce(
                brain_data_raw, "num_samples reps voxels -> num_samples", "mean"
            )
            print(f"using hv4 data")
        

    if averaging:

        return brain_data
    
    else:

        return brain_data_raw

    # /research/XXXX-3/repos/robust-brainmodels/temp_root/datasets/nsd/nsd/subj01/nsd_ffa1_tval_greater_than_6_responses_1000_subj01.pth
