import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import random
from torch import Tensor
import argparse    
from tqdm import tqdm
import numpy as np
import sys
from model import TGModel
from preprocess import load_seq_data, batch_seq_data
from SeqWeaver import batch2seq, to_device_recursive
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from torch.nn.utils.rnn import pad_sequence


def get_args():
    parser = argparse.ArgumentParser("Temporal Transformer")
    parser.add_argument("--dataset", type=str, default="tgbl-wiki")
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--bs", type=int, default=1024)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--device", type=str, default=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    parser.add_argument("--wd", type=float, default=3e-5)
    parser.add_argument("--prefix", type=int, default=1024)
    parser.add_argument("--step", type=int, default=256)
    parser.add_argument("--graph_size", type=int, default=1024)
    parser.add_argument("--hiddim", type=int, default=256)
    parser.add_argument("--num_head", type=int, default=8)
    parser.add_argument("--num_layer", type=int, default=6)
    parser.add_argument("--inter_ratio", type=float, default=1.0)
    parser.add_argument("--attn_dp", type=float, default=0.10)
    parser.add_argument("--event_dp", type=float, default=0.05)

    try:
        args = parser.parse_args()
        print(args)
    except:
        parser.print_help()
        sys.exit(0)
    return args, sys.argv

args, _ = get_args()


def choose_time_window(batch_idx, prev_len, quantile=0.8, M=1300):
    end_idx = batch_idx * args.step + prev_len
    start_idx = max(0, end_idx - M)
    recent_ts  = dataset.ts[start_idx:end_idx]

    if recent_ts.numel() < 2:
        return args.graph_size 
    dt = recent_ts[1:] - recent_ts[:-1]
    dt_q = torch.quantile(dt.float(), quantile).item()
    return dt_q * M

def check_temporal_leakage(start_prev, end_prev, dataset, stage="train"):
    ts_history = dataset.ts[start_prev:end_prev]
    if ts_history.numel() == 0:
        print(f"[{stage.upper()}] No history edges for batch {end_prev}, skipping check.")
        return

    max_history_ts = ts_history.max().item()
    current_ts = dataset.ts[end_prev - 1].item()

    if max_history_ts > current_ts:
        print(f"[{stage.upper()} WARNING] TIME LEAK DETECTED!")
        print(f"  → History max timestamp = {max_history_ts}")
        print(f"  → Current target timestamp = {current_ts}")
    else:
        print(f"[{stage.upper()} OK] No leak: history ≤ target (max_ts={max_history_ts}, target_ts={current_ts})")

decay_alphas = torch.tensor([0.001, 0.01, 0.1, 1.0], device=args.device) 
def get_prev_graph(batch_idx, prev_len=0):
    time_window = choose_time_window(batch_idx, prev_len)
    end_prev = min(batch_idx * args.step + prev_len, len(dataset))
    start_prev  = torch.searchsorted(dataset.ts, dataset.ts[end_prev-1] - time_window).item() 
    gsrc, gdst = dataset.src[start_prev:end_prev], dataset.dst[start_prev:end_prev]
    timestamps_original = dataset.ts[start_prev:end_prev].to(args.device) 
    edge_feat_original = dataset.edge_feat[start_prev:end_prev].to(args.device) 
    edge_index = torch.cat([gsrc, gdst], dim=0).reshape(2, -1).to(args.device)
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1).to(args.device)
    timestamps = torch.cat([timestamps_original, timestamps_original], dim=0)
    edge_feat = torch.cat([edge_feat_original, edge_feat_original], dim=0)
    if timestamps_original.numel() == 0:
        num_decay_rates = decay_alphas.size(0) 
        return edge_index, torch.empty(0, num_decay_rates, device=args.device), torch.empty(0, edge_feat.size(-1), device=args.device)
    current_time = timestamps.max() 
    time_diff = current_time - timestamps
    time_diff = time_diff.unsqueeze(-1)
    decay_alphas_reshaped = decay_alphas.unsqueeze(0) 
    exponent = -decay_alphas_reshaped * time_diff
    edge_features = torch.exp(exponent)
    return edge_index, edge_features, edge_feat

