"""
    Main script to run the k-TVNs on the TGB benchmark.
    This is the official implementation of our paper: "Virtual Nodes Go Temporal".
"""

from hmac import new
import math
import timeit
import random
import argparse

import os
import os.path as osp
from pathlib import Path
import numpy as np
from tqdm import tqdm as tk
from datetime import datetime
import csv

import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
import torch.nn.functional as F

from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader

from torch_geometric.nn import TransformerConv

from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
import tqdm
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import  EarlyStopMonitor

from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset

from vn_mixer import build_vn_graph_and_mix, VNMixer, MeanDegMixer



class RunningSnapshot:
    """
        This class keep an edge_index of the edges that have been seen so far, which
        is used for the clustering. 
    """
    def __init__(self, num_nodes: int, device):
        self.num_nodes = num_nodes
        self.device = device
        self._src = torch.empty(0, dtype=torch.long, device=device)
        self._dst = torch.empty(0, dtype=torch.long, device=device)

    @torch.no_grad()
    def insert(self, src: torch.Tensor, dst: torch.Tensor):
        # src, dst are 1D Long tensors on device
        self._src = torch.cat([self._src, src.detach()])
        self._dst = torch.cat([self._dst, dst.detach()])

    @torch.no_grad()
    def edge_index(self, undirected: bool = True, dedup: bool = True) -> torch.Tensor:
        if self._src.numel() == 0:
            # empty; return a trivial self-loop graph to avoid k-means edge cases
            ar = torch.arange(self.num_nodes, device=self.device)
            return torch.stack([ar, ar], dim=0)
        s1, d1 = self._src, self._dst
        if undirected:
            s = torch.cat([s1, d1], dim=0)
            d = torch.cat([d1, s1], dim=0)
        else:
            s, d = s1, d1
        if dedup:
            lin = s * self.num_nodes + d
            uniq = torch.unique(lin)
            s = (uniq // self.num_nodes).long()
            d = (uniq %  self.num_nodes).long()
        return torch.stack([s, d], dim=0)



def train():
    """
        Train function, which includes the VNs.
        ---
        Main addition consists of adding the VN Mixer part which takes the individual memory and update based on the clusters.
    """
    model['memory'].train()
    model['gnn'].train()
    model['link_pred'].train()

    # Since our VN Mixer sometimes can be learnable (in case of attention-based), we put the train here
    vn_mixer.train()  

    model['memory'].reset_state()  
    neighbor_loader.reset_state()  

    n_id_obs = torch.empty(0, dtype=torch.long, device=device) 
    z_exp_obs = torch.zeros(1, MEM_DIM, device=device) 

    total_loss = 0
    for batch in tk(train_loader):
        batch = batch.to(device)
        optimizer.zero_grad()

        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

        
        neg_dst = torch.randint(
            min_dst_idx,
            max_dst_idx + 1,
            (src.size(0),),
            dtype=torch.long,
            device=device,
        )

        n_id = torch.cat([src, pos_dst, neg_dst]).unique()
        new_nodes = n_id[~torch.isin(n_id, n_id_obs)] 
        n_id_seen = n_id[~torch.isin(n_id, new_nodes)] 
        n_id_obs = torch.cat((n_id_obs, new_nodes), dim=0).unique() 
        n_id, edge_index, e_id = neighbor_loader(n_id)
        
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        z, last_update = model['memory'](n_id)
        z_exp = z_exp_obs[n_id_seen].detach() 
        z[assoc[n_id_seen]] = z_exp 
       
        z = model['gnn'](
            z,
            last_update,
            edge_index,
            data.t[e_id].to(device),
            data.msg[e_id].to(device),
        )

        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])

        loss = criterion(pos_out, torch.ones_like(pos_out))
        loss += criterion(neg_out, torch.zeros_like(neg_out))

        model['memory'].update_state(src, pos_dst, t, msg)
        neighbor_loader.insert(src, pos_dst)
        snapshot.insert(src, pos_dst) 

        loss.backward()
        optimizer.step()


        x_obs = model['memory'].memory


        # Now we need to build a snapshot adjacency which can be used for the clustering. 
        snapshot_edge_index = snapshot.edge_index()

        # Now we build the VNs and do the mixing of information through the virtual nodes.
        z_exp_obs, assignment, vn_edge_index = build_vn_graph_and_mix(
            memory_bank=x_obs,
            edge_index_snapshot=snapshot_edge_index.to(device),
            num_nodes=data.num_nodes,
            k=K_COMMUNITIES,
            in_dim=MEM_DIM,
            out_dim=MEM_DIM,
            method=ASSIGN_METHOD,
            connect=VN_CONNECT,
            mixer=vn_mixer,
            method_kwargs=ASSIGN_KW,
            seed=SEED,
        )

        model['memory'].detach()
        total_loss += float(loss) * batch.num_events
    
    return total_loss / train_data.num_events


