from datetime import datetime
import hashlib
import json


# A Control instance stores the configuration information.
class Control:
    def __init__(self):
        self.data = "yago4_lcc"
        self.create_inverse_triples = True
        self.embed_dim = 100
        self.subgraph_max_size = 4e6
        self.embed_method = "distmult"
        self.embed_setting = DistMultSetting()
        self.seed = 0
        self.date = None
        self.core_prop = 0.05
        self.core_selection = "degree"  # "degree" or "relation_based" or "hybrid"
        self.partitioning = "blocs"
        self.propagation_type = "normalized_sum"  # "barycenter" or "normalized_sum"
        self.propagation_lr = 1
        self.n_passes_by_subgraph = 1
        self.reset_embed = (
            True  # reset core subgraph embeddings after each propagation steps
        )
        self.reorder_subgraphs = False

    def get_config(self):
        control_params = list(self.__dict__.items())
        embed_params = [
            ("embed_setting." + x[0], x[1]) for x in self.embed_setting.__dict__.items()
        ]

        config = dict(control_params + embed_params)
        config.pop("embed_setting")
        return config


# Embedding algorithms settings


class HolESetting:
    def __init__(self) -> None:
        self.optimizer = "Adam"
        self.composition = "circular_correlation"
        self.lr = 1e-3


class TuckERSetting:
    def __init__(self) -> None:
        self.optimizer = "Adam"
        self.composition = "tucker"
        self.lr = 1e-3


class DistMultSetting:
    def __init__(self) -> None:
        self.optimizer = "Adam"
        self.composition = "multiplication"
        self.lr = 1e-3
        self.loss_fn = "CrossEntropyLoss"


class TransESetting:
    def __init__(self) -> None:
        self.optimizer = "Adam"
        self.composition = "subtraction"
        self.lr = 1e-3


class RotatESetting:
    def __init__(self) -> None:
        self.optimizer = "Adam"
        self.composition = "multiplication"
        self.lr = 1e-3


class RandomSetting:
    """Configuration parameters for Random embedding."""

    def __init__(self):
        self.mean = 0
        self.std = 0.07
        self.composition = "multiplication"


# List of the methods available for KG embedding.
# If you implement a new method, please add it here, and define a new setting class.
embed_methods = {
    "hole": HolESetting,
    "distmult": DistMultSetting,
    "transe": TransESetting,
    "rotate": RotatESetting,
    "tucker": TuckERSetting,
    "random": RandomSetting,
}


# Dataset-specific parameters
core_node_proportions = {
    "mini_yago3_lcc": 0.3,
    "yago3_lcc": 0.6,
    "yago4_lcc": 0.01,
    "yago4.5_lcc": 0.065,
    "yago4_with_full_ontology": 0.005,
    "yago4.5_with_full_ontology": 0.055,
    "full_freebase_lcc": 0.02,
    "wikikg90mv2_lcc": 0.005,
}

core_edge_proportions = {
    "mini_yago3_lcc": 0.05,
    "yago3_lcc": 0.02,
    "yago4_lcc": 0.025,
    "yago4.5_lcc": 0.03,
    "yago4_with_full_ontology": 0.025,
    "yago4.5_with_full_ontology": 0.01,
    "full_freebase_lcc": 0.012,
    "wikikg90mv2_lcc": 0.004,
}

diffusion_stop = {
    "mini_yago3_lcc": 0.65,
    "yago3_lcc": 0.65,
    "yago4_lcc": 0.53,
    "yago4.5_lcc": 0.58,
    "yago4_with_full_ontology": 0.8,
    "yago4.5_with_full_ontology": 0.75,
    "full_freebase_lcc": 0.49,
    "wikikg90mv2_lcc": 0.93,
}

n_propagation_steps = {
    "mini_yago3_lcc": 5,
    "yago3_lcc": 5,
    "yago4_lcc": 5,
    "yago4.5_lcc": 10,
    "yago4_with_full_ontology": 5,
    "yago4.5_with_full_ontology": 10,
    "full_freebase_lcc": 5,
    "wikikg90mv2_lcc": 15,
}

