import torch
import time
import json
import numpy as np
import scipy.sparse
from pathlib import Path
from pykeen.datasets import Wikidata5M
from pykeen.models import DistMult, ComplEx, TransE, RotatE
from pykeen.training import SLCWATrainingLoop
from pykeen.utils import resolve_device, set_random_seed
from pykeen.checkpoints import save_model
from pykeen.triples import TriplesFactory

from fastrp import fastrp_wrapper

MODELS_DIR = Path(__file__).parents[1] / "models/linked"

data_paths = {
    "wikidata500k": Path(__file__).parents[1] / "data/kg/processed/wikidata5m_deg9",
    "wikidata1m": Path(__file__).parents[1] / "data/kg/processed/wikidata5m_deg6",
    "wikidata2m": Path(__file__).parents[1] / "data/kg/processed/wikidata5m_deg4",
    "wikidata3m": Path(__file__).parents[1] / "data/kg/processed/wikidata5m_deg3",
}


def train_kge_embeddings(data="wikidata5m"):
    for model_name in [
        "DistMult",
        "ComplEx",
        "RotatE",
        "TransE",
    ]:
        print(f"Training {model_name} model on {data}...")
        output_directory = MODELS_DIR / f"{model_name.lower()}_{data}"
        output_directory.mkdir(parents=True, exist_ok=True)

        # 0. Fix random seed for reproducibility
        set_random_seed(42)

        # 1. Load dataset
        if data == "wikidata5m":
            dataset = Wikidata5M()
            triples_factory = dataset.training
        else:
            triples_factory = TriplesFactory.from_path_binary(path=data_paths[data])

        # 2. Initialize model
        if model_name == "ComplEx":
            model = ComplEx(triples_factory=triples_factory, embedding_dim=150)
        elif model_name == "DistMult":
            model = DistMult(triples_factory=triples_factory, embedding_dim=300)
        elif model_name == "TransE":
            model = TransE(triples_factory=triples_factory, embedding_dim=300)
        elif model_name == "RotatE":
            model = RotatE(triples_factory=triples_factory, embedding_dim=150)

        # 3. Move to GPU if available
        device = resolve_device("cuda:1")
        model = model.to(device)

        # 4. Create training loop
        training_loop = SLCWATrainingLoop(
            model=model,
            triples_factory=triples_factory,
            optimizer=torch.optim.Adam(params=model.get_grad_params(), lr=1e-3),
        )

        # 5. Train model
        start_time = time.time()
        training_loop.train(
            triples_factory=triples_factory,
            num_epochs=100,
            batch_size=8192,
            use_tqdm_batch=True,
        )
        end_time = time.time()
        training_time = end_time - start_time

        # 6. Save model and training time
        save_model(model, output_directory / "model.pt")
        with open(output_directory / "training_time.json", "w") as f:
            json.dump({"training_time_seconds": training_time}, f)

        print(f"{model_name} model saved to {output_directory}")
        print(f"Training took {training_time:.2f} seconds.")
    return


def train_fastrp_embeddings(data="wikidata5m"):
    print(f"Training FastRP on {data} dataset...")
    output_directory = MODELS_DIR / f"fastrp_{data}"
    output_directory.mkdir(parents=True, exist_ok=True)

    # 1. Load dataset
    if data == "wikidata5m":
        dataset = Wikidata5M()
        triples_factory = dataset.training
    else:
        triples_factory = TriplesFactory.from_path_binary(path=data_paths[data])

    # 2. Set configuration
    config = {
        "projection_method": "sparse",
        "input_matrix": "adj",
        "weights": [1.0, 1.0, 7.81, 45.28],
        "normalization": False,
        "dim": 300,
        "alpha": -0.628,
        "C": 1.0,
    }

    # 3. Prepare adjacency matrix
    num_entities = triples_factory.num_entities
    edge_index = triples_factory.mapped_triples[:, 0::2].t()

    A = scipy.sparse.csr_matrix(
        (torch.ones(edge_index.shape[1]), (edge_index[0], edge_index[1])),
        shape=(num_entities, num_entities),
        dtype="float32",
    )
    # Make adjacency matrix symmetric
    A = A + A.T
    A[A > 1] = 1

    # 4. Train FastRP
    start_time = time.time()
    U = fastrp_wrapper(A, config)
    end_time = time.time()
    training_time = end_time - start_time
    with open(output_directory / "training_time.json", "w") as f:
        json.dump({"training_time_seconds": training_time}, f)

    config["weights"] = [config["weights"]]

    # Save embeddings
    np.save(output_directory / "entity_embeddings.npy", U)
    return


if __name__ == "__main__":
    for data in [
        # "wikidata500k",
        # "wikidata1m",
        # "wikidata2m",
        # "wikidata3m",
        "wikidata5m",
    ]:
        # train_kge_embeddings(data)
        train_fastrp_embeddings(data)
