from library import configs
from pathlib import Path

def Get_MNIST_Config(
        experiment_id: int = 0,
        experiment_name: str = "",
        results_di: str = "results/",
        image_thresholds: int = 3, # 4
        augmentation: bool = False,
        binarize_input_train: float = 0.0,
        num_workers: int = 4,
        pin_memory: bool = False,
        valid_set_size: float = 0.2,
        batch_size_train: int = 128,
        num_epochs: int = 10,
        eval_freq: int = 2,
        learning_rate: float = 0.01,
        extensive_eval_train: bool = False,
        training_bit_count: int = 32,
        batch_size_test: int = 10000,
        store_raw_values: bool = False,
        store_logit_stats: bool = False,
        l1_regularization: float = 0.0,
        extensive_eval_test: bool = False,
        eval_compiled_model: bool = False,
        eval_binarized: float = 0.5,
        packbits_eval: bool = False,
        compile_model: bool = False,
        tau: int = 10,
        connections: str = 'unique',
        architecture: str = 'randomly_connected',
        num_neurons: int = 64_000,
        num_layers: int = 6,
        custom_layer_sizes: list = None,
        last_layer_neurons: int = 64_000,
        use_groupsum: bool = True,
        full_ffn: bool = False,
        use_mygroupsum: bool = False,
        tree_classification: bool = False,
        tree_layers: int = 0,
        full_tree_output: bool = False,
        dropout_percentage: float = 1.0,
        distanceLayer: bool = False,
        use_ffbinary: bool = False,
        use_ffn: bool = False, # Last layer a FFN to evaluate output and potential
        used_ffn: list = None,
        grad_factor: int = 1,
        device: str = "cuda", 
        seed: int = 0,
    ) -> configs.DifflogicConfig:

    assert custom_layer_sizes == None or num_layers == len(custom_layer_sizes) 
    
    experiment_config = configs.ExperimentConfig(
        experiment_id=experiment_id,
        seed=seed,
        results_di=results_di,
        store_raw_values=store_raw_values,
        store_logit_stats=store_logit_stats
    )
    data_config = configs.DataConfig(
        dataset=configs.Dataset.MNIST,
        image_thresholds=image_thresholds,
        augmentation=augmentation,
        binarize_input_train=binarize_input_train,
        eval_binarized=eval_binarized,
        num_workers=num_workers,
        pin_memory=pin_memory,
        download=True,
        valid_set_size=valid_set_size,
        device=device,
        seed=seed,
    )
    train_config = configs.TrainConfig(
        batch_size=batch_size_train,
        num_epochs=num_epochs,
        eval_freq=eval_freq,
        l1_regularization=l1_regularization,
        learning_rate=learning_rate,
        extensive_eval=extensive_eval_train,
        training_bit_count=training_bit_count,
        eval_binarized=eval_binarized
    )
    test_config = configs.TestConfig(
        batch_size=batch_size_test,
        extensive_eval=extensive_eval_test,
        eval_compiled_model=eval_compiled_model,
        packbits_eval=packbits_eval,
        compile_model=compile_model,
    )
    model_config = configs.ModelConfig(
        tau=tau,
        connections=connections,
        architecture=architecture,
        num_neurons=num_neurons,
        num_layers=num_layers,
        custom_layer_sizes=custom_layer_sizes,
        last_layer_neurons=last_layer_neurons,
        grad_factor=grad_factor,
        use_groupsum=use_groupsum,
        use_mygroupsum=use_mygroupsum,
        dropout_percentage=dropout_percentage,
        distanceLayer=distanceLayer,
        use_ffbinary=use_ffbinary,
        tree_classification=tree_classification,
        tree_layers=tree_layers,
        full_tree_output=full_tree_output,
        use_ffn=use_ffn, # last layer
        used_ffn=used_ffn,
        full_ffn=full_ffn,
        device=device,
        seed=seed,
    )
    compilation_config = configs.CompilationConfig(
        num_bits=[16, 64],
        cpu_compiler=configs.CPUCompiler.GCC,
        verbose=False,
        num_repetitions=3,
    )

    return configs.DifflogicConfig(
        data_config=data_config,
        model_config=model_config,
        train_config=train_config,
        test_config=test_config,
        experiment_config=experiment_config,
        compilation_config=compilation_config,
    )