num_epochs = {
    "mini_yago3_lcc": 25,
    "yago3_lcc": 25,
    "yago4_lcc": 40,
    "yago4.5_lcc": 16,
    "yago4_with_full_ontology": 40,
    "yago4.5_with_full_ontology": 16,
    "full_freebase_lcc": 40,
    "wikikg90mv2_lcc": 18,
}

batch_sizes = {
    "mini_yago3_lcc": 512,
    "yago3_lcc": 4096,
    "yago4_lcc": 8192,
    "yago4.5_lcc": 16384,
    "yago4_with_full_ontology": 8192,
    "yago4.5_with_full_ontology": 16384,
    "full_freebase_lcc": 8192,
    "wikikg90mv2_lcc": 16384,
}

num_negs_per_pos = {
    "mini_yago3_lcc": 100,
    "yago3_lcc": 100,
    "yago4_lcc": 100,
    "yago4.5_lcc": 100,
    "yago4_with_full_ontology": 100,
    "yago4.5_with_full_ontology": 100,
    "full_freebase_lcc": 100,
    "wikikg90mv2_lcc": 100,
}

time_intervals = {
    "mini_yago3_lcc": 0.1,
    "yago3_lcc": 1,
    "yago4_lcc": 10,
    "yago4.5_lcc": 10,
    "yago4_with_full_ontology": 10,
    "yago4.5_with_full_ontology": 10,
    "full_freebase_lcc": 10,
    "wikikg90mv2_lcc": 10,
}


def set_control_params(data, gpu, **kwargs):
    """Function to set the parameters for the control instance.

    Args:
        data (str): dataset name
        gpu (int): gpu id (if -1, use cpu)

    Returns:
        Control: control instance
    """
    # Initialize ctrl
    ctrl = Control()

    # Set data-specific parameters
    ctrl.data = data
    ctrl.diffusion_stop = diffusion_stop[ctrl.data]
    ctrl.n_propagation_steps = n_propagation_steps[ctrl.data]
    ctrl.core_node_proportions = core_node_proportions[ctrl.data]
    ctrl.core_edge_proportions = core_edge_proportions[ctrl.data]

    # Set attributes
    for key, value in kwargs.items():
        setattr(ctrl, key, value)

    if ctrl.core_selection == "degree":
        ctrl.core_prop = ctrl.core_node_proportions
        ctrl.core_edge_proportions = None
    elif ctrl.core_selection == "relation_based":
        ctrl.core_prop = ctrl.core_edge_proportions
        ctrl.core_node_proportions = None
    elif ctrl.core_selection == "hybrid":
        ctrl.core_prop = None

    if ctrl.propagation_type == "barycenter":
        ctrl.normalize_embeddings = False
    elif ctrl.propagation_type == "normalized_sum":
        ctrl.normalize_embeddings = True

    # Set embedding method's parameters
    if ctrl.embed_method == "pbg_precomputed":
        ctrl.reset_embed = True
    else:
        ctrl.embed_setting = embed_methods[ctrl.embed_method]()
        if ctrl.embed_method == "random":
            ctrl.reset_embed = False
        else:
            if "batch_size" not in kwargs:
                ctrl.batch_size = batch_sizes[ctrl.data]
            if "num_epochs" not in kwargs:
                ctrl.num_epochs = num_epochs[ctrl.data]
            if "num_negs_per_pos" not in kwargs:
                ctrl.num_negs_per_pos = num_negs_per_pos[ctrl.data]
            if "core_lr" in kwargs:
                ctrl.embed_setting.lr = kwargs["core_lr"]
            ctrl.reset_embed = True

    # Compute config id
    config = ctrl.get_config()
    ctrl.id = hashlib.sha256(
        json.dumps(config, sort_keys=True).encode("ascii")
    ).hexdigest()

    # Set other parameters
    ctrl.gpu = gpu
    if ctrl.gpu != -1:
        ctrl.device = f"cuda:{ctrl.gpu}"
    else:
        ctrl.device = "cpu"
    ctrl.date = datetime.today().strftime("%Y-%m-%d")
    ctrl.time_interval = time_intervals[ctrl.data]
    return ctrl