@torch.no_grad()
def test(loader, neg_sampler, split_mode):
    """
    Evaluation function, which includes the VNs.
    """

    model['memory'].eval()
    model['gnn'].eval()
    model['link_pred'].eval()
    vn_mixer.eval()

    perf_list = []

    n_id_obs = torch.empty(0, dtype=torch.long, device=device)
    z_exp_obs = torch.zeros(1, MEM_DIM, device=device)

    # Running a separate snapshot for VN clustering during eval (to avoid leakage from training set)
    snapshot_eval = RunningSnapshot(num_nodes=data.num_nodes, device=device)

    for pos_batch in tk(loader):
        pos_src, pos_dst, pos_t, pos_msg = (
            pos_batch.src,
            pos_batch.dst,
            pos_batch.t,
            pos_batch.msg,
        )

        # === Score this batch ===
        neg_batch_list = neg_sampler.query_batch(
            pos_src, pos_dst, pos_t, split_mode=split_mode
        )

        for idx, neg_batch in enumerate(neg_batch_list):
            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
            dst = torch.tensor(
                np.concatenate(
                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
                    axis=0,
                ),
                device=device,
            )

            n_id = torch.cat([src, dst]).unique()
            n_id_seen_all = n_id[torch.isin(n_id, n_id_obs)]

            n_id, edge_index, e_id = neighbor_loader(n_id)
            assoc[n_id] = torch.arange(n_id.size(0), device=device)

            # Get the memory states
            z, last_update = model['memory'](n_id)

            # Get the information from the existing VN embeddings (if existing)
            if z_exp_obs.size(0) > 1 and n_id_seen_all.numel() > 0:
                max_idx = int(n_id_seen_all.max().item())
                if max_idx < z_exp_obs.size(0):
                    z_exp = z_exp_obs[n_id_seen_all].detach()
                    z[assoc[n_id_seen_all]] = z_exp

            # GNN update
            z = model['gnn'](
                z,
                last_update,
                edge_index,
                data.t[e_id].to(device),
                data.msg[e_id].to(device),
            )

            # Score: index 0 is the positive, rest are negatives
            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
            input_dict = {
                "y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
                "y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
                "eval_metric": [metric],
            }
            perf_list.append(evaluator.eval(input_dict)[metric])


        # Update temporal memory and neighbor loader
        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
        neighbor_loader.insert(pos_src, pos_dst)

        # Update the snapshot for VN mixing
        snapshot_eval.insert(pos_src, pos_dst)

        # Recoompute the VN embeddings from the updated memory and snapshot.
        x_obs = model['memory'].memory
        snapshot_edge_index = snapshot_eval.edge_index()

        z_exp_obs, assignment, vn_edge_index = build_vn_graph_and_mix(
            memory_bank=x_obs,
            edge_index_snapshot=snapshot_edge_index.to(device),
            num_nodes=data.num_nodes,
            k=K_COMMUNITIES,
            in_dim=MEM_DIM,
            out_dim=MEM_DIM,
            method=ASSIGN_METHOD,
            connect=VN_CONNECT,
            mixer=vn_mixer,
            method_kwargs=ASSIGN_KW,
            seed=SEED,
        )

        
        n_id_pos = torch.cat([pos_src, pos_dst]).unique()
        new_nodes = n_id_pos[~torch.isin(n_id_pos, n_id_obs)]
        n_id_obs = torch.cat((n_id_obs, new_nodes), dim=0).unique()

    # Simple aggregation of the metrics
    perf_metrics = float(torch.tensor(perf_list).mean())
    return perf_metrics


start_overall = timeit.default_timer()

args, _ = get_args()
print("INFO: Arguments:", args)