def Get_BASE_Config(
        experiment_id: int = 0,
        experiment_name: str = "",
        results_di: str = "results/",
        image_thresholds: int = 3,
        augmentation: bool = False,
        binarize_input_train: float = 0.0,
        num_workers: int = 4,
        pin_memory: bool = False,
        valid_set_size: float = 0.2,
        batch_size_train: int = 128,
        num_epochs: int = 10,
        eval_freq: int = 2,
        learning_rate: float = 0.01,
        extensive_eval_train: bool = False,
        training_bit_count: int = 32,
        batch_size_test: int = 10000,
        store_raw_values: bool = False,
        store_logit_stats: bool = False,
        l1_regularization: float = 0.0,
        extensive_eval_test: bool = False,
        eval_compiled_model: bool = False,
        eval_binarized: float = 0.5,
        packbits_eval: bool = False,
        compile_model: bool = False,
        tau: int = 10,
        connections: str = 'unique',
        architecture: str = 'randomly_connected',
        num_neurons: int = 64_000,
        num_layers: int = 6,
        num_classes: int = 0,
        custom_layer_sizes: list = None,
        last_layer_neurons: int = 0, # Changed from 64'000 to 0 on 27.07.2025 (if not set, choose num_neurons)
        use_groupsum: bool = True,
        use_mygroupsum: bool = False,
        distanceLayer: bool = False,
        distanceLayer2: bool = False,
        distance_dimension: int = 0,
        tree_classification: bool = False,
        tree_layers: int = 0,
        full_tree_output: bool = False,
        dropout_percentage: float = 1.0,
        use_ffbinary: bool = False,
        use_ffn: bool = False, # Last layer a FFN to evaluate output and potential
        used_ffn: list = None, # List of layer sizes to be appended to the difflogic network. Usable with 'use_ffn'.
        full_ffn: bool = False,
        ffn_layer_size: int = 0, # Layer size for ffn baseline, when learning the dataset with a 'full_ffn'.
        upscale_input: int = 0,
        grad_factor: int = 1,
        device: str = "cuda", 
        save_model_on: str ="valid",
        seed: int = 0,
    ) -> configs.DifflogicConfig:

    assert custom_layer_sizes == None or num_layers == len(custom_layer_sizes) 
    
    experiment_config = configs.ExperimentConfig(
        experiment_id=experiment_id,
        seed=seed,
        results_di=results_di,
        store_raw_values=store_raw_values,
        store_logit_stats=store_logit_stats
    )
    data_config = configs.DataConfig(
        dataset=configs.Dataset.MNIST,
        image_thresholds=image_thresholds,
        augmentation=augmentation,
        num_classes=num_classes,
        binarize_input_train=binarize_input_train,
        eval_binarized=eval_binarized,
        num_workers=num_workers,
        pin_memory=pin_memory,
        download=True,
        valid_set_size=valid_set_size,
        upscale_input=upscale_input,
        device=device,
        seed=seed,
    )
    train_config = configs.TrainConfig(
        batch_size=batch_size_train,
        num_epochs=num_epochs,
        eval_freq=eval_freq,
        l1_regularization=l1_regularization,
        learning_rate=learning_rate,
        extensive_eval=extensive_eval_train,
        training_bit_count=training_bit_count,
        eval_binarized=eval_binarized,
        save_model_on=save_model_on
    )
    test_config = configs.TestConfig(
        batch_size=batch_size_test,
        extensive_eval=extensive_eval_test,
        eval_compiled_model=eval_compiled_model,
        packbits_eval=packbits_eval,
        compile_model=compile_model,
    )
    model_config = configs.ModelConfig(
        tau=tau,
        connections=connections,
        architecture=architecture,
        num_neurons=num_neurons,
        num_layers=num_layers,
        custom_layer_sizes=custom_layer_sizes,
        last_layer_neurons=last_layer_neurons,
        grad_factor=grad_factor,
        use_groupsum=use_groupsum,
        use_mygroupsum=use_mygroupsum,
        dropout_percentage=dropout_percentage,
        distanceLayer=distanceLayer,
        distanceLayer2=distanceLayer2,
        distance_dimension=distance_dimension,
        use_ffbinary=use_ffbinary,
        tree_classification=tree_classification,
        tree_layers=tree_layers,
        full_tree_output=full_tree_output,
        use_ffn=use_ffn, # last layer
        used_ffn=used_ffn,
        full_ffn=full_ffn,
        device=device,
        seed=seed,
    )
    compilation_config = configs.CompilationConfig(
        num_bits=[16, 64],
        cpu_compiler=configs.CPUCompiler.GCC,
        verbose=False,
        num_repetitions=3,
    )

    return configs.DifflogicConfig(
        data_config=data_config,
        model_config=model_config,
        train_config=train_config,
        test_config=test_config,
        experiment_config=experiment_config,
        compilation_config=compilation_config,
    )