decay_alpha = 0.0001
def get_prev_graph_old(batch_idx, prev_len=0):
    end_prev = batch_idx * args.step + prev_len
    start_prev = max(0, end_prev - args.graph_size) + prev_len
    gsrc, gdst = dataset.src[start_prev:end_prev], dataset.dst[start_prev:end_prev]
    edge_index = torch.cat([gsrc, gdst], dim=0).reshape(2, -1)
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1).to(args.device)
    timestamps = dataset.ts[start_prev:end_prev].to(args.device)
    if timestamps.numel() == 0:
        return edge_index, torch.ones(edge_index.size(1), device=args.device)

    timestamps = torch.cat([timestamps, timestamps], dim=0)
    current_time = timestamps.max()
    edge_weight = torch.exp(-decay_alpha * (current_time - timestamps))
    return edge_index, edge_weight

def get_neg_sample(node_num, sample_size, neg_type="all"):
    if neg_type == "all":
        neg_idx = torch.randint(
            0,
            max_num_nodes,
            (node_num, sample_size),
            dtype=torch.long,
            device=device,
        )
    elif neg_type == "target":
        neg_idx = torch.zeros((node_num, sample_size), dtype=torch.long, device=device)
        neg_dst_idx = torch.randint(
            min_dst_idx,
            max_dst_idx + 1,
            (node_num//2, sample_size),
            dtype=torch.long,
            device=device,
        )
        neg_src_idx = torch.randint(
            min_src_idx,
            max_src_idx + 1,
            (node_num//2, sample_size),
            dtype=torch.long,
            device=device,
        )
        neg_idx[1::2] = neg_dst_idx
        neg_idx[::2] = neg_src_idx

    return neg_idx
    

def train_epoch(model: TGModel, loader: DataLoader, optimizer, device, event_dp: float):
    model.train()

    total_node_loss = 0
    total_time_loss = 0

    for batchi in tqdm(loader):
        optimizer.zero_grad()
        batch_idx, batch = batchi

        inputs, label, used_event_dp_mask = batch2seq(batch, event_dp=event_dp)
        inputs = to_device_recursive(device, inputs)
        label = to_device_recursive(device, label)
        edge_index, edge_weight, edge_feat = get_prev_graph(batch_idx)    
        sample_size = 200

        neg_idx = get_neg_sample(len(label[0]), sample_size, "target")      
        sample_nodes = torch.concat((label[0].unsqueeze(-1), neg_idx), dim=-1)
        out_node, out_time, _ = model.forward(edge_index, edge_weight, edge_feat, sample_nodes, *inputs)

        node_loss = F.cross_entropy(out_node, torch.zeros((out_node.shape[0], 1), dtype=torch.long, device=device))
        time_loss = (1 / model.timescale) * F.huber_loss(out_time, label[1].to(out_time.dtype))
        loss = node_loss + time_loss

        loss.backward()
        optimizer.step()
        scheduler.step()  
        total_node_loss += node_loss.item()
        total_time_loss += time_loss.item()

    return total_node_loss / len(loader), total_time_loss / len(loader)


def evaluation_mrr(prev_graph_length: int, model: TGModel, loader: DataLoader, device, evaluator: Evaluator, metric: str="mrr", split_mode="val"):
    model.eval()
    mrr_list = []
    for batchi in tqdm(loader):
        batch_idx, batch = batchi
        mask = batch.mask
        if torch.all(torch.logical_not(mask)):
            continue
        eval_src, eval_dst, eval_t = batch.src[mask], batch.dst[mask], batch.t[mask]
        neg_batch_list = neg_sampler.query_batch(eval_src, eval_dst, eval_t, split_mode=split_mode)
        neg_batch_list = [torch.tensor(_) for _ in neg_batch_list]
        sample_nodes = torch.concat((eval_dst.reshape(-1, 1), pad_sequence(neg_batch_list, batch_first=True, padding_value=-1)), dim=-1)
        inputs, label, _ = batch2seq(batch, event_dp=0.00, pred_dst_only=True, pred_mask=mask.squeeze(0))
        inputs = to_device_recursive(device, inputs)
        label = to_device_recursive(device, label)

        with torch.no_grad():
            edge_index, edge_weight, edge_feat = get_prev_graph(batch_idx, prev_len=prev_graph_length)
            out_node, _, _ = model.forward(edge_index, edge_weight, edge_feat, sample_nodes, *inputs)
            out_node = torch.sigmoid(out_node).squeeze(-1)
            out_node[sample_nodes<0] = -0.1
            out_node = out_node.cpu()
            mrr = calculate_mrr(out_node, evaluator, metric)
            mrr_list.append(mrr)
    tot_mrr = np.concatenate(mrr_list).mean()
    return tot_mrr

def calculate_mrr(out_node: Tensor, evaluator: Evaluator, metric: str="mrr"):
    perf_list = []
    for idx in range(out_node.shape[0]):
        y_pred = out_node[idx]
        y_pred = y_pred[y_pred>=0]
        assert y_pred.dim() == 1
        input_dict = {
            "y_pred_pos": y_pred[0:1],
            "y_pred_neg": y_pred[1:],
            "eval_metric": [metric],
        }
        perf_list.append(evaluator.eval(input_dict)[metric])
    return np.array(perf_list)

if __name__ == "__main__":
    dataset = PyGLinkPropPredDataset(name=args.dataset, root="datasets")
    loaders, data_props = load_seq_data(dataset, args.bs)
    train_loader, val_loader, test_loader = loaders
    metric, min_src_idx, max_src_idx, min_dst_idx, max_dst_idx, max_num_nodes, timescale, msgdim, neg_sampler = data_props
    evaluator = Evaluator(name=args.dataset)
    device = args.device
    model = TGModel(timescale, max_num_nodes, msgdim, args.hiddim, args.num_head, args.num_layer, inter_ratio=args.inter_ratio, attn_dp=args.attn_dp).to(device)
    optimizer = torch.optim.AdamW(list(model.parameters()), lr=args.lr, weight_decay=args.wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0 = 3 * len(train_loader),    
        T_mult = 1,       
        eta_min = args.lr * 0.01
    )
    offset =  - args.prefix + args.bs - args.step
    dataset.load_val_ns()
    best_val_mrr = 0
    best_test_mrr = 0
    best_epoch = -1
    save_path = f"{args.dataset}.pt"
    for epoch in range(args.epochs):
        nodeloss, timeloss = train_epoch(model, train_loader, optimizer, device, args.event_dp)
        print(f"epoch {epoch} nodeloss {nodeloss:.4f} timeloss {timeloss:.4f}")
        val_mrr = evaluation_mrr(len(dataset[dataset.train_mask]) + offset, 
                                 model, val_loader, device, evaluator, split_mode="val")
        print(f"epoch {epoch} mrr {val_mrr:.4f}")
        if val_mrr > best_val_mrr:
            best_val_mrr = val_mrr
            best_epoch = epoch
            dataset.load_test_ns()
            test_mrr = evaluation_mrr(len(dataset[dataset.train_mask|dataset.val_mask]) + offset, 
                                      model, test_loader, device, evaluator, split_mode="test")
            best_test_mrr = test_mrr
            print(f"epoch {epoch} (best so far) test mrr {test_mrr:.4f}")
            torch.save(model.state_dict(), save_path)    

    print(f"Best Epoch: {best_epoch}, Val MRR: {best_val_mrr:.4f}, Test MRR at Best Val: {best_test_mrr:.4f}")