# Get the set of argument parameters
DATA = args.data
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value  
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
# This represents the number of clusters to be used.
K_COMMUNITIES = args.n_communities
VN_CONNECT = "clique"                    # "clique" | "star" | "ring"

# This represents the clustering method to be used (default of K-TVNs is "kmeans_adj")
# Other methods such as "random" and "louvain" can also be considered. 
ASSIGN_METHOD = "kmeans_adj"

# Enter the number of iterations of the clustering method.
ASSIGN_KW = {"dproj": 64, "iters": args.num_iters}

NUM_NEIGHBORS = 10
MODEL_NAME = 'GAT'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric

train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]

train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)

min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)

memory = TGNMemory(
    data.num_nodes,
    data.msg.size(-1),
    MEM_DIM,
    TIME_DIM,
    message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
    aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
    in_channels=MEM_DIM,
    out_channels=EMB_DIM,
    msg_dim=data.msg.size(-1),
    time_enc=memory.time_enc,
).to(device)


# vn_mixer = VNMixer(in_dim=MEM_DIM, out_dim=MEM_DIM, heads=2, layers=1).to(device)

# This is the VN mixer, here we use the weighted degree-average.
vn_mixer = MeanDegMixer(in_dim=MEM_DIM, out_dim=MEM_DIM, project=True, bias=False).to(device)

# In case you want to use the GAT-based aggregation (uncomment the next line)
#vn_mixer = VNMixer(in_dim=MEM_DIM, out_dim=MEM_DIM, heads=2, layers=1).to(device)

link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)

model = {'memory': memory,
         'gnn': gnn,
         'link_pred': link_pred}

optimizer = torch.optim.Adam(
    list(model['memory'].parameters())
    + list(model['gnn'].parameters())
    + list(model['link_pred'].parameters())
    + list(vn_mixer.parameters()), # The weighted-averaging doesn't have any parameters but the GAT-based does.
    lr=LR,
)

criterion = torch.nn.BCEWithLogitsLoss()

assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)

print("==========================================================")
print(f"=================*** k-TVN-{MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")

evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler

results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
    os.mkdir(results_path)
    print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'


# To track the val and test values.
runs_val_mrr, runs_test_mrr = [], []

# To track the running train, val and test times
runs_train_time_mean, runs_val_time_mean, runs_test_time_mean = [], [], []

# 
def _save_results_csv(
    runs_val_mrr, runs_test_mrr,
    runs_train_time_mean, runs_val_time_mean, runs_test_time_mean,
    out_dir="saved_results",
    filename_prefix=""):
    """
    This function is used to save the results as a csv file including the val/test MRR
    and also the runing time.
    It produces <out_dir>/<filename_prefix>mrr.csv
    """
    Path(out_dir).mkdir(parents=True, exist_ok=True)
    csv_path = Path(out_dir) / f"{filename_prefix}mrr.csv"

    import math
    def mean_std(xs):
        if not xs:
            return float("nan"), float("nan")
        m = sum(xs) / len(xs)
        var = sum((x - m) ** 2 for x in xs) / (len(xs) - 1) if len(xs) > 1 else 0.0
        return m, math.sqrt(var)

    val_mrr_mean, val_mrr_std   = mean_std(runs_val_mrr)
    test_mrr_mean, test_mrr_std = mean_std(runs_test_mrr)

    tr_time_mean, tr_time_std = mean_std(runs_train_time_mean)
    va_time_mean, va_time_std = mean_std(runs_val_time_mean)
    te_time_mean, te_time_std = mean_std(runs_test_time_mean)

    header = [
        "timestamp",
        "experiment",
        "type_clustering",
        "num_clusters",
        "num_iterations_clustering",
        "num_runs",
        # MRR stats
        "val_mrr_mean", "val_mrr_std",
        "test_mrr_mean", "test_mrr_std",
        "val_mrr_per_run", "test_mrr_per_run",
        # Time stats
        "train_time_mean_per_epoch_mean", "train_time_mean_per_epoch_std",
        "val_time_mean_per_epoch_mean",   "val_time_mean_per_epoch_std",
        "test_time_mean_per_epoch_mean",  "test_time_mean_per_epoch_std",
        "train_time_mean_per_epoch_per_run",
        "val_time_mean_per_epoch_per_run",
        "test_time_mean_per_epoch_per_run",
    ]

    row = [
        datetime.now().isoformat(timespec="seconds"),
        filename_prefix.rstrip("_"),
        ASSIGN_METHOD,
        K_COMMUNITIES,
        ASSIGN_KW["iters"],
        len(runs_val_mrr),
        # MRR
        f"{val_mrr_mean:.6f}", f"{val_mrr_std:.6f}",
        f"{test_mrr_mean:.6f}", f"{test_mrr_std:.6f}",
        ";".join(f"{x:.6f}" for x in runs_val_mrr),
        ";".join(f"{x:.6f}" for x in runs_test_mrr),
        # Times
        f"{tr_time_mean:.6f}", f"{tr_time_std:.6f}",
        f"{va_time_mean:.6f}", f"{va_time_std:.6f}",
        f"{te_time_mean:.6f}", f"{te_time_std:.6f}",
        ";".join(f"{x:.6f}" for x in runs_train_time_mean),
        ";".join(f"{x:.6f}" for x in runs_val_time_mean),
        ";".join(f"{x:.6f}" for x in runs_test_time_mean),
    ]

    file_exists = csv_path.exists()
    with csv_path.open("a", newline="") as f:
        w = csv.writer(f)
        if not file_exists:
            w.writerow(header)
        w.writerow(row)

    print(f"[Saved CSV] {csv_path}")


