import os.path
import glob
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image
import scipy.io as scio
import numpy as np

def get_transform(opt):
    transform_list = []
    if opt.resize_or_crop == 'resize_and_crop':
        zoom = 1 + 0.1*radom.randint(0,4)
        osize = [int(400*zoom), int(600*zoom)]
        transform_list.append(transforms.Scale(osize, Image.BICUBIC))
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'crop':
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'scale_width':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.fineSize)))
    elif opt.resize_or_crop == 'scale_width_and_crop':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.loadSize)))
        transform_list.append(transforms.RandomCrop(opt.fineSize))  
    # elif opt.resize_or_crop == 'no':
    #     osize = [384, 512]
    #     transform_list.append(transforms.Scale(osize, Image.BICUBIC))

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.RandomHorizontalFlip())

    transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5),
                                            (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

def __scale_width(img, target_width):
    ow, oh = img.size
    if (ow == target_width):
        return img
    w = target_width
    h = int(target_width * oh / ow)
    return img.resize((w, h), Image.BICUBIC)


def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)

    return images

class YuanDaiDataset(data.Dataset):
    def __init__(self,T,root_path):
        self.root_path = root_path#"/data2/s503-7/SNN_VisualOrgan/yuandai/dataset"
        self.npy_path = os.path.join(self.root_path,'*.npy')
        self.npy_list = glob.glob(self.npy_path)
        self.T = T
        print(self.npy_list)
        self.make_data_list()
        # print(self.data_list)


    def make_data_list(self):
        self.data_list = []
        path1 = '/data2/s503-7/SNN_VisualOrgan/yuandai/dataset/stimTimePoints.mat'
        time_point = scio.loadmat(path1)
        for i,_ in enumerate(self.npy_list):
            time_dict = np.load(self.npy_list[i],allow_pickle=True)
            # print("time_dict",type(time_dict))
            # input()
            if 'L1' in self.npy_list[i]:
                target = 0
            elif 'X2' in self.npy_list[i]:
                target = 1
            for time_stamp,mea_dict in time_dict.tolist().items():
                # print("time_stamp",time_stamp)
                temp_list = []
                temp_one_data = np.zeros([self.T,59]) 
                start_time = time_point['stimTimePoints'][int(time_stamp)]
                end_time = time_point['stimTimePoints'][int(time_stamp)+1]
                for false_index,(mea_index,spike_value) in enumerate(mea_dict.items()):
                    # print("mea_index",false_index,mea_index)
                    # print("-------------------------------------------")
                    time_interval = np.linspace(start = 0, stop = 1, num = self.T+1)
                    spike_value = spike_value - start_time
                    # print("spike_value",spike_value)
                    # print("--------")
                    # input()
                    # print("time_interval",time_interval)
                    # input()
                    # print("pos",spike_value[-1]>time_interval[-1])
                    for index1, _ in enumerate(spike_value):
                        cnt = 0
                        # print("spike_value[index1]",spike_value[index1])
                        for index2, _ in enumerate(time_interval[:-1]):
                            # print("time_interval[index2]",time_interval[index2])
                            if spike_value[index1] > time_interval[-1]:
                                # print("!!!!!!!!!!!")
                                temp_one_data[self.T-1][int(false_index)] = 1
                                break
                            elif spike_value[index1] >= time_interval[index2] and spike_value[index1] <=time_interval[index2+1]:
                                temp_one_data[cnt][int(false_index)] = 1
                                # print(cnt)
                            elif spike_value[index1] > time_interval[index2]:
                                cnt += 1
                            # elif spike_value[index1] > time_interval[index2]:
                            #     break
                    # print("temp_one_data",temp_one_data)
                    # input()
                temp_list.append(temp_one_data)
                temp_list.append(target)
                # print(temp_list)
                # input()
                self.data_list.append(temp_list)
                # print("append one")
        # print(len(self.data_list))
        # print(self.data_list)
        # cnt = 0
        # for po in self.data_list:
        #     if po[1] == 1:
        #         cnt+=1
        # print(cnt)
        # input()
                        

    
    def __getitem__(self, index):
        return self.data_list[index][0],self.data_list[index][1]

    def __len__(self):
        return len(self.data_list)

    def name(self):
        return 'YuanDai_Dataset'
    
class SimulatorDataset(data.Dataset):
    def __init__(self,root_path,phase):
        self.root_path = root_path#"/data2/s503-7/SNN_VisualOrgan/yuandai/dataset"
        self.npy_path = os.path.join(self.root_path,'*.npy')
        self.npy_list = glob.glob(self.npy_path)
        self.phase = phase
        print(self.npy_list)
        self.make_data_list()
        # print(self.data_list)


    def make_data_list(self):
        self.data_list = []
        for i,_ in enumerate(self.npy_list):
            if self.phase in self.npy_list[i]:
                print("load npy:",self.npy_list[i])
                self.data_list = np.load(self.npy_list[i],allow_pickle=True)
                break
        
        for i,_ in enumerate(self.data_list):
            self.data_list[i][1] = int(self.data_list[i][1])
        print(type(self.data_list),type(self.data_list[0][0]),type(self.data_list[0][1]))
        # print("self.data_list[0][1]",type(int(self.data_list[0][1])),self.data_list[0][1])
                        

    
    def __getitem__(self, index):
        return self.data_list[index][0],self.data_list[index][1]

    def __len__(self):
        return len(self.data_list)

    def name(self):
        return 'YuanDai_Dataset'

if __name__ == '__main__':
    SimulatorDataset(root_path='/data2/s503-7/8.15对齐code/flie',phase='train')