import lightning as L
import numpy as np
import scanpy as sc
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, random_split


class CellDataset(Dataset):
    def __init__(self, file, data_dimensions, extended_filtering=False):
        super().__init__()
        self.data_dimenions = data_dimensions
        adata = sc.read_h5ad(file)

        log_adata = sc.pp.log1p(adata, copy=True)
        highly_var_genes = sc.pp.highly_variable_genes(
            log_adata, n_top_genes=1000, inplace=False)
        adata_filtered = adata[:, highly_var_genes['highly_variable']]
        adata_normalized = sc.pp.normalize_total(
            adata_filtered, target_sum=1, copy=True)
        filtered_data = adata_normalized.X.toarray()
        self.data_dimensions = data_dimensions
        n_cells, n_genes = filtered_data.shape
        padded_data = np.zeros((n_cells, data_dimensions),dtype=np.float32)
        padded_data[:, :n_genes] = filtered_data
        self.data = padded_data

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        arr = self.data[idx]
        return arr


class SparseCellDataset(CellDataset):
    def __init__(self, file, data_dimensions, extended_filtering=False):
        super().__init__(file, data_dimensions, extended_filtering)
        self.start_position = self.data_dimensions
        self.end_position = self.data_dimensions + 1
        self.pad_position = self.data_dimensions + 2
        self.start_value = -1
        self.end_value = -2
        self.pad_value = -3

    def __getitem__(self, idx):
        arr = self.data[idx]
        positions = torch.cat((torch.tensor([self.start_position]), torch.tensor(arr.nonzero()[
                              0]), torch.tensor([self.end_position])))
        values = torch.cat((torch.tensor([self.start_value]), torch.tensor(arr[
                           arr != 0]), torch.tensor([self.end_value])))
        return positions, values


class CellDataModule(L.LightningDataModule):
    def __init__(self, train_data_dir: str, batch_size: int = 32, image_size: int = None, val_split: float = 0.1, seed: int = 42, data_dimensions: int = 1000, input_mode: str = "scrna", extended_filtering: bool = False):
        super().__init__()
        self.train_data_dir = train_data_dir
        self.batch_size = batch_size
        self.val_split = val_split
        self.seed = seed
        self.data_dimensions = data_dimensions if image_size == None else image_size
        self.input_mode = input_mode
        self.extended_filtering = extended_filtering
        self.save_hyperparameters()

    def setup(self, stage: str = None):
        # Load the full dataset once
        full_dataset = CellDataset(self.train_data_dir, self.data_dimensions)
        total_size = len(full_dataset)
        val_size = int(total_size * self.val_split)
        train_size = total_size - val_size

        self.train_dataset, self.val_dataset = random_split(
            full_dataset, [train_size, val_size],
            generator=torch.Generator().manual_seed(self.seed)
        )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)


class SparseCellDataModule(CellDataModule):
    def setup(self, stage: str = None):
        # Load the full dataset once
        full_dataset = SparseCellDataset(
            self.train_data_dir, self.data_dimensions, self.extended_filtering)
        total_size = len(full_dataset)
        val_size = int(total_size * self.val_split)
        train_size = total_size - val_size

        self.train_dataset, self.val_dataset = random_split(
            full_dataset, [train_size, val_size],
            generator=torch.Generator().manual_seed(self.seed)
        )
        self.start_position = self.train_dataset.dataset.start_position
        self.end_position = self.train_dataset.dataset.end_position
        self.pad_position = self.train_dataset.dataset.pad_position

    def state_dict(self):
        state = {"start_position": self.start_position,
                 "end_position": self.end_position, "pad_position": self.pad_position}
        return state

    def load_state_dict(self, state_dict):
        self.start_position = state_dict["start_position"]
        self.end_position = state_dict["end_position"]
        self.pad_position = state_dict["pad_position"]

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=lambda batch: collate_sparse_batch(batch, pad_position=self.train_dataset.dataset.pad_position, pad_value=self.train_dataset.dataset.pad_value), shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=lambda batch: collate_sparse_batch(batch, pad_position=self.val_dataset.dataset.pad_position, pad_value=self.val_dataset.dataset.pad_value))


def collate_sparse_batch(batch, pad_position=-1, pad_value=-1):
    """
    Custom collate function that pads sequences to the max length in the batch
    """
    # Separate positions and labels
    positions = [item[0] for item in batch]
    values = [item[1] for item in batch]

    # Pad sequences to the same length
    padded_positions = pad_sequence(
        positions, batch_first=True, padding_value=pad_position)
    padded_values = pad_sequence(
        values, batch_first=True, padding_value=pad_value)

    return padded_positions, padded_values
