import torch
import itertools

class Rule_Config():
    def __init__(self,n_features, n_rules, predicate_temperature = 0.2, random_cutpoints=False, temp_reduction_predicate=10, init_set_size=-1):
        self.n_features = n_features
        self.n_rules = n_rules
        self.predicate_temperature = predicate_temperature
        self.random_cutpoints = random_cutpoints
        self.temp_reduction_predicate = temp_reduction_predicate
        self.schedule_predicate_temperature = {"start":self.predicate_temperature,
                                                  "end":self.predicate_temperature/self.temp_reduction_predicate,"progress":"linear"}
        self.init_set_size = init_set_size

class Method_Config():
    def __init__(self):
        raise NotImplementedError("This is an abstract class. Please implement the method_config class.")

    def get_all_configs(self):
        combinations = []
        for params in itertools.product(*self.parameter_range_dict.values()):
            combinations.append(dict(zip(self.parameter_range_dict.keys(), params)))
        return combinations
    
    def get_default_config(self):
        configs = self.get_all_configs()
        return configs[self.default_index]

    def get_setting_config(self, setting):
        if setting not in self.setting_to_index:
            raise ValueError(f"Setting '{setting}' not recognized. Available settings: {list(self.setting_to_index.keys())}")
        index = self.setting_to_index[setting]
        configs = self.get_all_configs()
        return configs[index]

class Subcon_Config(Method_Config):
    def __init__(self):
        self.parameter_range_dict = {
        "lambd": [0.1, 0.5, 1.0],
        "gamma": [0.1, 0.5],
        "n_epochs": [500, 1000],
        "lr_classifier": [0.001, 0.005],
    }
        self.default_index = 9
        self.setting_to_index = {
            "observational": 9,
            "interventional": 9,
            "demographic": 9,
        }
    
    
class Syflow_Config(Method_Config):
    def __init__(self):
        self.parameter_range_dict = {
            "alpha": [0.1, 0.3, 0.5],
            "lr_classifier": [1e-3, 1e-4, 1e-5],
            "subgroup_train_epochs": [200, 500, 1000]
        }
        self.default_index = 20
        self.setting_to_index = {
            "observational": 20,
            "interventional": 2,
            "demographic": 3,
        }

class PySubgroup_Config(Method_Config):
    def __init__(self):
        self.parameter_range_dict = {
            "beam_width": [50, 100, 200],
            "n_bins": [5, 10, 20],
            "alpha": [0.2, 0.5, 1.0]
        }
        self.default_index = 2
        self.setting_to_index = {
            "observational": 26,
            "interventional": 11,
            "demographic": 26,
        }

class CausalTree_Config(Method_Config):
    def __init__(self):
        self.parameter_range_dict = {
            "min_samples_leaf": [0.01, 0.05, 0.1, 0.2, 0.3],
            "max_depth": [2,3,5,None]
        }
        self.default_index = 1
        self.setting_to_index = {
            "observational": 1,
            "interventional": 2,
            "demographic": 6,
        }

class HonestTree_Config(Method_Config):
    def __init__(self):
        self.parameter_range_dict = {
            "min_samples_leaf": [0.01,0.05,0.1,0.2,0.3],
            "max_depth": [2,3,5, None]
        }
        self.default_index = 9
        self.setting_to_index = {
            "observational": 9,
            "interventional": 8,
            "demographic": 18,
        }