from multiprocessing import cpu_count

import lightning as L
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset


class CaloImageDataset(Dataset):
    def __init__(
        self,
        file,
    ):
        super().__init__()
        data = np.load(file).astype('float32')
        self.data = np.expand_dims(data, axis=1)
        self.max_value = np.max(self.data)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        arr = self.data[idx]/self.max_value
        return np.clip(arr, a_min=0, a_max=1)


class SparseCaloImageDataset(CaloImageDataset):
    def __init__(
        self,
        file,
        data_dimensions
    ):
        super().__init__(file)
        self.start_position = data_dimensions
        self.end_position = self.start_position + 1
        self.pad_position = self.start_position + 2
        self.start_value = -1
        self.end_value = -2
        self.pad_value = -3

    def __getitem__(self, idx):
        arr = np.clip(self.data[idx]/self.max_value, a_min=0, a_max=1)
        positions = torch.cat((torch.tensor([self.start_position]), torch.tensor(arr.flatten(
        ).nonzero()[0]), torch.tensor([self.end_position])))
        values = torch.cat((torch.tensor([self.start_value]), torch.tensor(arr.flatten()[
                           arr.flatten() != 0]), torch.tensor([self.end_value])))
        return positions, values


class CaloImageDataModule(L.LightningDataModule):
    def __init__(self, train_data_dir: str, val_data_dir: str, batch_size: int = 32, image_size: int = 32, input_mode: str = "calo_image"):
        super().__init__()
        self.train_data_dir = train_data_dir
        self.val_data_dir = val_data_dir
        self.batch_size = batch_size
        self.image_size = image_size
        self.input_mode = input_mode
        self.save_hyperparameters()

    def setup(self, stage: str):
        self.train_dataset = CaloImageDataset(
            self.train_data_dir)
        self.max_value = self.train_dataset.max_value
        self.val_dataset = CaloImageDataset(
            self.val_data_dir)

    def state_dict(self):
        state = {"max_value": self.max_value}
        return state

    def load_state_dict(self, state_dict):
        self.max_value = state_dict["max_value"]

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)


class SparseCaloImageDataModule(CaloImageDataModule):
    def __init__(self, train_data_dir: str, val_data_dir: str, batch_size: int = 32, data_dimensions: int = 1024, input_mode: str = "calo_image"):
        super().__init__(train_data_dir, val_data_dir, batch_size, input_mode)
        self.data_dimensions = data_dimensions

    def setup(self, stage: str):
        self.train_dataset = SparseCaloImageDataset(
            self.train_data_dir, self.data_dimensions)
        self.max_value = self.train_dataset.max_value
        self.start_position = self.train_dataset.start_position
        self.end_position = self.train_dataset.end_position
        self.pad_position = self.train_dataset.pad_position
        self.val_dataset = SparseCaloImageDataset(
            self.val_data_dir, self.data_dimensions)

    def state_dict(self):
        state = {"max_value": self.max_value, "start_position": self.start_position,
                 "end_position": self.end_position, "pad_position": self.pad_position}
        return state

    def load_state_dict(self, state_dict):
        self.max_value = state_dict["max_value"]
        self.start_position = state_dict["start_position"]
        self.end_position = state_dict["end_position"]
        self.pad_position = state_dict["pad_position"]

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=lambda batch: collate_sparse_batch(batch, pad_position=self.train_dataset.pad_position, pad_value=self.train_dataset.pad_value), shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=lambda batch: collate_sparse_batch(batch, pad_position=self.val_dataset.pad_position, pad_value=self.val_dataset.pad_value))


def collate_sparse_batch(batch, pad_position=-1, pad_value=-1):
    """
    Custom collate function that pads sequences to the max length in the batch
    """
    # Separate positions and labels
    positions = [item[0] for item in batch]
    values = [item[1] for item in batch]

    # Pad sequences to the same length
    padded_positions = pad_sequence(
        positions, batch_first=True, padding_value=pad_position)
    padded_values = pad_sequence(
        values, batch_first=True, padding_value=pad_value)

    return padded_positions, padded_values