def Get_CIFER10_Config(
        experiment_id: int = 0,
        image_thresholds: int = 3,
        num_workers: int = 4,
        pin_memory: bool = False,
        valid_set_size: float = 0.2,
        batch_size_train: int = 128,
        num_epochs: int = 10,
        eval_freq: int = 2,
        learning_rate: float = 0.01,
        extensive_eval_train: bool = False,
        training_bit_count: int = 32,
        batch_size_test: int = 10000,
        extensive_eval_test: bool = False,
        eval_compiled_model: bool = False,
        packbits_eval: bool = False,
        compile_model: bool = False,
        tau: int = 10,
        connections: str = 'unique',
        architecture: str = 'randomly_connected',
        num_neurons: int = 128_000,
        num_layers: int = 4,
        grad_factor: int = 1,
        device: str = "cuda", 
        seed: int = 0,
    ) -> configs.DifflogicConfig:
    
    experiment_config = configs.ExperimentConfig(
        experiment_id=experiment_id,
        seed=seed,
    )
    data_config = configs.DataConfig(
        dataset=configs.Dataset.CIFAR10,
        image_thresholds=image_thresholds,
        num_workers=num_workers,
        pin_memory=pin_memory,
        download=True,
        valid_set_size=valid_set_size,
        device=device,
        seed=seed,
    )
    train_config = configs.TrainConfig(
        batch_size=batch_size_train,
        num_epochs=num_epochs,
        eval_freq=eval_freq,
        learning_rate=learning_rate,
        extensive_eval=extensive_eval_train,
        training_bit_count=training_bit_count,
    )
    test_config = configs.TestConfig(
        batch_size=batch_size_test,
        extensive_eval=extensive_eval_test,
        eval_compiled_model=eval_compiled_model,
        packbits_eval=packbits_eval,
        compile_model=compile_model,
    )
    model_config = configs.ModelConfig(
        tau=tau,
        connections=connections,
        architecture=architecture,
        num_neurons=num_neurons,
        num_layers=num_layers,
        grad_factor=grad_factor,
        device=device,
        seed=seed,
    )
    compilation_config = configs.CompilationConfig(
        num_bits=[16, 64],
        cpu_compiler=configs.CPUCompiler.GCC,
        verbose=False,
        num_repetitions=3,
    )

    return configs.DifflogicConfig(
        data_config=data_config,
        model_config=model_config,
        train_config=train_config,
        test_config=test_config,
        experiment_config=experiment_config,
        compilation_config=compilation_config,
    )