from enum import Enum
import json
import os
import pathlib
from typing import List
import re

from library import misc

SCRIPT_PROJECT_VARIABLES = pathlib.Path("helper_script/project_variables.sh")


class Dataset(Enum):
    ADULT = 'adult'
    BREAST_CANCER = 'breast_cancer'
    MONK1 = 'monk1'
    MONK2 = 'monk2'
    MONK3 = 'monk3'
    MNIST = 'mnist'
    FMNIST = 'fmnist'
    KMNIST = 'kmnist'
    QMNIST = 'qmnist'
    EMNIST_BALANCED = 'emnist_balanced'
    EMNIST_LETTERS = 'emnist_letters'
    CUSTOM = 'custom'
    CUSTOM_IMAGENET = 'custom_imagenet'
    SYNTHETIC = 'synthetic'
    MNIST20x20 = 'mnist20x20'
    CIFAR10 = 'cifar10'
    CIFAR100 = 'cifar100'
    IMAGENET32 = 'imagenet32'


class CPUCompiler(Enum):
    GCC = 'gcc'
    CLANG = 'clang'


class CustomJSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, Enum):
            return obj.value
        elif isinstance(obj, pathlib.Path):
            return str(obj)
        return super().default(obj)


class CustomJSONDecoder(json.JSONDecoder):
    def __init__(self, *args, **kwargs):
        super().__init__(object_hook=self.object_hook, *args, **kwargs)
    def object_hook(self, obj):
        # Convert keys expected to be Enums.
        if "dataset" in obj:
            try:
                obj["dataset"] = Dataset(obj["dataset"])
            except ValueError:
                pass
        if "cpu_compiler" in obj:
            try:
                obj["cpu_compiler"] = CPUCompiler(obj["cpu_compiler"])
            except ValueError:
                pass
        if "results_di" in obj:
            obj["results_di"] = pathlib.Path(obj["results_di"])
        if "project_storage_dir" in obj:
            obj["project_storage_dir"] = pathlib.Path(obj["project_storage_dir"])
        if "library_dir" in obj:
            obj["library_dir"] = pathlib.Path(obj["library_dir"])
        if "data_dir" in obj:
            obj["data_dir"] = pathlib.Path(obj["data_dir"])
        return obj


class ExperimentConfig:
    def __init__(
            self,
            experiment_id: int = None,
            experiment_name: str = None,
            results_di: str = "", #pathlib.Path = pathlib.Path('results'), # 'results',
            seed: int = 0,
            store_raw_values: bool = False,
            store_logit_stats: bool = False,
            project_storage_dir: pathlib.Path = None,
        ) -> None:
        self.experiment_id = experiment_id
        self.experiment_name = experiment_name
        self.results_di = results_di
        self.store_raw_values = store_raw_values
        self.store_logit_stats = store_logit_stats
        self.seed = seed
        # Get the project storage directory from the environment variable if not provided
        if project_storage_dir is None:
            # Check if the environment variable is set
            if 'PROJECT_STORAGE_DIR' not in os.environ:
                script_env = misc.get_bash_env_vars(SCRIPT_PROJECT_VARIABLES)
                self.project_storage_dir = pathlib.Path(script_env['PROJECT_STORAGE_DIR'])
            else:
                self.project_storage_dir = pathlib.Path(os.environ['PROJECT_STORAGE_DIR'])
        else:
            self.project_storage_dir = project_storage_dir

    def from_dict(self, d: dict):
        self.__dict__.update(d)

    def to_dict(self):
        return self.__dict__


class CompilationConfig:
    def __init__(
            self,
            num_bits: List[int] = [16, 64],
            cpu_compiler: CPUCompiler = CPUCompiler.GCC,
            verbose: bool = False,
            library_dir: pathlib.Path = pathlib.Path('library'), # 'lib',
            num_repetitions: int = 3,
        ) -> None:
        self.num_bits = num_bits
        self.cpu_compiler = cpu_compiler
        self.verbose = verbose
        self.library_dir = library_dir
        self.num_repetitions = num_repetitions

    def from_dict(self, d: dict):
        self.__dict__.update(d)

    def to_dict(self):
        return self.__dict__


class DataConfig:
    def __init__(
            self,
            dataset: Dataset = 'mnist',
            data_dir: pathlib.Path = pathlib.Path('data'), # 'data',
            image_thresholds: int = 0,
            augmentation: bool = False,
            binarize_input_train: float = 0.0,
            eval_binarized: float = 0.0,
            num_classes: int = 0,
            num_workers: int = 4,
            pin_memory: bool = False,
            download: bool = True,
            valid_set_size: float = 0,
            upscale_input: int = 0,
            device: str = 'cuda',  # TODO: Change to Enum or torch.device
            seed: int = 0,
        ) -> None:
        self.dataset = dataset
        self.data_dir = data_dir
        if image_thresholds == 0 and 'thresholds' in dataset:
            # Create pattern to extract the number followed by '.thresholds' from the dataset name
            pattern = re.compile(r'\d+(?=.thresholds)')
            result = pattern.search(dataset)
            if result is None:
                raise ValueError(f"Could not extract the number of thresholds from the dataset name '{dataset}'")
            else:
                self.image_thresholds = int(result.group(0))
        else:
            self.image_thresholds = image_thresholds

        self.augmentation = augmentation
        self.binarize_input_train = binarize_input_train
        self.eval_binarized = eval_binarized
        self.num_workers = num_workers
        self.num_classes = num_classes
        self.pin_memory = pin_memory
        self.download = download
        assert 0 <= valid_set_size <= 1
        self.valid_set_size: float = valid_set_size
        self.upscale_input = upscale_input
        self.device = device
        self.seed = seed

    def from_dict(self, d: dict):
        self.__dict__.update(d)

    def to_dict(self):
        return self.__dict__


