import json
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from torchvision import datasets, transforms
import torch
import pandas as pd
from .util import scale_dataframe, flatten_image, flatten_image_maintain_dim


def get_dataloaders_for_tabular(
    batch_size=64,
    normalize: bool = True,
    separate_sign_normalization: bool = True,
    path_to_data: str = None,
    maintain_dimensionality: bool = None,
    device: torch.device = None,
    tuning: bool = False,
) -> list[DataLoader]:
    """
    Get train and test DataLoaders for a pre-processed tabular dataset.
    """

    data = TabularDataset(
        path_to_data,
        transform=TabularDatasetTransform() if normalize else None,
        normalize=normalize,
        seperate_sign_normalization=separate_sign_normalization,
        maintain_dimensionality=maintain_dimensionality,
        device=device,
    )

    # NOTE: Since the the variance in the first results is fairly small, we
    # can also make the experiments sensible to the splitter seed.
    # NOTE: !! To avoid leakage between tuning and benchmarking a fixed seed
    # is required.
    generator = torch.Generator().manual_seed(42)
    if tuning:
        # 0.8 train, 0.1 val, 0.1 test (latter not used for tuning)
        split_list = [0.8, 0.1, 0.1]
    else:
        # 0.9 train, 0.1 test
        split_list = [0.9, 0.1]
    data_splits = random_split(data, split_list, generator=generator)

    train_loader = DataLoader(data_splits[0], batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(data_splits[1], batch_size=batch_size, shuffle=True)
    return train_loader, test_loader


def get_inference_dataloader(
    batch_size: int = 64,
    normalize: bool = True,
    separate_sign_normalization: bool = True,
    path_to_data: str = None,
) -> DataLoader:
    """
    Return dataloader for inference without shuffle and random split.
    """

    data = TabularDataset(
        path_to_data,
        transform=TabularDatasetTransform() if normalize else None,
        normalize=normalize,
        seperate_sign_normalization=separate_sign_normalization,
    )
    return DataLoader(data, batch_size, shuffle=False)


def get_flattened_mnist_dataloaders(
    batch_size: int = 128,
    size_share: float = 1.0,
    maintain_dimensionality: bool = False,
    path_to_data: str = "./data",
) -> list[DataLoader]:
    """
    Gets the test and train dataloaders for the MNIST dataset without labels.
    Expecting size 32x32.
    """

    if maintain_dimensionality:
        all_transforms = transforms.Compose(
            [
                transforms.Resize(32),
                transforms.ToTensor(),
                transforms.Lambda(flatten_image_maintain_dim),
            ]
        )
    else:
        all_transforms = transforms.Compose(
            [
                transforms.Resize(32),
                transforms.ToTensor(),
                transforms.Lambda(flatten_image),
            ]
        )

    train_data = ImageOnlyMNIST(
        path_to_data,
        train=True,
        transform=all_transforms,
        target_transform=None,
    )
    test_data = ImageOnlyMNIST(
        path_to_data,
        train=False,
        transform=all_transforms,
        target_transform=None,
    )
    if size_share < 1:
        train_data = Subset(train_data, range(int(size_share * len(train_data))))
        test_data = Subset(test_data, range(int(size_share * len(test_data))))

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader


# Dataset classes
class ImageOnlyMNIST(Dataset):
    def __init__(
        self, root, train=True, transform=None, target_transform=None
    ) -> datasets.MNIST:
        self.data = datasets.MNIST(
            root,
            train=train,
            transform=transform,
            target_transform=target_transform,
            download=True,
        )

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index][0]


class TabularDataset(Dataset):
    "Arbitrary tabular dataset."

    def __init__(
        self,
        path_to_data: str,
        transform=None,
        normalize: bool = True,
        seperate_sign_normalization: bool = True,
        maintain_dimensionality: bool = False,
        device: torch.device = None,
    ):
        if path_to_data.endswith("processed.csv"):
            config_path = path_to_data[: -len("processed.csv")] + "conf.json"
        else:
            raise Exception("Data path does not end in processed.csv")
        with open(config_path, "r") as fp:
            target = json.load(fp)["target"]
        if path_to_data.endswith(".csv"):
            self.df = pd.read_csv(path_to_data, header=0).drop(target, axis=1)
        elif path_to_data.endswith(".parquet"):
            self.df = pd.read_parquet(path_to_data, header=0)
        self.transform = transform
        self.normalized = normalize
        self.seperate_sign_normalization = seperate_sign_normalization
        self.maintain_dimensionality = maintain_dimensionality
        self.device = device

        if self.transform:
            self.df = self.transform(
                self.df, self.normalized, self.seperate_sign_normalization
            )

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if self.maintain_dimensionality:
            return torch.tensor(
                self.df.iloc[idx].values,
                dtype=torch.float32,
                device=self.device,
            ).view(1, 1, -1)
        else:
            return torch.tensor(
                self.df.iloc[idx].values,
                dtype=torch.float32,
                device=self.device,
            )


# Dataset transformation classes
class TabularDatasetTransform:
    def __call__(self, df: pd.DataFrame, normalize: bool, seperate: bool):
        """
        The dataset can optionally be normalized to the [-1,1]
        scale. This is important to measure the reconstruction loss after
        the last layers TanH activation.
        """

        if seperate and not normalize:
            raise RuntimeWarning(
                "The will be no seperate scaling for positive and negative"
                " values unless the normalize parameter is positive!"
            )

        filtered_df = df.dropna()
        if normalize:
            filtered_df = scale_dataframe(filtered_df, seperate=seperate)

        return filtered_df
