import numpy as np
import torch
import torch.utils.data

from transformer import Constants


class EventData(torch.utils.data.Dataset):
    """ Event stream dataset. """

    def __init__(self, data):
        """
        Data should be a list of event streams; each event stream is a list of dictionaries;
        each dictionary contains: time_since_start, time_since_last_event, type_event
        """
        self.time = [[elem['time_since_start'] for elem in inst] for inst in data]
        self.time_gap = [[elem['time_since_last_event'] for elem in inst] for inst in data]
        # plus 1 since there could be event type 0, but we use 0 as padding
        self.event_type = [[elem['type_event'] + 1 for elem in inst] for inst in data]

        self.length = len(data)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        """ Each returned element is a list, which represents an event stream """
        return self.time[idx], self.time_gap[idx], self.event_type[idx]


def pad_time(insts):
    """ Pad the instance to the max seq length in batch. """

    max_len = max(len(inst) for inst in insts)

    batch_seq = np.array([
        inst + [Constants.PAD] * (max_len - len(inst))
        for inst in insts])

    return torch.tensor(batch_seq, dtype=torch.float32)


def pad_type(insts):
    """ Pad the instance to the max seq length in batch. """

    max_len = max(len(inst) for inst in insts)

    batch_seq = np.array([
        inst + [Constants.PAD] * (max_len - len(inst))
        for inst in insts])

    return torch.tensor(batch_seq, dtype=torch.long)


def collate_fn(insts):
    """ Collate function, as required by PyTorch. """

    time, time_gap, event_type = list(zip(*insts))
    time = pad_time(time)
    time_gap = pad_time(time_gap)
    event_type = pad_type(event_type)
    return time, time_gap, event_type

def load_univariate_point_process(pkl_file):
   
    """Load dataset."""
    if not pkl_file.endswith('.pkl'):
        pkl_file += '.pkl'
    data = torch.load(pkl_file)
    sequences = data["sequences"]  

    event_streams = []
    for seq in sequences:
        arrival_times = seq["arrival_times"]
        stream = []
        prev_time = 0.0
        for t in arrival_times:
            event = {
                "time_since_start": t,
                "time_since_last_event": t - prev_time,  
                "type_event": 0 
            }
            stream.append(event)
            prev_time = t
        event_streams.append(stream)
    return event_streams



def get_dataloader(data, batch_size, shuffle=True):
    """ Prepare dataloader. """

    ds = EventData(data)
    dl = torch.utils.data.DataLoader(
        ds,
        num_workers=2,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle
    )
    return dl


def split_event_streams(event_streams, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2, shuffle=True):
    if not np.isclose(train_ratio + val_ratio + test_ratio, 1.0):
        raise ValueError("sum of proportions must be 1.")
    n = len(event_streams)
    if shuffle:
        np.random.shuffle(event_streams)
    train_end = int(n * train_ratio)
    val_end = train_end + int(n * val_ratio)
    train_stream = event_streams[:train_end]
    val_stream = event_streams[train_end:val_end]
    test_stream = event_streams[val_end:]
    return train_stream, val_stream, test_stream
