import os

import brainpy.math as bm
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torch
import numpy as np


def linear(x):
    return x


def load_config(config):
    normalize = transforms.Normalize(mean=[0], std=[1])
    inv_normalize = transforms.Normalize(mean=[-0 / 1], std=[1 / 1])
    _transforms = transforms.Compose([transforms.ToTensor(), normalize])
    
    if config == 'taxib':
        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        class CustomDataset(Dataset):
            def __init__(self, numpy_file):
                # 加载数据
                data = np.load(numpy_file)

                # 选择第二维的第一个元素
                data = data[:, 0, :, :]
                # 重塑数据为 (500, 8, 32, 32)
                self.data = data.reshape(-1, 8, 32, 32)
                

            def __len__(self):
                return len(self.data)

            def __getitem__(self, idx):
                # 返回一个样本
                #data = self.transform(self.data)
                return torch.tensor(self.data[idx], dtype=torch.float32)
        neuron_size = [0, 2000,2000,1000]
        duration = 10
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.3
        dt = 0.05

        noise = 0
        model_name = 'taxib_net'
        using_epoch = 10
        training_data =  CustomDataset('data.npy')  
    
    if config == 'MNIST_rio':
        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 2000,2000,1000]
        duration = 10
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.3
        dt = 0.05

        noise = 0
        model_name = 'MNIST_rio_net_layer3'
        using_epoch = 10
        training_data = datasets.FashionMNIST(root='iuput_data',
                                              train=True,
                                              download=True,
                                              transform=_transforms)
    
    if config == 'street_xy':
        class CustomDataset(Dataset):

            def __init__(self, folder_path, transform=None):
                self.folder_path = folder_path
                self.transform = transform
                self.image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

            def __len__(self):
                return len(self.image_paths)

            def __getitem__(self, idx):
                image_path = self.image_paths[idx]
                image = Image.open(image_path).convert('RGB')
                # image.show()
                label1 = os.path.basename(image_path).split('_')[0]
                label2 = os.path.basename(image_path).split('_')[1]
                label3 = os.path.basename(image_path).split('_')[2]
                label4 = os.path.basename(image_path).split('_')[3].split('.')[0]
                labelx=int(label4)+3*int(label3)
                labely=int(label2)+10*int(label1)

                if self.transform:
                    image = self.transform(image)

                return image, labelx,labely

        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 4000, 2000, 2000, 200]
        duration = 5
        eta = 0.05
        f = bm.leaky_relu
        l_duration = 0.1
        dt = 0.02

        noise = 0
        model_name = 'street_xy'
        using_epoch = 27
        training_data = CustomDataset(folder_path='old_1800_tran',
                                         transform=_transforms)
    
    if config == 'street_3600':
        class CustomDataset(Dataset):

            def __init__(self, folder_path, transform=None):
                self.folder_path = folder_path
                self.transform = transform
                self.image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

            def __len__(self):
                return len(self.image_paths)

            def __getitem__(self, idx):
                image_path = self.image_paths[idx]
                image = Image.open(image_path).convert('RGB')
                # image.show()
                label1 = os.path.basename(image_path).split('_')[0]
                label2 = os.path.basename(image_path).split('_')[1]
                label3 = os.path.basename(image_path).split('_')[2]
                label4 = os.path.basename(image_path).split('_')[3].split('.')[0]
                label=int(label4)+4*int(label3)+36*int(label2)+180*int(label1)

                if self.transform:
                    image = self.transform(image)

                return image, label

        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 4000, 2000, 2000, 3800]
        duration = 5
        eta = 0.03
        f = bm.leaky_relu
        l_duration = 0.1
        dt = 0.02

        noise = 0
        model_name = 'street_3600'
        using_epoch = 27
        training_data = CustomDataset(folder_path='argen_tran',
                                         transform=_transforms)

    if config == 'street_layer1':
        class CustomDataset(Dataset):

            def __init__(self, folder_path, transform=None):
                self.folder_path = folder_path
                self.transform = transform
                self.image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

            def __len__(self):
                return len(self.image_paths)

            def __getitem__(self, idx):
                image_path = self.image_paths[idx]
                image = Image.open(image_path).convert('RGB')
                # image.show()
                label1 = os.path.basename(image_path).split('_')[0]
                label2 = os.path.basename(image_path).split('_')[1]
                label3 = os.path.basename(image_path).split('_')[2]
                label4 = os.path.basename(image_path).split('_')[3].split('.')[0]
                label=int(label4)+3*int(label3)+18*int(label2)+180*int(label1)

                if self.transform:
                    image = self.transform(image)

                return image, label

        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 1800]
        duration = 5
        eta = 0.01
        f = bm.leaky_relu
        l_duration = 0.5
        dt = 0.05

        noise = 0
        model_name = 'street_layer1_1800'
        using_epoch = 27
        training_data = CustomDataset(folder_path='old_1800_tran',
                                         transform=_transforms)

    if config == 'street_layer2':
        class CustomDataset(Dataset):

            def __init__(self, folder_path, transform=None):
                self.folder_path = folder_path
                self.transform = transform
                self.image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

            def __len__(self):
                return len(self.image_paths)

            def __getitem__(self, idx):
                image_path = self.image_paths[idx]
                image = Image.open(image_path).convert('RGB')
                # image.show()
                label1 = os.path.basename(image_path).split('_')[0]
                label2 = os.path.basename(image_path).split('_')[1]
                label3 = os.path.basename(image_path).split('_')[2]
                label4 = os.path.basename(image_path).split('_')[3].split('.')[0]
                label=int(label4)+3*int(label3)+18*int(label2)+180*int(label1)

                if self.transform:
                    image = self.transform(image)

                return image, label

        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 4000, 2000]
        duration = 5
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.2
        dt = 0.05

        noise = 0
        model_name = 'street_layer2'
        using_epoch = 27
        training_data = CustomDataset(folder_path='old_1800_tran',
                                         transform=_transforms)

    if config == 'street_layer3':
        class CustomDataset(Dataset):

            def __init__(self, folder_path, transform=None):
                self.folder_path = folder_path
                self.transform = transform
                self.image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

            def __len__(self):
                return len(self.image_paths)

            def __getitem__(self, idx):
                image_path = self.image_paths[idx]
                image = Image.open(image_path).convert('RGB')
                # image.show()
                label1 = os.path.basename(image_path).split('_')[0]
                label2 = os.path.basename(image_path).split('_')[1]
                label3 = os.path.basename(image_path).split('_')[2]
                label4 = os.path.basename(image_path).split('_')[3].split('.')[0]
                label=int(label4)+3*int(label3)+18*int(label2)+180*int(label1)

                if self.transform:
                    image = self.transform(image)

                return image, label

        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 700, 700,700]
        duration = 5
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.2
        dt = 0.05

        noise = 0
        model_name = 'street_layer3'
        using_epoch = 27
        training_data = CustomDataset(folder_path='old_1800_tran',
                                         transform=_transforms)    
    
    if config == 'street_layer4':
        class CustomDataset(Dataset):

            def __init__(self, folder_path, transform=None):
                self.folder_path = folder_path
                self.transform = transform
                self.image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

            def __len__(self):
                return len(self.image_paths)

            def __getitem__(self, idx):
                image_path = self.image_paths[idx]
                image = Image.open(image_path).convert('RGB')
                # image.show()
                label1 = os.path.basename(image_path).split('_')[0]
                label2 = os.path.basename(image_path).split('_')[1]
                label3 = os.path.basename(image_path).split('_')[2]
                label4 = os.path.basename(image_path).split('_')[3].split('.')[0]
                label=int(label4)+3*int(label3)+18*int(label2)+180*int(label1)

                if self.transform:
                    image = self.transform(image)

                return image, label

        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 200, 200,200,200]
        duration = 7
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.2
        dt = 0.05

        noise = 0
        model_name = 'street_layer4'
        using_epoch = 27
        training_data = CustomDataset(folder_path='old_1800_tran',
                                         transform=_transforms)      

    if config == 'street_theta':
        class CustomDataset(Dataset):

            def __init__(self, folder_path, transform=None):
                self.folder_path = folder_path
                self.transform = transform
                self.image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

            def __len__(self):
                return len(self.image_paths)

            def __getitem__(self, idx):
                image_path = self.image_paths[idx]
                image = Image.open(image_path).convert('RGB')
                # image.show()
                label1 = os.path.basename(image_path).split('_')[0]
                label2 = os.path.basename(image_path).split('_')[1]
                label3 = os.path.basename(image_path).split('_')[2]
                label4 = os.path.basename(image_path).split('_')[3].split('.')[0]
                label=int(label4)+3*int(label3)+18*int(label2)+180*int(label1)

                if self.transform:
                    image = self.transform(image)

                return image, label

        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 4000, 2000, 2000, 2000]
        duration = 5
        eta = 0.05
        f = bm.leaky_relu
        l_duration = 0.1
        dt = 0.1

        noise = 0
        model_name = 'street_1800'
        using_epoch = 27
        training_data = CustomDataset(folder_path='old_1800_tran',
                                         transform=_transforms)

    if config == 'Lab_135':

        class CustomDataset(Dataset):

            def __init__(self, folder_path, transform=None):
                self.folder_path = folder_path
                self.transform = transform
                self.image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

            def __len__(self):
                return len(self.image_paths)

            def __getitem__(self, idx):
                image_path = self.image_paths[idx]
                image = Image.open(image_path).convert('RGB')
                # image.show()
                labelx = os.path.basename(image_path).split('_')[2].split('.')[0]  # 提取类别信息

                if self.transform:
                    image = self.transform(image)

                return image, labelx

        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 1024, 512, 256]
        duration = 10
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.4
        dt = 0.05

        noise = 0
        model_name = 'Lab_35_net'
        using_epoch = 27
        training_data = CustomDataset(folder_path='pic_empty',
                                         transform=_transforms)

    if config == 'Lab_empty':

        class CustomDataset(Dataset):

            def __init__(self, folder_path, transform=None):
                self.folder_path = folder_path
                self.transform = transform
                self.image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

            def __len__(self):
                return len(self.image_paths)

            def __getitem__(self, idx):
                image_path = self.image_paths[idx]
                image = Image.open(image_path).convert('RGB')
                # image.show()
                labelx = os.path.basename(image_path).split('_')[2].split('.')[0]  # 提取类别信息
                labely = os.path.basename(image_path).split('_')[1] 

                if self.transform:
                    image = self.transform(image)

                return image, labely, labelx

        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 4000, 2000, 2000, 2500]
        duration = 10 
        eta = 0.01
        f = bm.leaky_relu
        l_duration = 0.2
        dt = 0.02

        noise = 0
        model_name = 'Lab_empty_net'
        using_epoch = 27
        training_data = CustomDataset(folder_path='pic',
                                         transform=_transforms)
    if config == 'Lab_layer3':

        class CustomDataset(Dataset):

            def __init__(self, folder_path, transform=None):
                self.folder_path = folder_path
                self.transform = transform
                self.image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

            def __len__(self):
                return len(self.image_paths)

            def __getitem__(self, idx):
                image_path = self.image_paths[idx]
                image = Image.open(image_path).convert('RGB')
                # image.show()
                labelx = os.path.basename(image_path).split('_')[2].split('.')[0]  # 提取类别信息
                labely = os.path.basename(image_path).split('_')[1] 

                if self.transform:
                    image = self.transform(image)

                return image, labely, labelx

        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 200,200,200]
        duration = 10 
        eta = 0.01
        f = bm.leaky_relu
        l_duration = 0.2
        dt = 0.02

        noise = 0
        model_name = 'Lab_layer3_net'
        using_epoch = 27
        training_data = CustomDataset(folder_path='pic',
                                         transform=_transforms)
    if config == 'Lab_layer4':

        class CustomDataset(Dataset):

            def __init__(self, folder_path, transform=None):
                self.folder_path = folder_path
                self.transform = transform
                self.image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

            def __len__(self):
                return len(self.image_paths)

            def __getitem__(self, idx):
                image_path = self.image_paths[idx]
                image = Image.open(image_path).convert('RGB')
                # image.show()
                labelx = os.path.basename(image_path).split('_')[2].split('.')[0]  # 提取类别信息
                labely = os.path.basename(image_path).split('_')[1] 

                if self.transform:
                    image = self.transform(image)

                return image, labely, labelx

        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 500, 500, 500, 500]
        duration = 10 
        eta = 0.01
        f = bm.leaky_relu
        l_duration = 0.2
        dt = 0.02

        noise = 0
        model_name = 'Lab_layer4_net'
        using_epoch = 27
        training_data = CustomDataset(folder_path='pic',
                                         transform=_transforms)
    if config == 'CIFAR10_128':
        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 1024, 1024 ,1024]
        duration = 10
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.3
        dt = 0.05

        noise = 0
        model_name = 'CIFAR10_128_net'
        using_epoch = 27
        training_data = datasets.CIFAR10(root='iuput_data_CIFAR',
                                         train=True,
                                         download=True,
                                         transform=_transforms)

    if config == 'CIFAR10_32':
        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 512, 256, 128]
        duration = 10
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.3
        dt = 0.05

        noise = 0
        model_name = 'CIFAR10_32_net'
        using_epoch = 27
        training_data = datasets.CIFAR10(root='iuput_data_CIFAR',
                                         train=True,
                                         download=True,
                                         transform=_transforms)
    if config == 'CIFAR10_16_lay2':
        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 500, 100]
        duration = 10
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.3
        dt = 0.05

        noise = 0
        model_name = 'CIFAR10_16_lay2_net'
        using_epoch = 27
        training_data = datasets.CIFAR10(root='iuput_data_CIFAR',
                                         train=True,
                                         download=True,
                                         transform=_transforms)        
    if config == 'CIFAR10_64_lay3':
        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 1000, 1000, 1000]
        duration = 10
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.3
        dt = 0.05

        noise = 0
        model_name = 'CIFAR10_64_net_layer3'
        using_epoch = 27
        training_data = datasets.CIFAR10(root='iuput_data_CIFAR',
                                         train=True,
                                         download=True,
                                         transform=_transforms)
    if config == 'CIFAR10_64_lay2':
        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 1500, 1500]
        duration = 5
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.3
        dt = 0.05

        noise = 0
        model_name = 'CIFAR10_64_net_layer2'
        using_epoch = 27
        training_data = datasets.CIFAR10(root='iuput_data_CIFAR',
                                         train=True,
                                         download=True,
                                         transform=_transforms)
        
    if config == 'CIFAR10_64_lay4':
        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 750, 750,750,750]
        duration = 10
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.3
        dt = 0.05

        noise = 0
        model_name = 'CIFAR10_64_net_layer4'
        using_epoch = 27
        training_data = datasets.CIFAR10(root='iuput_data_CIFAR',
                                         train=True,
                                         download=True,
                                         transform=_transforms)

    if config == 'CIFAR10_neuron_lay3':
        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 200,200,200]
        duration = 10
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.3
        dt = 0.05

        noise = 0
        model_name = 'CIFAR10_neuron200_lay3_net'
        using_epoch = 27
        training_data = datasets.CIFAR10(root='iuput_data_CIFAR',
                                         train=True,
                                         download=True,
                                         transform=_transforms)

    if config == 'Fashionmnist_layer3':
        # neuron_size = bm.insert(bm.linspace(200, 200, L), 0, bm.array([0])).astype(int).tolist()
        neuron_size = [0, 1000,1000,1000]
        duration = 10
        eta = 0.1
        f = bm.leaky_relu
        l_duration = 0.3
        dt = 0.05

        noise = 0
        model_name = 'Fashionmnist_net_layer3'
        using_epoch = 10
        training_data = datasets.FashionMNIST(root='iuput_data',
                                              train=True,
                                              download=True,
                                              transform=_transforms)

    return neuron_size, l_duration, duration, eta, dt, f, noise, \
    model_name, using_epoch, training_data, normalize, inv_normalize, _transforms
    
