import torch
import os


class Config:
    def __init__(self, num):
        self.num = num
        if num == 0:
            # 基础设置
            self.seed = 3407
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

            # 数据集设置
            self.dataset = 'streetview'
            self.data_root = '/home/xiangyuanpeng/VQforCIFAR/data/street/5181010'
            self.train_data_path = os.path.join(self.data_root, "train")  # 训练集路径
            self.test_data_path = os.path.join(self.data_root, "test")  # 测试集路径

            # 使用原始路径作为后备
            self.data_path = "input_data/5181010_small/"  # 向后兼容

            self.batch_size = 1
            self.num_workers = 4
            self.img_height = 128
            self.img_width = 128
            self.sequence_length = 128
            self.num_sequences = 10000

            # 模型参数
            self.model_type = 'vit_bump'
            # vq
            self.embedding_dim = 64
            self.num_embeddings = 512

            # vit
            self.dim = 1024
            self.depth = 3
            self.compress_rate = 4
            self.p_in_encoder = 32
            self.p_in_decoder = 16

            # 训练参数
            self.beta = 1
            self.learning_rate = 1e-4
            self.num_epochs = 40
            self.save_interval = 2  # 每2个epoch保存一次

            # 保存路径
            self.model_name = 'vit_1024'
            self.results_dir = "results"
            self.model_dir = "results/models/" + self.model_name
            self.img_dir = "results/images/" + self.model_name

        if num == 1:
            # 基础设置
            self.seed = 3407
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

            # 数据集设置
            self.dataset = 'streetview'
            # self.data_root = "input_data/5181010_small_split/"  # 数据根目录
            self.data_root = '/home/xiangyuanpeng/VQ4STREET/input_data/5181010_small_split'
            self.train_data_path = os.path.join(self.data_root, "train")  # 训练集路径
            self.test_data_path = os.path.join(self.data_root, "test")  # 测试集路径

            # 使用原始路径作为后备
            self.data_path = "input_data/5181010_small/"  # 向后兼容

            self.batch_size = 1
            self.num_workers = 4
            self.img_height = 40
            self.img_width = 80
            self.sequence_length = 64
            self.num_sequences = 10000

            # 模型参数
            self.model_type = 'vit_fsq'
            # resnet

            # vit
            self.d_out = 16
            self.dim = 768
            self.depth = 3
            self.p_in_encoder = 8

            # 训练参数
            self.beta = 1
            self.learning_rate = 1e-4
            self.num_epochs = 20
            self.save_interval = 2  # 每2个epoch保存一次

            # 保存路径
            self.model_name = 'vit_FSQ'
            self.results_dir = "results"
            self.model_dir = "results/models/" + self.model_name
            self.img_dir = "results/images/" + self.model_name

        if num == 2:
            # 基础设置
            self.seed = 3407
            self.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

            # 数据集设置
            self.dataset = 'streetview'
            # self.data_root = "input_data/5181010_small_split/"  # 数据根目录
            self.data_root = '/home/xiangyuanpeng/VQ4STREET/input_data/5181010_small_split'
            self.train_data_path = os.path.join(self.data_root, "train")  # 训练集路径
            self.test_data_path = os.path.join(self.data_root, "test")  # 测试集路径

            # 使用原始路径作为后备
            self.data_path = "input_data/5181010_small/"  # 向后兼容

            self.batch_size = 1
            self.num_workers = 4
            self.img_height = 40
            self.img_width = 80
            self.sequence_length = 128
            self.num_sequences = 10000

            # 模型参数
            self.model_type = 'resnet_bump'
            # vq
            self.embedding_dim = 64
            self.num_embeddings = 512

            # resnet
            self.residual_blocks = 2
            self.hidden_dims = [256, 256]

            # 训练参数
            self.beta = 1
            self.learning_rate = 1e-4
            self.num_epochs = 20
            self.save_interval = 2  # 每2个epoch保存一次

            # 保存路径
            self.model_name = 'resnet_vq'
            self.results_dir = "results"
            self.model_dir = "results/models/" + self.model_name
            self.img_dir = "results/images/" + self.model_name

        if num == 3:
            # 基础设置
            self.seed = 3407
            self.device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

            # 数据集设置
            self.dataset = 'streetview'
            self.data_root = '/home/xiangyuanpeng/VQforCIFAR/data/street/5181010'
            self.train_data_path = os.path.join(self.data_root, "train")  # 训练集路径
            self.test_data_path = os.path.join(self.data_root, "test")  # 测试集路径

            # 使用原始路径作为后备
            self.data_path = "input_data/5181010_small/"  # 向后兼容

            self.batch_size = 1
            self.num_workers = 4
            self.img_height = 128
            self.img_width = 128
            self.sequence_length = 256
            self.num_sequences = 10000

            # 模型参数
            self.model_type = 'vit_vq'
            # vq
            self.embedding_dim = 64
            self.num_embeddings = 512

            # vit
            self.dim = 768
            self.depth = 3
            self.compress_rate = 4
            self.p_in_encoder = 32
            self.p_in_decoder = 16

            # 训练参数
            self.beta = 1
            self.learning_rate = 1e-4
            self.num_epochs = 30
            self.save_interval = 2  # 每2个epoch保存一次

            # 保存路径
            self.model_name = 'vit_vq'
            self.results_dir = "results"
            self.model_dir = "results/models/" + self.model_name
            self.img_dir = "results/images/" + self.model_name

        if num == 4:
            # 基础设置
            self.seed = 3407
            self.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

            # 数据集设置
            self.dataset = 'streetview'
            self.data_root = '/home/xiangyuanpeng/VQforCIFAR/data/street/5181010'
            self.train_data_path = os.path.join(self.data_root, "train")  # 训练集路径
            self.test_data_path = os.path.join(self.data_root, "test")  # 测试集路径

            # 使用原始路径作为后备
            self.data_path = "input_data/5181010_small/"  # 向后兼容

            self.batch_size = 1
            self.num_workers = 4
            self.img_height = 128
            self.img_width = 128
            self.sequence_length = 4
            self.num_sequences = 10000

            # 模型参数
            self.model_type = 'vit_bump'
            # vq
            self.embedding_dim = 64
            self.num_embeddings = 512

            # vit
            self.dim = 1024
            self.depth = 3
            self.compress_rate = 4
            self.p_in_encoder = 32
            self.p_in_decoder = 16

            # 训练参数
            self.beta = 1
            self.learning_rate = 1e-4
            self.num_epochs = 30
            self.save_interval = 2  # 每2个epoch保存一次

            # 保存路径
            self.model_name = 'vit_1024_encoder'
            self.results_dir = "results"
            self.model_dir = "results/models/" + self.model_name
            self.img_dir = "results/images/" + self.model_name

            self.model1_name = 'vit_1024'
            self.model1_dir = "results/models/" + self.model1_name


        if num == 5:
            # 基础设置
            self.seed = 3407
            self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

            # 数据集设置
            self.dataset = 'streetview'
            self.data_root = '/home/xiangyuanpeng/VQforCIFAR/data/street/5181010'
            self.train_data_path = os.path.join(self.data_root, "train")  # 训练集路径
            self.test_data_path = os.path.join(self.data_root, "test")  # 测试集路径

            # 使用原始路径作为后备
            self.data_path = "input_data/5181010_small/"  # 向后兼容

            self.batch_size = 1
            self.num_workers = 4
            self.img_height = 128
            self.img_width = 128
            self.sequence_length = 128
            self.num_sequences = 10000

            # 模型参数
            self.model_type = 'vit_bump'
            # vq
            self.embedding_dim = 64
            self.num_embeddings = 512

            # vit
            self.dim = 1024
            self.depth = 3
            self.compress_rate = 4
            self.p_in_encoder = 32
            self.p_in_decoder = 16

            # 训练参数
            self.beta = 1
            self.learning_rate = 1e-4
            self.num_epochs = 40
            self.save_interval = 2  # 每2个epoch保存一次

            # 保存路径
            self.model_name = 'vit_1024_2'
            self.results_dir = "results"
            self.model_dir = "results/models/" + self.model_name
            self.img_dir = "results/images/" + self.model_name

