import os
from tkinter import Y

import numpy as np
import pandas as pd
import sklearn.preprocessing

from dataset import DataSet
from imageio import imread  # changed from scipy.misc to imageio
from PIL import Image
from scipy.io import loadmat
from sklearn import preprocessing
from torch.utils.data import DataLoader


class KamitaniDataHandler():
    def __init__(self, config, subject_file_idx, log = 0):

        test_img_csv = 'imageID_test.csv'
        train_img_csv = 'imageID_training.csv'
        test_img_csv = os.path.join(config["fmri_imagenet_data"], test_img_csv)
        train_img_csv = os.path.join(config["fmri_imagenet_data"], train_img_csv)
        matlab_file = os.path.join(config["fmri_imagenet_data"], config["subject_filenames"][subject_file_idx]) 

        mat = loadmat(matlab_file)
        self.data = mat['dataSet'][:,3:]
        self.sample_meta = mat['dataSet'][:,:3]
        meta = mat['metaData']

        self.meta_keys = list(l[0] for l in meta[0][0][0][0])
        self.meta_desc = list(l[0] for l in meta[0][0][1][0])
        self.voxel_meta = np.nan_to_num(meta[0][0][2][:,3:])

        test_img_df = pd.read_csv(test_img_csv, header=None)
        train_img_df = pd.read_csv(train_img_csv, header=None)
        self.test_img_id = test_img_df[0].values
        self.train_img_id = train_img_df[0].values
        self.sample_type = {'train':1 , 'test':2 , 'test_imagine' : 3}
        self.log = log

    def get_meta_field(self,field = 'DataType'):
        index = self.meta_keys.index(field)
        if(index <3): # 3 first keys are sample meta
            return self.sample_meta[:,index]
        else:
            return self.voxel_meta[index]

    def print_meta_desc(self):
        print(self.meta_desc)

    def get_labels(self, imag_data = 0,test_run_list = None):
        le = preprocessing.LabelEncoder()

        img_ids = self.get_meta_field('Label')
        type = self.get_meta_field('DataType')
        train = (type == self.sample_type['train'])
        test = (type == self.sample_type['test'])
        imag = (type == self.sample_type['test_imagine'])

        img_ids_train = img_ids[train]
        img_ids_test = img_ids[test]
        img_ids_imag = img_ids[imag]


        train_labels  = []
        test_labels  =  []
        imag_labels = []

        for id in img_ids_test:
            idx = (np.abs(id - self.test_img_id)).argmin()
            test_labels.append(idx)

        for id in img_ids_train:
            idx = (np.abs(id - self.train_img_id)).argmin()
            train_labels.append(idx)

        for id in img_ids_imag:
            idx = (np.abs(id - self.test_img_id)).argmin()
            imag_labels.append(idx)

        if (test_run_list is not None):
            run = self.get_meta_field('Run')
            test = (self.get_meta_field('DataType') == 2).astype(bool)
            run = run[test]

            select = np.in1d(run, test_run_list)
            test_labels = test_labels[select]

        imag_labels = le.fit_transform(img_ids_imag)
        if(imag_data):
            return np.array(train_labels), np.array(test_labels), np.array(imag_labels)
        else:
            return np.array(train_labels),np.array(test_labels)


    def get_data(self,
                normalize = 1,
                roi = 'ROI_VC',
                imag_data = 0,
                test_run_list = None):   # normalize 0-no, 1- per run , 2- train/test seperatly
        
        type = self.get_meta_field('DataType')
        train = (type == self.sample_type['train'])
        test = (type == self.sample_type['test'])
        test_imag = (type == self.sample_type['test_imagine'])
        test_all  = np.logical_or(test,test_imag)

        if isinstance(roi, str):
            roi_select = self.get_meta_field(roi).astype(bool)
        elif isinstance(roi, list):
            roi_select = 0
            for r in roi:
                roi_select += self.get_meta_field(r)
            roi_select = roi_select.astype(bool)

        data = self.data[:,roi_select]
        
        if(self.log ==1):
            data = np.log(1+np.abs(data))*np.sign(data)

        if(normalize==1):
            run = self.get_meta_field('Run').astype('int')-1
            num_runs = np.max(run)+1
            data_norm = np.zeros(data.shape)

            for r in range(num_runs):
                data_norm[r==run] = sklearn.preprocessing.scale(data[r==run])
            train_data = data_norm[train]
            test_data  = data_norm[test]
            test_all = data_norm[test_all]
            test_imag = data_norm[test_imag]

        else:
            train_data = data[train]
            test_data  =  data[test]
            if(normalize==2):
                train_data = sklearn.preprocessing.scale(train_data)
                test_data = sklearn.preprocessing.scale(test_data)

        if(self.log ==2):
            train_data = np.log(1+np.abs(train_data))*np.sign(train_data)
            test_data = np.log(1+np.abs(test_data))*np.sign(test_data)
            train_data = sklearn.preprocessing.scale(train_data)
            test_data = sklearn.preprocessing.scale(test_data)

        test_labels =  self.get_labels()[1]
        imag_labels = self.get_labels(1)[2]
        num_labels = max(test_labels)+1
        test_data_avg = np.zeros([num_labels,test_data.shape[1]])
        test_imag_avg = np.zeros([num_labels,test_data.shape[1]])

        if(test_run_list is not None):
            run = self.get_meta_field('Run')
            test = (self.get_meta_field('DataType') == 2).astype(bool)
            run = run[test]

            select = np.in1d(run, test_run_list)
            test_data = test_data[select,:]
            test_labels = test_labels[select]

        for i in range(num_labels):
            test_data_avg[i] = np.mean(test_data[test_labels==i], axis=0)
            test_imag_avg[i] = np.mean(test_imag[imag_labels == i], axis=0)
        if(imag_data):
            return train_data, test_data, test_data_avg,test_imag,test_imag_avg

        else:
            return train_data, test_data, test_data_avg #1200, 1750, 50

    def get_voxel_loc(self):
        x = self.get_meta_field('voxel_x')
        y = self.get_meta_field('voxel_y')
        z = self.get_meta_field('voxel_z')
        dim = [int(x.max() -x.min()+1),int(y.max() -y.min()+1), int(z.max() -z.min()+1)]
        return [x,y,z] , dim


