from datetime import datetime
import hashlib
import json


# A Control instance stores the configuration information.
class Control:
    def __init__(self):
        self.data = "yago4_lcc"
        self.create_inverse_triples = True
        self.embed_dim = 100
        self.subgraph_max_size = 3e6
        self.embed_method = "distmult"
        self.embed_setting = DistMultSetting()
        self.seed = 0
        self.date = None
        self.core_prop = 0.05
        self.core_selection = "degree" # "degree" or "relation_based"
        self.partitioning = "sepal_subgraph"
        self.n_propagation_steps = 20
        self.propagation_type = "normalized_sum" # "barycenter" or "normalized_sum"
        self.reset_embed = (
            True  # reset core subgraph embeddings after each propagation steps
        )

    def get_config(self):
        control_params = list(self.__dict__.items())
        embed_params = [
            ("embed_setting." + x[0], x[1]) for x in self.embed_setting.__dict__.items()
        ]

        config = dict(control_params + embed_params)
        config.pop("embed_setting")
        return config


# Embedding algorithms settings

class HolESetting:
    def __init__(self) -> None:
        self.optimizer = "Adam"
        self.composition = "circular_correlation"
        self.lr = 1e-3


class TuckERSetting:
    def __init__(self) -> None:
        self.optimizer = "Adam"
        self.composition = "tucker"
        self.lr = 1e-3


class DistMultSetting:
    def __init__(self) -> None:
        self.optimizer = "Adam"
        self.composition = "multiplication"
        self.lr = 1e-3


class TransESetting:
    def __init__(self) -> None:
        self.optimizer = "Adam"
        self.composition = "subtraction"
        self.lr = 1e-3


class RandomSetting:
    """Configuration parameters for Random embedding."""
    def __init__(self):
        self.mean = 0
        self.std = 0.07
        self.composition = "multiplication"


# List of the methods available for KG embedding.
# If you implement a new method, please add it here, and define a new setting class.
embed_methods = {
    "hole": HolESetting,
    "distmult": DistMultSetting,
    "transe": TransESetting,
    "tucker": TuckERSetting,
    "random": RandomSetting,
}


# Dataset-specific parameters
core_node_proportions = {
    "mini_yago3_lcc": 0.05,
    "yago3_lcc": 0.05,
    "yago4_lcc": 0.06,
    "yago4.5_lcc": 0.03,
    "yago4_with_full_ontology": 0.04,
}

core_edge_proportions = {
    "mini_yago3_lcc": 0.01,
    "yago3_lcc": 0.025,
    "yago4_lcc": 0.01,
    "yago4.5_lcc": 0.02,
    "yago4_with_full_ontology": 0.01,
}

diffusion_stop = {
    "mini_yago3_lcc": 0.8,
    "yago3_lcc": 0.77,
    "yago4_lcc": 0.55,
    "yago4.5_lcc": 0.6,
    "yago4_with_full_ontology": 0.8,
}

n_propagation_steps = {
    "mini_yago3_lcc": 5,
    "yago3_lcc": 15,
    "yago4_lcc": 20,
    "yago4.5_lcc": 50,
    "yago4_with_full_ontology": 20,
}

num_epochs = {
    "mini_yago3_lcc": 60,
    "yago3_lcc": 50,
    "yago4_lcc": 75,
    "yago4.5_lcc": 75,
    "yago4_with_full_ontology": 75,
}

batch_sizes = {
    "mini_yago3_lcc": 512,
    "yago3_lcc": 2048,
    "yago4_lcc": 65536,
    "yago4.5_lcc": 8192,
    "yago4_with_full_ontology": 65536,
}

time_intervals = {
    "mini_yago3_lcc": 0.1,
    "yago3_lcc": 1,
    "yago4_lcc": 10,
    "yago4.5_lcc": 10,
    "yago4_with_full_ontology": 10,
}


# Function to set the parameterss


def set_control_params(data, gpu, **kwargs):
    # Initialize ctrl
    ctrl = Control()

    # Set data-specific parameters
    ctrl.data = data
    ctrl.diffusion_stop = diffusion_stop[ctrl.data]
    ctrl.n_propagation_steps = n_propagation_steps[ctrl.data]

    # Set attributes
    for key, value in kwargs.items():
        setattr(ctrl, key, value)

    if ctrl.core_selection == "degree":
        ctrl.core_prop = core_node_proportions[ctrl.data]
    elif ctrl.core_selection == "relation_based":
        ctrl.core_prop = core_edge_proportions[ctrl.data]
    
    if ctrl.propagation_type == "barycenter":
        ctrl.normalize_embeddings = False
    elif ctrl.propagation_type == "normalized_sum":
        ctrl.normalize_embeddings = True

    # Set embedding method's parameters
    ctrl.embed_setting = embed_methods[ctrl.embed_method]()
    if ctrl.embed_method == "random":
        ctrl.reset_embed = False
    else:
        ctrl.batch_size = batch_sizes[ctrl.data]
        ctrl.num_epochs = num_epochs[ctrl.data]
        ctrl.reset_embed = True

    # Compute config id
    config = ctrl.get_config()
    ctrl.id = hashlib.sha256(
        json.dumps(config, sort_keys=True).encode("ascii")
    ).hexdigest()

    # Set other parameters
    ctrl.gpu = gpu
    if ctrl.gpu != -1:
        ctrl.device = f"cuda:{ctrl.gpu}"
    else:
        ctrl.device = "cpu"
    ctrl.date = datetime.today().strftime("%Y-%m-%d")
    ctrl.time_interval = time_intervals[ctrl.data]
    return ctrl