# Let's run the model now
for run_idx in range(NUM_RUNS):
    print('-------------------------------------------------------------------------------')
    print(f"INFO: >>>>> Run: {run_idx} <<<<<")
    start_run = timeit.default_timer()

    torch.manual_seed(SEED)
    set_random_seed(SEED)

    snapshot = RunningSnapshot(num_nodes=data.num_nodes, device=device)

    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, 
                                    tolerance=TOLERANCE, patience=PATIENCE)

    dataset.load_val_ns()
    
    dataset.load_test_ns()

    val_perf_list = []
    max_val_perf = 0.3 # In case a known validation target MRR is known, otherwise set to 0.
    max_test_perf = 0
    best_epoch = 0
    count = 0
    train_time_list = []
    val_time_list = []
    test_time_list = []

    start_train_val = timeit.default_timer()
    for epoch in range(1, NUM_EPOCH + 1):
        start_epoch_train = timeit.default_timer()
        
        snapshot = RunningSnapshot(num_nodes=data.num_nodes, device=device) # I changed this

        loss = train()
        print(
            f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
        )

        train_time = timeit.default_timer() - start_epoch_train
        train_time_list.append(train_time)
        start_val = timeit.default_timer()
        perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
        print(f"\tValidation {metric}: {perf_metric_val: .4f}")
        print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
        val_perf_list.append(perf_metric_val)
        val_time = timeit.default_timer() - start_val
        val_time_list.append(val_time)
        if(perf_metric_val>max_val_perf):
            max_val_perf = perf_metric_val
            start_test = timeit.default_timer()
            perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
            print(f"\tTest: {metric}: {perf_metric_test: .4f}")
            test_time = timeit.default_timer() - start_test
            print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
            test_time_list.append(test_time)  # NEW
            count = 0
            best_epoch = epoch
            max_test_perf = perf_metric_test
        else:   
            count += 1
            if count == 5: # Patience is set to 5 following our paper.
                break

    train_val_time = timeit.default_timer() - start_train_val
    print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")

    print(f"Best epoch: {best_epoch}, Max Validation {metric}: {max_val_perf: .4f}, Test {metric}: {max_test_perf: .4f}")

    # Storing the MRR values of the Val and Test
    runs_val_mrr.append(float(max_val_perf))
    runs_test_mrr.append(float(max_test_perf))

    def _mean(xs):
        """
            Runs the per-run mean times (seconds per epoch)
        """
        return float(np.mean(xs)) if len(xs) > 0 else float("nan")

    runs_train_time_mean.append(_mean(train_time_list))
    runs_val_time_mean.append(_mean(val_time_list))
    runs_test_time_mean.append(_mean(test_time_list))

    print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
    print('-------------------------------------------------------------------------------')

# Save all the results into a CSV
_save_results_csv(
    runs_val_mrr,
    runs_test_mrr,
    runs_train_time_mean,
    runs_val_time_mean,
    runs_test_time_mean,
    out_dir=results_path,
    filename_prefix=f"{MODEL_NAME}_{DATA}_"
)

print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