class TrainConfig:
    def __init__(
            self,
            batch_size = 128,
            num_epochs = 10,
            eval_freq = 2,
            learning_rate = 0.01,
            extensive_eval = False,
            eval_binarized = 0.0,
            training_bit_count = 32,
            decrease_tau = None,
            l1_regularization = 0.0,
            save_model_on="valid"
        ) -> None:
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.eval_freq = eval_freq
        self.learning_rate = learning_rate
        self.l1_regularization=l1_regularization
        self.extensive_eval = extensive_eval
        self.eval_binarized = eval_binarized
        self.training_bit_count = training_bit_count
        self.decrease_tau = decrease_tau # reduce tau every 5 epochs
        self.save_model_on = save_model_on

    def from_dict(self, d: dict):
        self.__dict__.update(d)

    def to_dict(self):
        return self.__dict__


class TestConfig:
    def __init__(
            self,
            batch_size = 128,
            extensive_eval = False,
            eval_compiled_model = False,
            packbits_eval = False,
            compile_model = False,
        ) -> None:
        self.batch_size = batch_size
        self.extensive_eval = extensive_eval
        self.eval_compiled_model = eval_compiled_model
        self.packbits_eval = packbits_eval
        self.compile_model = compile_model

        self.dataloaders: dict = {}

    def from_dict(self, d: dict):
        self.__dict__.update(d)

    def to_dict(self):
        return self.__dict__


class ModelConfig:
    def __init__(
            self, 
            tau = 10,
            connections = 'unique',
            architecture = 'randomly_connected',
            num_neurons = 64_000, 
            num_layers = 6,
            custom_layer_sizes = None,
            last_layer_neurons = 64_000,
            grad_factor = 1,
            device = 'cuda',
            seed = 0,
            use_groupsum = True,
            use_mygroupsum = False,
            dropout_percentage = 1.0,
            distanceLayer = False,
            distanceLayer2 = False,
            distance_dimension = 0,
            tree_classification = False,
            tree_layers = 0,
            full_tree_output = False,
            use_ffn = False,
            used_ffn = None,
            full_ffn = False,
            use_ffbinary = False
        ) -> None:
        self.tau = tau
        self.connections = connections
        self.architecture = architecture
        self.num_neurons = num_neurons
        self.num_layers = num_layers
        self.custom_layer_sizes = custom_layer_sizes
        self.last_layer_neurons = last_layer_neurons
        self.grad_factor = grad_factor
        self.device = device
        self.seed = seed
        self.use_groupsum = use_groupsum
        self.use_mygroupsum = use_mygroupsum
        self.dropout_percentage=dropout_percentage
        self.tree_classification=tree_classification
        self.tree_layers = tree_layers
        self.full_tree_output = full_tree_output
        self.distanceLayer = distanceLayer
        self.distanceLayer2 = distanceLayer2
        self.distance_dimension = distance_dimension
        self.use_ffn = use_ffn
        self.used_ffn = used_ffn
        self.full_ffn = full_ffn

    def from_dict(self, d: dict):
        self.__dict__.update(d)

    def to_dict(self):
        return self.__dict__


class DifflogicConfig:
    """Configuration for DiffLogic model with data, and train and test configs."""
    def __init__(
            self, 
            data_config: DataConfig, 
            train_config: TrainConfig,
            test_config: TestConfig,
            model_config: ModelConfig,
            experiment_config: ExperimentConfig,
            compilation_config: CompilationConfig,
        ) -> None:       
        assert data_config.device == model_config.device
        assert model_config.seed == experiment_config.seed

        self.experiment_config = experiment_config
        self.data_config = data_config
        self.train_config = train_config
        self.test_config = test_config
        self.model_config = model_config
        self.compilation_config = compilation_config

        # Check if data, library, models, and results directories are absolute paths. Otherwise, prepend the project storage directory.
        if not self.data_config.data_dir.is_absolute():
            self.data_config.data_dir = self.experiment_config.project_storage_dir / self.data_config.data_dir
        if not self.compilation_config.library_dir.is_absolute():
            self.compilation_config.library_dir = self.experiment_config.project_storage_dir / self.compilation_config.library_dir
        # if not self.experiment_config.results_di.is_absolute():
        #     self.experiment_config.results_di = self.experiment_config.project_storage_dir / self.experiment_config.results_di

    def from_dict(self, d: dict):
        self.experiment_config.from_dict(d["experiment_config"])
        self.data_config.from_dict(d["data_config"])
        self.train_config.from_dict(d["train_config"])
        self.test_config.from_dict(d["test_config"])
        self.model_config.from_dict(d["model_config"])
        self.compilation_config.from_dict(d["compilation_config"])  

    def to_dict(self):
        return {
            "experiment_config": self.experiment_config.to_dict(),
            "data_config": self.data_config.to_dict(),
            "train_config": self.train_config.to_dict(),
            "test_config": self.test_config.to_dict(),
            "model_config": self.model_config.to_dict(),
            "compilation_config": self.compilation_config.to_dict(),
        }

    def from_json(self, path):
        with open(path, 'r') as f:
            d = json.load(f, cls=CustomJSONDecoder)
        self.from_dict(d)
    
    def to_json(self, path):
        with open(path, 'w') as f:
            json.dump(self.to_dict(), f, cls=CustomJSONEncoder)
    
    def __str__(self):
        return str(self.to_dict())