import os
import logging
from abc import ABC, abstractmethod
from pathlib import Path
import json


class BaseExperimentParameters:
    def __init__(
        self,
        num_training_samples: int,
        data_dimension: int,
        data_generating_function_name: str,
        noise_level: float,
        kernel: str,
        number_simulations_per_size: int,
    ):
        self.num_training_samples = num_training_samples
        self.data_dimension = data_dimension
        self.data_generating_function_name = data_generating_function_name
        self.noise_level = noise_level
        self.kernel = kernel
        self.number_simulations_per_size = number_simulations_per_size

    def save(self, results_path):
        with open(os.path.join(results_path, "experiment_parameters.json"), "w") as f:
            json.dump(self.__dict__, f, indent=4)


class SubtermExperimentParameters(BaseExperimentParameters):
    def __init__(
        self,
        setting_to_change: str,
        start: float,
        end: float,
        step: float,
        num_training_samples: int,
        data_dimension: int,
        data_generating_function_name: str,
        noise_level: float,
        kernel: str,
        number_simulations_per_size: int,
        random_seed: int,
    ):
        super().__init__(
            num_training_samples,
            data_dimension,
            data_generating_function_name,
            noise_level,
            kernel,
            number_simulations_per_size,
        )
        self.setting_to_change = setting_to_change
        self.start = start
        self.end = end
        self.step = step
        self.random_seed = random_seed


class RandomFeatureModelExperimentParameters(BaseExperimentParameters):
    def __init__(
        self,
        num_features_per_model: int,
        max_num_models: int,
        num_training_samples: int,
        data_dimension: int,
        data_generating_function_name: str,
        noise_level: float,
        kernel: str,
        activation_function: str,
        random_weights_distribution: str,
        case_type: str,
        number_simulations_per_size: int,
    ):
        super().__init__(
            num_training_samples,
            data_dimension,
            data_generating_function_name,
            noise_level,
            kernel,
            number_simulations_per_size,
        )
        self.num_features_per_model = num_features_per_model
        self.max_num_models = max_num_models
        self.activation_function = activation_function
        self.random_weights_distribution = random_weights_distribution
        self.case_type = case_type


class Experiment(ABC):
    def __init__(self, results_path, experiment_parameters, experiment_number):
        self.experiment_parameters = experiment_parameters
        self.experiment_number = experiment_number

        # Create directory for the experiment
        experiment_dir = Path("results") / self._get_experiment_dir_name()
        experiment_dir.mkdir(parents=True, exist_ok=True)
        self.experiment_dir = experiment_dir

        # Save parameters
        self.experiment_parameters.save(experiment_dir)

    def run_and_visualize_experiment(self):
        # Setup logger
        logger = self._setup_logger(self.experiment_dir)

        # Run the experiment
        results = self._run_experiment(logger)

        # Visualize the results
        self._visualize_results(results)

    @abstractmethod
    def _run_experiment(self, logger):
        pass

    @abstractmethod
    def _visualize_results(self, results):
        pass

    def _setup_logger(self, experiment_dir):
        logger = logging.getLogger(f"experiment_{self.experiment_number}")
        logger.setLevel(logging.DEBUG)

        # Create file handler which logs even debug messages
        fh = logging.FileHandler(experiment_dir / "experiment.log")
        fh.setLevel(logging.DEBUG)

        # Create console handler with a higher log level
        ch = logging.StreamHandler()
        ch.setLevel(logging.INFO)

        # Create formatter and add it to the handlers
        formatter = logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
        )
        fh.setFormatter(formatter)
        ch.setFormatter(formatter)

        # Add the handlers to logger
        logger.addHandler(fh)
        logger.addHandler(ch)

        return logger
