from torch import Tensor
from typing import Tuple
from pathlib import Path
import numpy as np
from tqdm import tqdm

import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear

from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv

from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset, TemporalData

from preprocess import batch_seq_data

def to_device_recursive(device: torch.DeviceObjType, obj):
    if isinstance(obj, Tensor):
        return obj.to(device, non_blocking=True)
    elif isinstance(obj, tuple):
        return tuple([to_device_recursive(device, _) for _ in obj])
    elif isinstance(obj, dict):
        return {key:to_device_recursive(device, obj[key]) for key in obj}
    elif obj is None:
        return None
    else:
        raise NotImplementedError(type(obj))


def batch2seq(batch: batch_seq_data, event_dp: float=0.00, pred_dst_only: bool=False, pred_mask: Tensor=None):
    if pred_mask is not None:
        if not pred_dst_only:
            raise NotImplementedError ("allow mask pred only when pred dst only")
    s = batch.src
    d = batch.dst
    
    t = batch.t
    t = t - t.min()
    
    msg = batch.msg
    event_dp_mask = torch.rand_like(s, dtype=torch.float) >= event_dp
    s, d, t, msg = s[event_dp_mask], d[event_dp_mask], t[event_dp_mask], msg[event_dp_mask]
    tokens = torch.stack((-torch.ones_like(s), s, d), dim=-1).flatten()
    
    in_type = torch.stack((torch.zeros_like(d), torch.ones_like(s), 2*torch.ones_like(d)), dim=-1).flatten()
    msg_idx = torch.arange(s.shape[0]).unsqueeze(-1).expand(-1, 3).flatten()
    realtime = t[msg_idx]
    basetime = torch.concat((torch.zeros_like(t[:1]), t), dim=0)[msg_idx]
    tokens, labels = tokens[:-1], tokens[1:]
    in_type, out_type = in_type[:-1], in_type[1:]
    timeidx = torch.nonzero(tokens==-1).flatten()
    times = t[msg_idx[timeidx]]
    out_timeidx = torch.nonzero(labels==-1).flatten()
    if not pred_dst_only:
        out_nodeidx = torch.nonzero(labels!=-1).flatten()
    else:
        out_nodeidx = torch.nonzero(out_type==2).flatten()

    if pred_mask is not None:
        out_nodeidx = out_nodeidx[pred_mask]
        out_timeidx = out_timeidx[pred_mask[1:]]

    out_time = realtime[1:][out_timeidx]
    out_node = labels[out_nodeidx]
    
    msg_idx = msg_idx[:-1]
    realtime = realtime[:-1]
    basetime = basetime[:-1]

    indict = {
        "tokens": tokens.clamp_min(0),
        "labels": labels.clamp_min(0),
        "timeidx": timeidx,
        "times": times,
        "msgidx": msg_idx,
        "msgs": msg,
        "in_type": in_type,
        }
    
    return (indict, basetime, realtime, out_nodeidx, out_timeidx), (out_node, out_time), event_dp_mask
    

    