import torch

import numpy as np

try:
    import torch
except ImportError:
    print("Pytorch not installed. Please install pytorch first.")

class PadTensors:
    """This code is adapted from Tonic:
    URL - https://github.com/neuromorphs/tonic/blob/develop/tonic/collation.py

    This is a custom collate function for a pytorch dataloader to load multiple event recordings
    at once. It's intended to be used in combination with sparse tensors. All tensor sizes are
    extended to the largest one in the batch, i.e. the longest recording.

    Example:
        >>> dataloader = torch.utils.data.DataLoader(dataset,
        >>>                                          batch_size=10,
        >>>                                          collate_fn=tonic.collation.PadTensors(),
        >>>                                          shuffle=True)
    """

    def __init__(self, batch_first: bool = True):
        self.batch_first = batch_first
        self.label_map = dict()
        self.class_num = 0

    def __call__(self, batch):
        samples_output = []
        targets_output = []

        max_length = max([sample.shape[0] for sample, target in batch])
        print(batch)
        for sample, target in batch:
            if not isinstance(sample, torch.Tensor):
                sample = torch.tensor(sample)
            if not isinstance(target, torch.Tensor):
                if isinstance(target, str):
                    if target not in self.label_map.keys():
                        self.label_map[target] = self.class_num
                        self.class_num += 1
                    target = self.label_map[target]

                # check if target is numpy.void type
                if isinstance(target, np.void):
                    print(target)
                target = torch.tensor(target)
            if sample.is_sparse:
                sample.sparse_resize_(
                    (max_length, *sample.shape[1:]),
                    sample.sparse_dim(),
                    sample.dense_dim(),
                )
            else:
                print(max_length)
                sample = torch.cat(
                    (
                        sample,
                        torch.zeros(
                            max_length - sample.shape[0],
                            *sample.shape[1:],
                            device=sample.device
                        ),
                    )
                )
            samples_output.append(sample)
            targets_output.append(target)
        samples_output = torch.stack(samples_output, 0 if self.batch_first else 1)
        if len(targets_output[0].shape) > 1:
            targets_output = torch.stack(targets_output, 0 if self.batch_first else -1) 
        else:
            targets_output = torch.tensor(targets_output, device=target.device)
        return (samples_output, targets_output)