#coding:utf8
import os
from torch.utils import data
import numpy as np
from sklearn.utils import shuffle
import nibabel as nib
import random
import pandas as pd
import glob
from sklearn.model_selection import StratifiedKFold
import math
import lmdb
import warnings
import h5py
from nilearn.connectome import ConnectivityMeasure
from sklearn.utils  import shuffle
from sklearn.model_selection import StratifiedShuffleSplit
warnings.filterwarnings("ignore")

def mask_timeseries(timeser, mask = 30):
    rnd = np.random.random()
    if rnd < 0.2:
        return timeser
    
    time_len = timeser.shape[1]
    mask_index = np.array(random.sample(list(np.arange(0,time_len)),mask))
    bool_mask = np.zeros((time_len))
    bool_mask[mask_index]=1
    bool_mask = bool_mask.astype(bool)

    return timeser[:,~bool_mask]

def mask_timeseries_per(timeser, mask = 30):
    rnd = np.random.random()
    if rnd < 0.2:
        return timeser

    time_len = timeser.shape[1]
    mask_len = int(mask * time_len /100)
    mask_index = np.array(random.sample(list(np.arange(0,time_len)),mask_len))
    bool_mask = np.zeros((time_len))
    bool_mask[mask_index]=1
    bool_mask = bool_mask.astype(bool)

    return timeser[:,~bool_mask]

class Task1Data(data.Dataset):

    def __init__(self, root = None,mask_way='mask',mask_len=10, time_len=30):
        self.template = 'sch'
        self.root = root
        self.mask_way = mask_way
        self.mask_len = mask_len
        self.time_len = time_len

        self.names = [f for f in os.listdir(self.root) if self.template in f]

        print(f"Finding files: {len(self.names)}")
        self.correlation_measure = ConnectivityMeasure(kind='correlation')

    def __getitem__(self,index):
        name = self.names[index]
        img = np.load(os.path.join(self.root, name))
        if self.mask_way == 'mask':
            slices = [mask_timeseries(img,mask=self.mask_len).T, mask_timeseries(img,mask=self.mask_len).T]
        elif self.mask_way == 'mask_per':
            slices = [mask_timeseries_per(img,mask=self.mask_len).T, mask_timeseries_per(img,mask=self.mask_len).T]
        elif self.mask_way == 'random':
            slices = [random_timeseries(img,sample_len=self.time_len).T, random_timeseries(img,sample_len=self.time_len).T]
        else:
            raise KeyError(f"mask way error, your input is {self.mask_way}")
        correlation_matrix = self.correlation_measure.fit_transform(slices)
        correlation_matrix[correlation_matrix!=correlation_matrix]=0
        return correlation_matrix[0], correlation_matrix[1]

    def __len__(self):
        return len(self.names)
        
class Task3Data(data.Dataset):
    def __init__(self,shuffle_seed, mask_way='mask',mask_len=10,is_train = True, is_test = False):
        self.template = 'sch'
        self.is_test = is_test
        self.is_train = is_train
        self.root = "../run1_abide1" # anonymous
        self.df = pd.read_csv("clean_abide1.csv") # anonymous

        self.mask_way = mask_way
        self.mask_len = mask_len

        self.names = list(self.df['new_name'])

        all_data = np.array(self.names)
        lbls = np.array(list([1 if f == 1 else 0 for f in self.df['dx'] ]))
        sites = np.array(self.df['site'])

        train_length = int(len(self.df) * 0.7)
        val_length = int(len(self.df) * 0.15)
        test_length = int(len(self.df) * 0.15)

        split = StratifiedShuffleSplit(n_splits=1, test_size=val_length+test_length, train_size=train_length, random_state=42)
        for train_index, test_valid_index in split.split(all_data, sites):
            data_train, labels_train = all_data[train_index], lbls[train_index]
            data_rest, labels_rest = all_data[test_valid_index], lbls[test_valid_index]
            site_rest = sites[test_valid_index]

        train_files = os.listdir("../run1_abide1_train/") #anonymous
        # double check that the pretraining training sets are not in the validation/test set.
        for ff in data_rest:
            if ff[4:-4] in data_train:
                print(ff)
                exit()
        
        split2 = StratifiedShuffleSplit(n_splits=1, test_size=test_length,random_state=shuffle_seed)
        for valid_index, test_index in split2.split(data_rest, site_rest):
            data_test, labels_test = data_rest[test_index], labels_rest[test_index]
            data_val, labels_val = data_rest[valid_index], labels_rest[valid_index]

        if is_test is True:
            print("Testing data:")
            self.imgs, self.lbls = data_test, labels_test
        elif is_train is True:
            print("Training data:")
            #self.imgs, self.lbls = data_train, labels_train
            self.imgs, self.lbls = np.concatenate([data_train, data_val],0), np.concatenate([labels_train, labels_val],0)
        else:
            print("Val data:")
            self.imgs, self.lbls = data_val, labels_val
        print(self.imgs.shape)
        self.correlation_measure = ConnectivityMeasure(kind='correlation')

    def __getitem__(self,index):
        name = self.imgs[index]
        lbl = self.lbls[index]
        img = np.load(os.path.join(self.root, f"{self.template}_{name}.npy"))
        if self.is_train is True:
            if self.mask_way == 'mask':
                slices = [mask_timeseries(img,self.mask_len).T,img.T]
            elif self.mask_way == 'mask_per':
                slices = [mask_timeseries_per(img,self.mask_len).T,img.T]
            else:
                slices = [img.T]
            correlation_matrix = self.correlation_measure.fit_transform(slices).mean(0)
        else:
            slices = [img.T]
            correlation_matrix = self.correlation_measure.fit_transform(slices).mean(0)
        onehot_lbl = np.zeros((2))
        onehot_lbl[lbl] = 1
        correlation_matrix[correlation_matrix!=correlation_matrix]=0
        return correlation_matrix,onehot_lbl
    
    def __len__(self):
        return len(self.imgs)