import json
from pathlib import Path

import numpy as np
import torch
from lightning import LightningDataModule
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset


class EasyTPPDataset(Dataset):
    """
    Real-world TPP datasets.

    Args:
        name: Name of the dataset
            Choices: "amazon", "earthquake", "retweet", "stackoverflow", 'taobao', "taxi", "volcano"
        split: Which data split to use
            Choices: "train", "dev", "test"
        min_len: Minimum sequence length, default=2
    """
    def __init__(
        self,
        data_dir: str,
        name: str,
        split: str,
        min_len: int = 2,
    ):
        super().__init__()

        data_dir = Path(data_dir).resolve()

        # Load data
        path = data_dir / name / f'{split}.json'
        with open(path, 'r') as f:
            data = json.load(f)
            data = [x for x in data if x['seq_len'] > min_len]

        # Set the process dimension
        self.dim = data[0]['dim_process']

        # Extract tensor lists
        self.arrival_times = [torch.from_numpy(np.array(x['time_since_start'])).float() for x in data]
        self.delta_times = [torch.from_numpy(np.array(x['time_since_last_event'])).float() for x in data]
        self.marks = [torch.from_numpy(np.array(x['type_event'])).long() for x in data]

    def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
        return (
            self.delta_times[index],
            self.marks[index],
        )

    def __len__(self) -> int:
        return len(self.delta_times)


def collate_fn(
    data: list[list[Tensor]],
) -> dict[str, Tensor]:
    """
    Returns:
        marks: Marks with values from 0 to dim-1, shape (batch, seq_len)
        delta_times: Delta times, shape (batch, seq_len, 1)
        mask: Mask indicating observed (1) and missing values (0), shape (batch, seq_len)
        seq_len: Length of the sequence, shape (batch,)
    """
    delta_times = [x[0] for x in data]
    marks = [x[1] for x in data]
    mask = [torch.ones_like(x) for x in delta_times]

    marks = pad_sequence(marks, batch_first=True)
    delta_times = pad_sequence(delta_times, batch_first=True, padding_value=0)
    mask = pad_sequence(mask, batch_first=True)

    return marks, delta_times, mask


class EasyTPPDataModule(LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        name: str,
        batch_size: int,
        test_batch_size: int = 1,
        num_workers: int = 0,
        min_len: int = 2,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.test_batch_size = test_batch_size

        self.trainset = EasyTPPDataset(data_dir, name, 'train', min_len)
        self.valset = EasyTPPDataset(data_dir, name, 'dev', min_len)
        self.testset = EasyTPPDataset(data_dir, name, 'test', min_len)

        self.dim = self.trainset.dim

        self.dl_kwargs = dict(
            num_workers=num_workers,
            collate_fn=collate_fn,
        )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.trainset, shuffle=True, batch_size=self.batch_size, **self.dl_kwargs)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.valset, batch_size=self.batch_size, **self.dl_kwargs)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.testset, batch_size=self.test_batch_size, **self.dl_kwargs)


class CustomTPPDataset(Dataset):
    def __init__(
        self,
        delta_times: list[list[float]],
        marks: list[list[int]],
    ):
        super().__init__()
        delta_times = [torch.from_numpy(np.array(x)).float() for x in delta_times]
        marks = [torch.from_numpy(np.array(x)).long() for x in marks]

        self.delta_times = delta_times
        self.marks = marks

    def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
        return (
            self.delta_times[index],
            self.marks[index],
        )

    def __len__(self) -> int:
        return len(self.delta_times)


class CustomTPPDatamodule(LightningDataModule):
    def __init__(
        self,
        delta_times: list[list[float]],
        marks: list[list[int]],
        train_ratio: float,
        val_ratio: float,
        batch_size: int,
        test_batch_size: int = 1,
        num_workers: int = 0,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.test_batch_size = test_batch_size
        self.dim = max(max(x) for x in marks) + 1

        data_size = len(delta_times)
        train_ind = int(data_size * train_ratio)
        val_ind = int(data_size * val_ratio) + train_ind

        self.trainset = CustomTPPDataset(delta_times[:train_ind], marks[:train_ind])
        self.valset = CustomTPPDataset(delta_times[train_ind:val_ind], marks[train_ind:val_ind])
        self.testset = CustomTPPDataset(delta_times[val_ind:], marks[val_ind:])

        self.dl_kwargs = dict(
            num_workers=num_workers,
            collate_fn=collate_fn,
        )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.trainset, shuffle=True, batch_size=self.batch_size, **self.dl_kwargs)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.valset, batch_size=self.batch_size, **self.dl_kwargs)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.testset, batch_size=self.test_batch_size, **self.dl_kwargs)
