import numpy as np
from scipy.special import expit

import torch
from torch.utils.data import Dataset, DataLoader

import lightning.pytorch as L
from lightning.pytorch.utilities.combined_loader import CombinedLoader

from typing import Literal

def _gen(dgp: Literal['lin', 'tanh', 'sin'], rng: np.random.Generator, n: int):
    def _f(x):
        if dgp == 'lin':
            return x
        elif dgp == 'tanh':
            return 3 * np.tanh(x)
        elif dgp == 'sin':
            return 3 * np.sin(x)
        else:
            raise ValueError(f"Unknown dgp: {dgp}")

    x = rng.normal(size=(n, 3))
    z = _f(x.mean(axis=1))
    y = rng.binomial(1, expit(z))

    return {
        "x": x,
        "z": z,
        "y": y,
    }
class _Dataset(Dataset):
    def __init__(self, dataset, i):
        self.x = torch.tensor(dataset["x"][i], dtype=torch.float32)
        self.z = torch.tensor(dataset["z"][i], dtype=torch.float32)
        self.y = torch.tensor(dataset["y"][i], dtype=torch.float32)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return {
            "x": self.x[idx],
            "z": self.z[idx],
            "y": self.y[idx],
        }

class CoordinatedBinaryClassificationDataModule(L.LightningDataModule):
    def __init__(self, dgp: Literal['lin', 'tanh', 'sin'], rng: np.random.Generator, n: int, batch_size: int):
        super().__init__()
        self.batch_size = batch_size
        dataset = _gen(dgp, rng, n * 4)
        self.train_datasets = {f"data{i+1}": _Dataset(dataset, slice(i * n, (i+1) * n)) for i in range(4)}

    def train_dataloader(self):
        loaders = {name: DataLoader(ds, batch_size=self.batch_size, shuffle=True) for name, ds in self.train_datasets.items()}
        return CombinedLoader(loaders, mode="min_size")
    
    def predict_dataloader(self):
        loaders = {name: DataLoader(ds, batch_size=self.batch_size, shuffle=False) for name, ds in self.train_datasets.items()}
        return CombinedLoader(loaders, mode="sequential")

class BinaryClassificationDataModule(L.LightningDataModule):
    def __init__(self, dgp: Literal['lin', 'tanh', 'sin'], rng: np.random.Generator, n: int, batch_size: int):
        super().__init__()
        self.batch_size = batch_size
        self.n = n

        dataset = _gen(dgp, rng, n * 4)
        self.resample_indx = np.arange(0, 2 * n)

        self.dataset_orig = dataset

        self.train_dataset = _Dataset(dataset, self.resample_indx)
        self.val_dataset = _Dataset(dataset, slice(2 * n, 3 * n))
        self.predict_dataset = _Dataset(dataset, slice(3 * n, 4 * n))

    def resample_traindataset(self, rng):
        self.resample_indx = rng.choice(np.arange(0, 2 * self.n), size=2*self.n, replace=True)
        self.train_dataset = _Dataset(self.dataset_orig, self.resample_indx)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
    
    def predict_dataloader(self):
        return DataLoader(self.predict_dataset, batch_size=self.batch_size, shuffle=False)
