import math
import timeit

import os
import os.path as osp
from pathlib import Path
import numpy as np
from tqdm import tqdm

from dataclasses import dataclass

import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch.utils.data import DataLoader, Dataset

from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv

# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset


class TemporalSequentialDataset(Dataset):
    def __init__(self, temporaldata, seq_len, step=256, split="train", prefix=0):
        if split in ["val", "test"]:
            assert prefix >= seq_len
        self.data = temporaldata
        self.seq_len = seq_len
        self.step = step
        self.prefix = prefix
        self.split = split
        self.length = (len(self.data) - self.seq_len + self.prefix) // self.step + 1
        self.eval_mask = torch.ones(len(self.data), dtype=torch.bool)
        self.eval_mask[:self.prefix] = False
        self.mem_mask = torch.ones(len(self.data), dtype=torch.bool)
        self.mem_mask[:self.prefix] = False
        
    def __getitem__(self, index):
        idx = slice(index * self.step, index * self.step + self.seq_len)
        bdata = self.data[idx]
        mask = self.eval_mask[idx].clone()
        mask[:self.seq_len - self.step] = False
        mem_mask = self.mem_mask[idx].clone()
        mem_mask[self.step:] = False
        return [index, bdata.src.unsqueeze(0), bdata.dst.unsqueeze(0), bdata.t.unsqueeze(0), 
                bdata.msg.unsqueeze(0), mask.unsqueeze(0), mem_mask.unsqueeze(0)]
    
    def __len__(self):
        return self.length

def load_temporal_data(dataset_name, batch_size):
    dataset = PyGLinkPropPredDataset(name=dataset_name, root="datasets")
    train_mask = dataset.train_mask
    val_mask = dataset.val_mask
    test_mask = dataset.test_mask
    data = dataset.get_TemporalData()
    metric = dataset.eval_metric

    train_data = data[train_mask]
    val_data = data[val_mask]
    test_data = data[test_mask]

    train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
    val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
    test_loader = TemporalDataLoader(test_data, batch_size=batch_size)

    max_num_nodes = max(data.src.max().item(), data.dst.max().item()) + 1
    timescale = (data.t[-1] - data.t[0]).item()/(data.t.shape[0]-1)

    return (train_loader, val_loader, test_loader), (metric, max_num_nodes, timescale)

def load_seq_data(dataset:PyGLinkPropPredDataset, seq_len, prefix=1024, step=256, batch_size=1):
    train_mask = dataset.train_mask
    train_len = train_mask.sum().item()
    val_mask = dataset.val_mask
    val_len = val_mask.sum().item()
    test_mask = dataset.test_mask
    data = dataset.get_TemporalData()
    metric = dataset.eval_metric
    train2val_idx = torch.arange(train_len-prefix, train_len)
    val_mask[train2val_idx] = True
    val2test_idx = torch.arange(train_len+val_len-prefix, train_len+val_len)
    test_mask[val2test_idx] = True

    train_data = data[train_mask]
    val_data = data[val_mask]
    test_data = data[test_mask]

    train_dataset = TemporalSequentialDataset(train_data, seq_len, step=step, split="train")
    val_dataset = TemporalSequentialDataset(val_data, seq_len, step=step, split="val", prefix=prefix)
    test_dataset = TemporalSequentialDataset(test_data, seq_len, step=step, split="test", prefix=prefix)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    min_src_idx, max_src_idx = int(data.src.min()), int(data.src.max())
    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
    max_num_nodes = max(data.src.max().item(), data.dst.max().item()) + 1
    timescale = (data.t[-1] - data.t[0]).item()/(data.t.shape[0]-1)

    return (train_loader, val_loader, test_loader), \
        (metric, min_src_idx, max_src_idx, min_dst_idx, max_dst_idx, max_num_nodes, 
         timescale, data.msg.shape[-1], dataset.negative_sampler)

def collate_fn(batch):
    batch_idx, src, dst, t, msg, mask, mem_mask = zip(*batch)
    src = torch.cat(src, dim=0)
    dst = torch.cat(dst, dim=0)
    t = torch.cat(t, dim=0)
    msg = torch.cat(msg, dim=0)
    mask = torch.cat(mask, dim=0)
    mem_mask = torch.cat(mem_mask, dim=0)
    assert len(batch_idx) == 1
    return batch_idx[0], batch_seq_data(src, dst, t, msg, mask, mem_mask)

@dataclass
class batch_seq_data:
    src: torch.Tensor
    dst: torch.Tensor
    t: torch.Tensor
    msg: torch.Tensor
    mask: torch.Tensor
    mem_mask: torch.Tensor

    @property
    def masked_1d_data(self):
        assert self.src.shape[0] == 1
        m1d = (self.src[self.mask].view(-1), 
                self.dst[self.mask].view(-1), 
                self.t[self.mask].view(-1), 
                self.msg[self.mask].view(self.src[self.mask].shape[0], -1))
        return m1d
    
    @property
    def mem_masked_1d_data(self):
        assert self.src.shape[0] == 1
        m1d = (self.src[self.mem_mask].view(-1), 
                self.dst[self.mem_mask].view(-1), 
                self.t[self.mem_mask].view(-1), 
                self.msg[self.mem_mask].view(self.src[self.mem_mask].shape[0], -1))
        return m1d
