from functools import partial
from pathlib import Path

import lightning as L
import torch
from PIL import Image
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T

from sed.utils import exists


def convert_image_to(pil_img_type, image):
    if image.mode != pil_img_type:
        return image.convert(pil_img_type)
    return image

# dataset classes


class ImageDataset(Dataset):
    def __init__(
        self,
        dir,
        image_size,
        exts=['jpg', 'jpeg', 'png', 'tiff', 'JPEG'],
        augment_horizontal_flip=False,
        pil_img_type=None
    ):
        super().__init__()
        self.folder = dir
        self.image_size = image_size
        self.paths = [p for ext in exts for p in Path(
            f'{dir}').glob(f'**/*.{ext}')]
        maybe_convert_fn = partial(convert_image_to, pil_img_type) if exists(
            pil_img_type) else nn.Identity()

        self.transform = T.Compose([
            T.Lambda(maybe_convert_fn),
            T.Resize(image_size),
            T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)


class SparseImageDataset(ImageDataset):
    def __init__(
        self,
        dir,
        image_size,
        data_dimensions,
        pil_img_type=None
    ):
        super().__init__(dir, image_size, pil_img_type=pil_img_type)
        self.start_position = data_dimensions
        self.end_position = self.start_position + 1
        self.pad_position = self.start_position + 2
        self.start_value = -1
        self.end_value = -2
        self.pad_value = -3

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        img = self.transform(img)
        positions = torch.cat((torch.tensor([self.start_position]), img.flatten(
        ).nonzero().squeeze(), torch.tensor([self.end_position])))
        values = torch.cat((torch.tensor([self.start_value]), img.flatten()[
                           img.flatten() != 0], torch.tensor([self.end_value])))
        return positions, values


class ImageDataModule(L.LightningDataModule):
    def __init__(self, train_data_dir: str, val_data_dir: str, batch_size: int = 32, image_size: int = 28, pil_img_type: str = 'L', input_mode: str = "image"):
        super().__init__()
        self.train_data_dir = train_data_dir
        self.val_data_dir = val_data_dir
        self.batch_size = batch_size
        self.image_size = image_size
        self.pil_img_type = pil_img_type
        self.input_mode = input_mode
        self.save_hyperparameters()

    def setup(self, stage: str):
        self.train_dataset = ImageDataset(
            self.train_data_dir, self.image_size, pil_img_type=self.pil_img_type)
        self.val_dataset = ImageDataset(
            self.val_data_dir, self.image_size, pil_img_type=self.pil_img_type)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)


class SparseImageDataModule(ImageDataModule):
    def __init__(self, train_data_dir: str, val_data_dir: str, batch_size: int = 32, image_size: int = 28, pil_img_type: str = 'L', data_dimensions: int = 1024, input_mode: str = 'image'):
        super().__init__(train_data_dir, val_data_dir, batch_size, image_size, pil_img_type, input_mode)
        self.data_dimensions = data_dimensions

    def setup(self, stage: str):
        self.train_dataset = SparseImageDataset(
            self.train_data_dir, self.image_size, self.data_dimensions, pil_img_type=self.pil_img_type)
        self.start_position = self.train_dataset.start_position
        self.end_position = self.train_dataset.end_position
        self.pad_position = self.train_dataset.pad_position
        self.val_dataset = SparseImageDataset(
            self.val_data_dir, self.image_size, self.data_dimensions, pil_img_type=self.pil_img_type)
        
    def state_dict(self):
        state = {"start_position": self.start_position,
                 "end_position": self.end_position, "pad_position": self.pad_position}
        return state

    def load_state_dict(self, state_dict):
        self.start_position = state_dict["start_position"]
        self.end_position = state_dict["end_position"]
        self.pad_position = state_dict["pad_position"]

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=lambda batch: collate_sparse_batch(batch, pad_position=self.train_dataset.pad_position, pad_value=self.train_dataset.pad_value), shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=lambda batch: collate_sparse_batch(batch, pad_position=self.val_dataset.pad_position, pad_value=self.val_dataset.pad_value))


def collate_sparse_batch(batch, pad_position=-1, pad_value=-1):
    """
    Custom collate function that pads sequences to the max length in the batch
    """
    # Separate positions and labels
    positions = [item[0] for item in batch]
    values = [item[1] for item in batch]

    # Pad sequences to the same length
    padded_positions = pad_sequence(
        positions, batch_first=True, padding_value=pad_position)
    padded_values = pad_sequence(
        values, batch_first=True, padding_value=pad_value)

    return padded_positions, padded_values