def get_data(config, subject_file_idx, ROI_split):
    """Reads the fMRI and Img files and returns np.array containing the train/test pairs

    Args:
        subject_file_idx (int): the index of subject file (0-4)
        ROI_split (list[string]): the ROIs list

    Returns:
        List[np.array]: X_train, X_test, Y_train, Y_test, the X data is a dict keyed by the ROI name  
    """
    ROIS = ["ROI_V1", "ROI_V2", "ROI_V3", "ROI_V4", "ROI_LOC", "ROI_FFA", "ROI_PPA", "ROI_LVC", "ROI_HVC", "ROI_VC", "FFA_PPA"]
    for roi in ROI_split:
        if isinstance(roi, str):
            assert roi in ROIS, "ROI must be in " + str(ROIS)
        elif isinstance(roi, list):
            for r in roi:
                assert r in ROIS, "ROI must be in " + str(ROIS)

    images_npz = config["images_npz"]
    handler = KamitaniDataHandler(config, subject_file_idx)

    X_train, X_test = {}, {}
    for roi in ROI_split:
        x_train, _, x_test  = handler.get_data(roi = roi, imag_data = 0)
        X_train[tuple(roi)] = x_train
        X_test[tuple(roi)] = x_test
    
    labels_train, _ = handler.get_labels(imag_data = 0)

    file = np.load(images_npz)
    Y_train = file['train_images']
    Y_train = Y_train[labels_train]
    Y_test = file['test_images']
    return X_train, X_test, Y_train, Y_test 


def make_dataloaders(config, subject_file_idx, ROI_split, mode, batch_size=64): 
    """Generates train and test dataloaders

    Args:
        subject_file_ids (string): the indx of subjec file (0-4)
        ROI_split (list[string]): the ROI, passed to the @DataHandler class
        batch_size (int, optional): the batch size. Defaults to 64.
        mode (str, optional): dec mode return (fmri, img) pairs, encdec returns (img)

    Returns:
        _type_: _description_
    """
    assert mode in ["encdec", "dec"]
    
    # make sure image.npz is there
    image_generate(config)

    # X is a dict keyed by ROI name 
    X_train, X_test, Y_train, Y_test = get_data(config, subject_file_idx, ROI_split)
    
    if mode == "dec":
        train_ds = DataSet(X_train, Y_train)
        test_ds = DataSet(X_test, Y_test)
    else:
        train_ds = DataSet(Y_train)
        test_ds = DataSet(Y_test)
    
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_dl = DataLoader(test_ds, batch_size=batch_size)

    return train_dl, test_dl
    

# Create Image dataset from Imagenet folders
def image_generate(config, interpolation = Image.BICUBIC):
    # Check if files exists
    if(os.path.exists(config["images_npz"])):
        return
            
    imgnet_test = os.path.join(config["fmri_imagenet_data"], "test/") 
    imgnet_train = os.path.join(config["fmri_imagenet_data"], "training/") 
    size = config["img_size"]["inet"] 
    out_file= config["images_npz"] 

    test_csv='imageID_test.csv'
    train_csv='imageID_training.csv'
    test_csv = os.path.join(config["fmri_imagenet_data"], test_csv)
    train_csv = os.path.join(config["fmri_imagenet_data"], train_csv)

    test_im = pd.read_csv(test_csv,header=None)
    train_im = pd.read_csv(train_csv,header=None)

    n_test = len(test_im.index)
    n_train = len(train_im.index)

    test_images = np.zeros([n_test, size, size, 3])
    train_images = np.zeros([n_train, size, size, 3])

    count = 0
    for file in list(test_im[1]):
        img = imread(imgnet_test  + '/' + file.strip())
        test_images[count] = image_prepare(img, size,interpolation)
        count += 1

    count = 0
    for file in list(train_im[1]):
        img = imread(imgnet_train  + '/' + file)
        train_images[count] = image_prepare(img, size,interpolation)
        count += 1

    print(train_images.shape)
    np.savez(out_file, train_images=train_images, test_images=test_images)


#ceneter crop and resize
def image_prepare(img,size,interpolation):

    out_img = np.zeros([size,size,3])
    s = img.shape
    r = s[0]
    c = s[1]

    trimSize = np.min([r, c])
    lr = int((c - trimSize) / 2)
    ud = int((r - trimSize) / 2)
    img = img[ud:min([(trimSize + 1), r - ud]) + ud, lr:min([(trimSize + 1), c - lr]) + lr]

    # img = imresize(img, size=[size, size], interp=interpolation)
    img = Image.fromarray(img)
    img = img.resize((size, size), interpolation)
    if (np.ndim(img) == 3):
        out_img = img
    else:
        out_img[ :, :, 0] = img
        out_img[ :, :, 1] = img
        out_img[ :, :, 2] = img

    out_img = np.array(out_img)/255.
    return out_img #.astype('float32')/255.0

