import torch
import torch.nn as nn

import torchvision.transforms as T

import numpy as np

from oucl.scenarios.transforms import load_transform, NumpyRawTransform, ExtractPatches

class DefaultCollate(nn.Module):

    def __init__(self, transform):

        super(DefaultCollate, self).__init__()
        self.transform = transform

    def forward(self, batch):

        images = [self.transform(item[0]) for item in batch]
        labels = [item[1] for item in batch]
        indices = [item[2] for item in batch]

        if len(batch[0]) > 3:
            object_labels = [item[3] for item in batch]
        else:
            object_labels = None

        if isinstance(images[0], (list, tuple)):
            images = torch.stack([torch.stack(x) for x in images]).swapaxes(0, 1)
        else:
            images = torch.stack(images)

        # Convert labels and task_labels to tensors
        labels = torch.tensor(labels)
        indices = torch.tensor(indices)

        if object_labels != None:
            object_labels = torch.tensor(object_labels)
    

        return images, labels, indices, object_labels
    
class RawCollate(nn.Module):

    def __init__(self, transform, raw_transform):

        super(RawCollate, self).__init__()
        
        self.transform = transform
        self.raw_transform = raw_transform

    def forward(self, batch):

        images = [self.transform(item[0]) for item in batch]
        raws = [self.raw_transform(item[0]) for item in batch]

        labels = [item[1] for item in batch]
        indices = [item[2] for item in batch]

        if len(batch[0]) > 3:
            object_labels = [item[3] for item in batch]
        else:
            object_labels = None

        if isinstance(images[0], (list, tuple)):
            images = torch.stack([torch.stack(x) for x in images]).swapaxes(0, 1)
        else:
            images = torch.stack(images)

        if isinstance(raws[0], np.ndarray):
            raws = np.stack(raws, 0)
        else:
            raws = torch.stack(raws, 0)

        # Convert labels and task_labels to tensors
        labels = torch.tensor(labels)
        indices = torch.tensor(indices)
        if object_labels != None:
            object_labels = torch.tensor(object_labels)


        return (images, raws), labels, indices, object_labels
    
def load_collate_fn(collate_type, transform, img_size, mean, std, config):
    transform = load_transform(transform,
                                      img_size,
                                      mean,
                                      std,
                                      config)
    if collate_type == 'include_raw':
        return RawCollate(transform,
                          NumpyRawTransform(
                            img_size,
                            config.alter_raw
                          ))
    else:
        return DefaultCollate(transform)