import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import wandb


import torch

from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory
from torch_geometric.nn.models.tgn import (
    IdentityMessage,
    LastAggregator,
    LastNeighborLoader,
)

from bikeshare import get_cbs_data
from utils import get_config, setup_seed, save_model
from model import (
    GraphAttentionEmbedding,
    ProductLayer,
    MergeLayer,
    ReshapeLayer,
    col_RNN,
)
from train_epoch import train_epoch

from val_epoch import Validator


def main():

    # init wandb logger
    wandb.init(
        project="generative-tgn",
        config=dict(
            name=get_config("name"),
            n_gpu=get_config("n_gpu"),
            num_neighbors=get_config("num_neighbors"),
            batch_size=get_config("batch_size"),
            n_sampled_src=get_config("n_sampled_src"),
            n_sampled_dst=get_config("n_sampled_dst"),
            memory_dim=get_config("memory_dim"),
            time_dim=get_config("time_dim"),
            embedding_dim=get_config("embedding_dim"),
            feats_model_h_dim=get_config("feats_model_h_dim"),
            data_name=get_config("data_name"),
            data_path=get_config("data_path"),
            lr=get_config("lr"),
            threshold_eps=get_config("threshold_eps"),
            threshold_epochs=get_config("threshold_epochs"),
            eps=get_config("eps"),
            num_feats=get_config("num_feats"),
            seed=get_config("seed"),
            n_comp=get_config("n_comp"),
        ),
    )

    hyperparams = wandb.config

    # setup seed
    setup_seed(hyperparams["seed"])

    # args
    num_neighbors = hyperparams["num_neighbors"]
    batch_size = hyperparams["batch_size"]

    n_sampled_src = hyperparams["n_sampled_src"]  # batch_size * 2],
    n_sampled_dst = hyperparams["n_sampled_dst"]  # batch_size * 2],
    memory_dim = hyperparams["memory_dim"]
    time_dim = hyperparams["time_dim"]
    embedding_dim = hyperparams["embedding_dim"]

    feats_model_h_dim = hyperparams["feats_model_h_dim"]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'JODIE')
    name = hyperparams["data_name"]
    print(f"Using {name} dataset. ")

    data = None
    if hyperparams["data_name"] in ["wikipedia", "reddit"]:
        path = hyperparams["data_path"]
        dataset = JODIEDataset(path, name=hyperparams["data_name"])
        data = dataset[0]
        data.msg = data.msg[:, :12]

    else:
        print("for sure bike data!!")
        data = get_cbs_data()

    # reduce the number of features
    reduce_feats = False
    if reduce_feats is True:
        sampled_feat_idxs = torch.randint(
            0, data.msg.shape[1], (hyperparams["num_feats"],)
        )
        data.msg = data.msg[:, sampled_feat_idxs]

    # add time diff
    dt = torch.cat([torch.tensor([0.0]), data.t[1:] - data.t[:-1]])
    data.msg = torch.cat([dt.unsqueeze(1), data.msg], dim=1)
    data.msg[:2]

    smp = data.msg
    col2K = {**{0: "exponential"}, **{i + 1: "gmm" for i in range(smp.shape[1])}}
    print(col2K)

    data = data.to(device)

    # Ensure to only sample actual destination nodes as negatives.
    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
    min_src_idx, max_src_idx = int(data.src.min()), int(data.src.max())
    train_data, val_data, test_data = data.train_val_test_split(
        val_ratio=0.15, test_ratio=0.15
    )

    train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
    val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
    test_loader = TemporalDataLoader(test_data, batch_size=batch_size)

    neighbor_loader = LastNeighborLoader(
        data.num_nodes, size=num_neighbors, device=device
    )

    memory = TGNMemory(
        data.num_nodes,
        data.msg.size(-1),
        memory_dim,
        time_dim,
        message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
        aggregator_module=LastAggregator(),
    ).to(device)

    gnn = GraphAttentionEmbedding(
        in_channels=memory_dim,
        out_channels=embedding_dim,
        msg_dim=data.msg.size(-1),
        time_enc=memory.time_enc,
    ).to(device)

    embd_to_score_dst = ProductLayer(in_channels=embedding_dim).to(device)

    embd_to_score_src = ReshapeLayer(in_channels=embedding_dim, out_channels=1).to(
        device
    )

    feats_model = col_RNN(
        smp.shape[1],
        col2K,
        embed_dim=1,
        hidden_size=feats_model_h_dim,
        num_layers=1,
        n_comp=hyperparams["n_comp"],
    ).to(device)

    embd_to_h0 = MergeLayer(
        in_channels=embedding_dim, out_channels=feats_model_h_dim
    ).to(device)

    # OPTIM
    optimizer = torch.optim.Adam(
        set(memory.parameters())
        | set(gnn.parameters())
        | set(embd_to_score_src.parameters())
        | set(embd_to_score_dst.parameters())
        | set(feats_model.parameters())
        | set(embd_to_h0.parameters()),
        lr=hyperparams["lr"],
    )

    # Helper vector to map global node indices to local ones.
    assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)

    all_dst = torch.arange(min_dst_idx, max_dst_idx + 1).to(device)

    len_all_dst = len(all_dst)

    rand_all_dst = torch.rand(len_all_dst).to(device)
    rand_smpdst = torch.rand(n_sampled_dst).to(device)
    tmap = torch.empty(max_dst_idx + 1, dtype=torch.long).to(device)
    # y_true_gen = torch.zeros((batch_size, n_sampled_dst)).to(device)

    all_src = torch.arange(min_src_idx, max_src_idx + 1).to(device)

    len_all_src = len(all_src)

    rand_all_src = torch.rand(len_all_src).to(device)
    rand_smpsrc = torch.rand(n_sampled_src).to(device)
    tmap_src = torch.empty(max_src_idx + 1, dtype=torch.long).to(device)
    # y_true_gen = torch.zeros((batch_size, n_sampled_dst)).to(device)

    # wandb model watch
    wandb.watch(memory, log_freq=1000)
    wandb.watch(gnn, log_freq=1000)
    wandb.watch(embd_to_score_dst, log_freq=1000)
    wandb.watch(embd_to_score_src, log_freq=1000)
    wandb.watch(feats_model, log_freq=1000)
    wandb.watch(embd_to_h0, log_freq=1000)

    wandb.run.summary["num_params_memory"] = sum(p.numel() for p in memory.parameters())
    wandb.run.summary["num_params_gnn"] = sum(p.numel() for p in gnn.parameters())
    wandb.run.summary["num_params_embd_to_score_dst"] = sum(
        p.numel() for p in embd_to_score_dst.parameters()
    )
    wandb.run.summary["num_params_embd_to_score_src"] = sum(
        p.numel() for p in embd_to_score_src.parameters()
    )
    wandb.run.summary["num_params_feats_model"] = sum(
        p.numel() for p in feats_model.parameters()
    )
    wandb.run.summary["num_params_embd_to_h0"] = sum(
        p.numel() for p in embd_to_h0.parameters()
    )

    wandb.run.summary["num_params_total"] = (
        sum(p.numel() for p in memory.parameters())
        + sum(p.numel() for p in gnn.parameters())
        + sum(p.numel() for p in embd_to_score_dst.parameters())
        + sum(p.numel() for p in embd_to_score_src.parameters())
        + sum(p.numel() for p in feats_model.parameters())
        + sum(p.numel() for p in embd_to_h0.parameters())
    )

    """
    BEGIN TRAIN LOOP
    """
    results = []
    results_feats = []

    best_loss = 1e8
    best_model = None
    best_loss_epoch = 0

    tot_epochs = 0

    eps = hyperparams["eps"]
    threshold_eps = hyperparams["threshold_eps"]
    min_eps = eps * 0.01
    threshold_epochs = hyperparams["threshold_epochs"]

    my_validator = Validator(
        val_loader=val_loader,
        train_loader=train_loader,
        test_loader=test_loader,
        device=device,
    )

    best_val_loss = 1000
    best_train_loss = 1000
    best_ds_val_loss = 1000
    best_ds_val_aucs = 0
    best_ds_val_aps = 0

    for epoch in range(1, 1000):
        tot_epochs += 1

        mean_loss, mean_loss_feats, best_model, best_loss = train_epoch(
            memory,
            gnn,
            embd_to_score_dst,
            feats_model,
            embd_to_h0,
            neighbor_loader,
            train_loader,
            device,
            all_src,
            rand_all_src,
            rand_smpsrc,
            tmap_src,
            n_sampled_src,
            rand_all_dst,
            n_sampled_dst,
            rand_smpdst,
            all_dst,
            min_dst_idx,
            tmap,
            optimizer,
            best_loss,
            assoc,
            data,
            eps,
            embd_to_score_src,
            train_data,
            max_dst_idx,
        )

        (
            mean_val_loss,
            mean_val_loss_feats,
            best_val_model_,
            best_val_loss_,
        ) = my_validator.val_loss(
            embd_to_score_src,
            embd_to_score_dst,
            memory,
            gnn,
            feats_model,
            embd_to_h0,
            neighbor_loader,
            min_dst_idx,
            max_dst_idx,
            all_src,
            all_dst,
            rand_all_src,
            rand_smpsrc,
            tmap_src,
            n_sampled_src,
            rand_all_dst,
            n_sampled_dst,
            rand_smpdst,
            tmap,
            assoc,
            data,
            eps,
            test_data,
            loader="test",
        )

        aps_val, aucs_val = my_validator.link_pred_from_emb(
            min_dst_idx=min_dst_idx,
            max_dst_idx=max_dst_idx,
            memory=memory,
            gnn=gnn,
            assoc=assoc,
            data=data,
            neighbor_loader=neighbor_loader,
            dataset="test",
        )

        if aps_val > best_ds_val_aps:
            best_ds_val_aps = aps_val

        if aucs_val > best_ds_val_aucs:
            best_ds_val_aucs = aucs_val

        if mean_loss < best_train_loss:
            best_train_loss = mean_loss
            wandb.run.summary["best_train_loss_epoch"] = epoch

            best_model = []
            best_model.append(memory)
            best_model.append(gnn)
            best_model.append(embd_to_score_dst)
            best_model.append(embd_to_score_src)
            best_model.append(feats_model)
            best_model.append(embd_to_h0)

            dataset_ = get_config("data_name")
            config_dict = {
                "num_neighbors": hyperparams["num_neighbors"],
                "batch_size": hyperparams["batch_size"],
                "memory_dim": hyperparams["memory_dim"],
                "time_dim": hyperparams["time_dim"],
                "embedding_dim": hyperparams["embedding_dim"],
                "feats_model_h_dim": hyperparams["feats_model_h_dim"],
                "n_comp": hyperparams["n_comp"],
                "lr": hyperparams["lr"],
                "eps": hyperparams["eps"],
                "threshold_eps": hyperparams["threshold_eps"],
                "min_eps": min_eps,
                "threshold_epochs": hyperparams["threshold_epochs"],
                "seed": hyperparams["seed"],
                "data": data,
                "min_dst_idx": min_dst_idx,
                "max_dst_idx": max_dst_idx,
                "all_src": all_src,
                "all_dst": all_dst,
                "rand_all_src": rand_all_src,
                "rand_smpsrc": rand_smpsrc,
                "tmap_src": tmap_src,
                "n_sampled_src": n_sampled_src,
                "rand_all_dst": rand_all_dst,
                "n_sampled_dst": n_sampled_dst,
                "rand_smpdst": rand_smpdst,
                "tmap": tmap,
                "assoc": assoc,
                "val_data": test_data,
                "loader": test_loader,
                "smp": smp,
                "col2K": col2K,
            }
            save_model(
                best_model,
                f"./saved_models/train_loss/{dataset_}_{name}",
                config_dict,
                neighbor_loader,
            )

        if mean_val_loss < best_val_loss:
            best_val_loss = mean_val_loss
            wandb.run.summary["best_val_loss_epoch"] = epoch

            best_model = []
            best_model.append(memory)
            best_model.append(gnn)
            best_model.append(embd_to_score_dst)
            best_model.append(embd_to_score_src)
            best_model.append(feats_model)
            best_model.append(embd_to_h0)

            dataset_ = get_config("data_name")
            config_dict = {
                "num_neighbors": hyperparams["num_neighbors"],
                "batch_size": hyperparams["batch_size"],
                "memory_dim": hyperparams["memory_dim"],
                "time_dim": hyperparams["time_dim"],
                "embedding_dim": hyperparams["embedding_dim"],
                "feats_model_h_dim": hyperparams["feats_model_h_dim"],
                "n_comp": hyperparams["n_comp"],
                "lr": hyperparams["lr"],
                "eps": hyperparams["eps"],
                "threshold_eps": hyperparams["threshold_eps"],
                "min_eps": min_eps,
                "threshold_epochs": hyperparams["threshold_epochs"],
                "seed": hyperparams["seed"],
                "data": data,
                "min_dst_idx": min_dst_idx,
                "max_dst_idx": max_dst_idx,
                "all_src": all_src,
                "all_dst": all_dst,
                "rand_all_src": rand_all_src,
                "rand_smpsrc": rand_smpsrc,
                "tmap_src": tmap_src,
                "n_sampled_src": n_sampled_src,
                "rand_all_dst": rand_all_dst,
                "n_sampled_dst": n_sampled_dst,
                "rand_smpdst": rand_smpdst,
                "tmap": tmap,
                "assoc": assoc,
                "val_data": test_data,
                "loader": test_loader,
                "smp": smp,
                "col2K": col2K,
            }
            save_model(
                best_model,
                f"./saved_models/val_loss/{dataset_}_{name}",
                config_dict,
                neighbor_loader,
            )

        results.append([mean_loss])
        results_feats.append([mean_loss_feats])

        if np.abs(eps - min_eps) > 1e-15:
            if tot_epochs > threshold_epochs and results_feats[
                tot_epochs - threshold_epochs
            ][0] - results_feats[-1][0] < threshold_eps * torch.abs(
                results_feats[tot_epochs - threshold_epochs][0]
            ):
                # rescale eps
                eps = eps * 0.1
                print(
                    f"[epoch {epoch} of {1000}]: \t ====== rescaling eps to {eps} "
                    + f"({results_feats[tot_epochs - threshold_epochs][0]} - {results_feats[-1][0]}) ======"
                )

        print(
            f"Epoch: {epoch:02d}:\t Train Loss: {mean_loss:.4f}\t Loss features: {mean_loss_feats:.4f}"
        )

        print(
            f"\t\t Val Loss: {mean_val_loss:.4f}\t Loss features: {mean_val_loss_feats:.4f} \t DS(aps, aucs): {aps_val:.4f}, {aucs_val:.4f}"
        )

        wandb.log(
            {
                "mean_train_loss": mean_loss,
                "mean_train_loss_feats": mean_loss_feats,
                "mean_val_loss": mean_val_loss,
                "mean_val_loss_feats": mean_val_loss_feats,
                "val_aps_ds": aps_val,
                "val_aucs_ds": aucs_val,
                "best_ds_val_aps": best_ds_val_aps,
            }
        )

        wandb.run.summary["best_val_loss"] = best_val_loss
        wandb.run.summary["best_ds_val_loss"] = best_ds_val_loss
        wandb.run.summary["best_ds_val_aucs"] = best_ds_val_aucs


if __name__ == "__main__":
    main()
