import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import random
from torch import Tensor
import argparse    
from tqdm import tqdm
import numpy as np
import sys

from model_memory import TGModel
from preprocess import load_seq_data, batch_seq_data
from SeqWeaver import batch2seq, to_device_recursive
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from torch.nn.utils.rnn import pad_sequence

def get_args():
    parser = argparse.ArgumentParser("Temporal Transformer")
    parser.add_argument("--dataset", type=str, default="tgbl-wiki")
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--bs", type=int, default=1024)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--prefix", type=int, default=1024)
    parser.add_argument("--step", type=int, default=256)
    parser.add_argument("--wd", type=float, default=1e-3)
    parser.add_argument("--hiddim", type=int, default=128)
    parser.add_argument("--num_head", type=int, default=8)
    parser.add_argument("--num_layer", type=int, default=6)
    parser.add_argument("--inter_ratio", type=float, default=1.0)
    parser.add_argument("--attn_dp", type=float, default=0.05)
    parser.add_argument("--event_dp", type=float, default=0.05)

    try:
        args = parser.parse_args()
    except:
        parser.print_help()
        sys.exit(0)
    return args, sys.argv

def train_epoch(model: TGModel, loader: DataLoader, optimizer, device, event_dp: float):
    model.train()
    model.reset_memory()
    
    total_node_loss = 0
    total_time_loss = 0
    mrr_list = []
    for batch in tqdm(loader):
        optimizer.zero_grad()
        inputs, label = batch2seq(batch, event_dp=event_dp)
        inputs = to_device_recursive(device, inputs)
        label = to_device_recursive(device, label)
        
        sample_size = 100
        neg_idx = torch.randint(
            0,
            max_num_nodes,
            (len(label[0]), sample_size),
            dtype=torch.long,
            device=device,
        )
        sample_nodes = torch.concat((label[0].unsqueeze(-1), neg_idx), dim=-1)

        out_node, out_time, _ = model.forward(sample_nodes, *inputs)

        node_loss = F.cross_entropy(out_node, torch.zeros((out_node.shape[0], 1), dtype=torch.long, device=device))
        time_loss = (1 / model.timescale) * F.huber_loss(out_time, label[1].to(out_time.dtype))
        loss = node_loss + time_loss

        loss.backward()
        optimizer.step()
        total_node_loss += node_loss.item()
        total_time_loss += time_loss.item()
        model.update_memory(*(to_device_recursive(device, batch.mem_masked_1d_data)))
        model.detach_memory()

    return total_node_loss / len(loader), total_time_loss / len(loader)


def evaluation_mrr(model: TGModel, loader: DataLoader, device, evaluator: Evaluator, metric: str="mrr", split_mode="val"):
    model.eval()
    mrr_list = []
    for batch in tqdm(loader):
        mask = batch.mask
        if torch.all(torch.logical_not(mask)):
            if torch.any(batch.mem_mask):
                model.update_memory(*(to_device_recursive(device, batch.mem_masked_1d_data)))
            continue
        eval_src, eval_dst, eval_t = batch.src[mask], batch.dst[mask], batch.t[mask]
        neg_batch_list = neg_sampler.query_batch(eval_src, eval_dst, eval_t, split_mode=split_mode)
        neg_batch_list = [torch.tensor(_) for _ in neg_batch_list]
        sample_nodes = torch.concat((eval_dst.reshape(-1, 1), pad_sequence(neg_batch_list, batch_first=True, padding_value=-1)), dim=-1)
        input, label = batch2seq(batch, event_dp=0.00, pred_dst_only=True, pred_mask=mask.squeeze(0))
        input = to_device_recursive(device, input)
        label = to_device_recursive(device, label)

        with torch.no_grad():
            out_node, _, _ = model.forward(sample_nodes, *input)
            out_node = torch.sigmoid(out_node).squeeze(-1)
            out_node[sample_nodes<0] = -0.1
            out_node = out_node.cpu()

            mrr = calculate_mrr(out_node, evaluator, metric)
            mrr_list.append(mrr)

            if torch.any(batch.mem_mask):
                model.update_memory(*(to_device_recursive(device, batch.mem_masked_1d_data)))

    tot_mrr = np.concatenate(mrr_list).mean()
    return tot_mrr

def calculate_mrr(out_node: Tensor, evaluator: Evaluator, metric: str="mrr"):
    perf_list = []
    for idx in range(out_node.shape[0]):
        y_pred = out_node[idx]
        y_pred = y_pred[y_pred>=0]
        assert y_pred.dim() == 1
        input_dict = {
            "y_pred_pos": y_pred[0:1],
            "y_pred_neg": y_pred[1:],
            "eval_metric": [metric],
        }
        perf_list.append(evaluator.eval(input_dict)[metric])
    return np.array(perf_list)

if __name__ == "__main__":
    args, _ = get_args()

    dataset = PyGLinkPropPredDataset(name=args.dataset, root="datasets")
    loaders, data_props = load_seq_data(dataset, args.bs, prefix=args.prefix, step=args.step)
    train_loader, val_loader, test_loader = loaders
    metric, max_num_nodes, timescale, msgdim, neg_sampler = data_props
    evaluator = Evaluator(name=args.dataset)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TGModel(timescale, max_num_nodes, msgdim, args.hiddim, args.num_head, args.num_layer, inter_ratio=args.inter_ratio, attn_dp=args.attn_dp).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)

    dataset.load_val_ns()
    best_val_mrr = 0
    best_test_mrr = 0
    best_epoch = -1
    save_path = f"best_model_memory_{args.dataset}.pt"
    for epoch in range(args.epochs):
        nodeloss, timeloss = train_epoch(model, train_loader, optimizer, device, args.event_dp)
        print(f"epoch {epoch} nodeloss {nodeloss:.4f} timeloss {timeloss:.4f}")
        val_mrr = evaluation_mrr(model, val_loader, device, evaluator, split_mode="val")
        print(f"epoch {epoch} mrr {val_mrr:.4f}")
        if val_mrr > best_val_mrr:
            best_val_mrr = val_mrr
            best_epoch = epoch
            dataset.load_test_ns()
            test_mrr = evaluation_mrr(model, test_loader, device, evaluator, split_mode="test")
            best_test_mrr = test_mrr
            print(f"epoch {epoch} (best so far) test mrr {test_mrr:.4f}")
            torch.save(model.state_dict(), save_path)    
        
    print(f"Best Epoch: {best_epoch}, Val MRR: {best_val_mrr:.4f}, Test MRR at Best Val: {best_test_mrr:.4f}")


