import numpy as np
import pandas as pd
import random
import torch
from torch.utils.data import Dataset
import h5py
import scipy.stats as stats

class fMRIDataset(Dataset):
    """
    Class that takes in the location of Karl-Heinz's dataset, and produces an object to sample data from (e.g., to train a transformer
    Args:
        dataset_str (str): the full path to KH's dataset 
    """
    def __init__(self, 
                 datadir='../../../hcpPostProcCiric/',
                 subjectset=1
                 ):
        self.datadir = datadir
        self.runs = ['rfMRI_REST1_RL', 'rfMRI_REST1_LR','rfMRI_REST2_RL', 'rfMRI_REST2_LR']

        if subjectset==1:
            #self.subjects = [
            #    '178950','189450','199453','209228','220721','298455','356948','419239','499566','561444','618952','680452','757764','841349','908860',
            #    '103818','113922','121618','130619','137229','151829','158035','171633','179346','190031','200008','210112','221319','299154','361234',
            #    '424939','500222','570243','622236','687163','769064','845458','911849','104416','114217','122317','130720','137532','151930','159744',
            #    '172029','180230','191235','200614','211316','228434','300618','361941','432332','513130','571144','623844','692964','773257','857263',
            #    '926862','105014','114419','122822','130821','137633','152427','160123','172938','180432','192035','200917','211417','239944','303119',
            #    '365343','436239','513736','579665','638049','702133','774663','865363','930449','106521','114823','123521','130922','137936','152831',
            #    '160729','173334','180533','192136','201111','211619','249947','305830','366042','436845','516742','580650','645450','715041','782561',
            #    '871762','942658','106824','117021','123925','131823','138332','153025','162026','173536','180735','192439','201414','211821','251833',
            #    '310621','371843','445543','519950','580751','647858','720337','800941','871964','955465','107018','117122','125222','132017','138837',
            #    '153227','162329','173637','180937','193239','201818','211922','257542','314225','378857','454140','523032','585862','654350','725751',
            #    '803240','872562','959574','107422','117324','125424','133827','142828','153631','164030','173940','182739','194140','202719','212015',
            #    '257845','316633','381543','459453','525541','586460','654754','727553','812746','873968','966975'
            #    ]
            self.subjects = [
                '100408', '135629', '174437', '205725', '389357', '588565', '814649',
                '103818', '135932', '175237', '205826', '394956', '598568', '820745',
                '104416', '142828', '176441', '209228', '395756', '599671', '826454',
                '105923', '146331', '176845', '212015', '406432', '601127', '835657',
                '107018', '146432', '177645', '212823', '406836', '622236', '871762',
                '110007', '147636', '177746', '213017', '414229', '654754', '871964',
                '112516', '148133', '179346', '213421', '419239', '657659', '877269',
                '112920', '151829', '183034', '213522', '424939', '667056', '878877',
                '113316', '151930', '185341', '228434', '436239', '671855', '891667',
                '113922', '152427', '187850', '239944', '436845', '675661', '898176',
                #'114419', '153025', '189450', '251833', '445543', '679568', '907656',
                #'117021', '153227', '190031', '257542', '467351', '679770', '911849',
                #'117122', '154229', '191033', '257845', '481042', '687163', '926862',
                #'117324', '154532', '191235', '268749', '500222', '702133', '930449',
                #'118124', '154936', '191437', '268850', '516742', '707749', '955465',
                #'118225', '159744', '192439', '285345', '523032', '727553', '959574',
                #'120212', '160123', '192641', '290136', '525541', '735148', '989987',
                #'125222', '162329', '193239', '303119', '529953', '742549', '991267',
                #'125424', '162935', '196144', '342129', '530635', '749361', '992774',
                #'127933', '164939', '196346', '352738', '553344', '769064', '993675',
            ]
        elif subjectset==2:
            #self.subjects = [
            #    '100206','108020','117930','126325','133928','143224','153934','164636','174437','183034','194443','204521','212823','268749','322224',
            #    '385450','463040','529953','587664','656253','731140','814548','877269','978578','100408','108222','118124','126426','134021','144832',
            #    '154229','164939','175338','185139','194645','204622','213017','268850','329844','389357','467351','530635','588565','657659','737960',
            #    '816653','878877','987074','101006','110007','118225','127933','134324','146331','154532','165638','175742','185341','195445','205119',
            #    '213421','274542','341834','393247','479762','545345','597869','664757','742549','820745','887373','989987','102311','111009','118831',
            #    '128632','135528','146432','154936','167036','176441','186141','196144','205725','213522','285345','342129','394956','480141','552241',
            #    '598568','671855','744553','826454','896879','990366','102513','112516','118932','129028','135629','146533','156031','167440','176845',
            #    '187850','196346','205826','214423','285446','348545','395756','481042','553344','599671','675661','749058','832651','899885','991267',
            #    '102614','112920','119126','129129','135932','147636','157336','168745','177645','188145','198350','208226','214726','286347','349244',
            #    '406432','486759','555651','604537','679568','749361','835657','901442','992774','103111','113316','120212','130013','136227','148133',
            #    '157437','169545','178748','188549','198451','208327','217429','290136','352738','414229','497865','559457','615744','679770','753150',
            #    '837560','907656','993675','103414','113619','120414','130114','136833','150726','157942','171330'
            #]
            self.subjects = [
                '129129', '165638', '200008', '353740', '555651', '774663',
                '130619', '167036', '201111', '358144', '561444', '782561',
                '132017', '171330', '201414', '361941', '568963', '800941',
                '133827', '172938', '202719', '365343', '586460', '812746',
                '135528', '173940', '204521', '381543', '587664', '814548',
            ]
        
    def data2dict(self,model='24pXaCompCorXVolterra',zscore=True):
        data = {}
        for subj in self.subjects:
            data[subj] = self._loadsubj(subj,model=model,zscore=zscore)
        return data

    def _loadsubj(self,subj,model='24pXaCompCorXVolterra',zscore=True):
        datafile = self.datadir + subj + '_glmOutput_data.h5' 
        h5f = h5py.File(datafile,'r')
        data = []
        i = 1
        for run in self.runs:
            ##### EDIT THIS BLOCK
            dataid = run + '/nuisanceReg_resid_' + model
            # dataid = f'{subj}_rest{i}_residuals_noGSR'
            i += 1
            ####
            tmp = h5f[dataid][:]
            if zscore:
                tmp = stats.zscore(tmp,axis=1)
            data.extend(tmp.T)
        data = np.asarray(data).T
        h5f.close()
        return data

class DatasetSampler(Dataset):
    """
    Pytorch Class that takes in the fMRIDataset class and generates individual samples (e.g., pupil data for a specific window size)
    will extract randomly selected time points in an individual's time series data
    

    Args:
        fmri_dict:   data dictionary generated from fMRIDataset class
        subjects : list of strings (e.g., ['01','02'])
    """
    def __init__(self, 
                 datadict, 
                 subjects,
                 ):
        self.datadict = datadict
        self.subjects = subjects
        self.n_timepoints = datadict[subjects[0]].shape[1]

    def __len__(self):
        # this is technically number of samples, but we don't index them at all in this dataset
        return len(self.subjects) * (self.n_timepoints)

    def __getitem__(self,idx):
        """
        idx is ignored -- it's just a requirement for dataset object specified by torch
        return a sample time series window from dataset
        """
        # cannot sample a time point where timepoint + window_size exceeds the session length
        sub_id = random.choice(self.subjects)
        time_idx = random.choice(np.arange(self.n_timepoints))

        sample = torch.tensor(self.datadict[sub_id][:,time_idx])
        sample = torch.reshape(sample,(sample.shape[0],1)).float()

        return sample
