import torch
from torchdrug import data
import torch.nn.functional as F
from torch.utils.data import default_collate


def gearnet_collate(batch):
    """
    Convert any list of same nested container into a container of tensors.

    For instances of :class:`data.Graph <torchdrug.data.Graph>`, they are collated
    by :meth:`data.Graph.pack <torchdrug.data.Graph.pack>`.

    Parameters:
        batch (list): list of samples with the same nested container
    """
    x, *list_tensors = zip(*batch)
    x = list(x)
    elem = x[0]
    for i in range(len(list_tensors)):
        list_tensors[i] = torch.stack(list_tensors[i], dim=0)

    if isinstance(elem, data.Graph):
        x_batch = elem.pack(x)
        return x_batch, *list_tensors
    else:
        raise ValueError("Unsupported type {}".format(type(elem)))
    
def gearnet_collate_masked(batch):
    """
    Convert any list of same nested container into a container of tensors.

    For instances of :class:`data.Graph <torchdrug.data.Graph>`, they are collated
    by :meth:`data.Graph.pack <torchdrug.data.Graph.pack>`.

    Parameters:
        batch (list): list of samples with the same nested container
    """
    x, y = zip(*batch)
    x = list(x)
    num_nodes = [g.num_node for g in x]
    elem = x[0]
    y_batch = default_collate(y)
    if isinstance(elem, data.Graph):
        x_batch = elem.pack(x)
        #x_batch = x
        return x_batch, y_batch, num_nodes
    else:
        raise ValueError("Unsupported type {}".format(type(elem)))


def gearnet_positions_collate(batch):
    """
    Convert any list of same nested container into a container of tensors.

    For instances of :class:`data.Graph <torchdrug.data.Graph>`, they are collated
    by :meth:`data.Graph.pack <torchdrug.data.Graph.pack>`.

    Parameters:
        batch (list): list of samples with the same nested container
    """
    x, y, p = zip(*batch)
    x = list(x)
    elem = x[0]
    y_batch = default_collate(y)
    p_batch = default_collate(p)
    if isinstance(elem, data.Graph):
        x_batch = elem.pack(x)
        #x_batch = x
        return x_batch, y_batch, p_batch
    else:
        raise ValueError("Unsupported type {}".format(type(elem)))
  

def gearnet_paired_collate(batch):
    x, x1, *list_tensors = zip(*batch)
    x = list(x)
    elem = x[0]
    elem1 = x1[0]
    for i in range(len(list_tensors)):
        list_tensors[i] = torch.stack(list_tensors[i], dim=0)

    if isinstance(elem, data.Graph) & isinstance(elem1, data.Graph):
        x_batch = elem.pack(x)
        x1_batch = elem1.pack(x1)
        #x_batch = x
        return x_batch, x1_batch, *list_tensors
    else:
        raise ValueError("Unsupported type {}".format(type(elem)))


def paired_collate(batch):
    x1, x2, *list_tensors_var = zip(*batch)
    list_tensors = [x1, x2] + [t for t in list_tensors_var]
    for i in range(len(list_tensors)):
        list_tensors[i] = torch.stack(list_tensors[i], dim=0)
    return list_tensors

def onehot_collate(batch):
    x1, *list_tensors_var = zip(*batch)
    list_tensors = [x1] + [t for t in list_tensors_var]
    for i in range(len(list_tensors)):
        list_tensors[i] = torch.stack(list_tensors[i], dim=0)
    return list_tensors


def universal_batcher(*inputs, repeats:int = 1):
    x1, *list_tensors_var = inputs
    if not isinstance(x1, data.Graph):
        x_batch = torch.stack([x1.squeeze() for _ in range(repeats)], dim=0)
    else:
        x_batch = x1.repeat_interleave(repeats)
    for i in range(len(list_tensors_var)):
        list_tensors_var[i] = torch.stack([list_tensors_var[i].squeeze(0)
                                           for _ in range(repeats)], dim=0
                                          )
    return x_batch, *list_tensors_var
    
    
def gearnet_annotate_collate(batch):
    """
    Convert any list of same nested container into a container of tensors.

    For instances of :class:`data.Graph <torchdrug.data.Graph>`, they are collated
    by :meth:`data.Graph.pack <torchdrug.data.Graph.pack>`.

    Parameters:
        batch (list): list of samples with the same nested container
    """
    x, y, *list_tensors = zip(*batch)
    x = list(x)
    elem = x[0]
    for i in range(len(list_tensors)):
        if isinstance(list_tensors[i], torch.tensor):
            list_tensors[i] = default_collate(list_tensors[i])
    if isinstance(elem, data.Graph):
        x_batch = elem.pack(x)
        return x_batch, *list_tensors
    else:
        raise ValueError("Unsupported type {}".format(type(elem)))
