import numpy as np
from scipy.special import expit

import torch
from torch.utils.data import Dataset, DataLoader

import lightning.pytorch as L
from lightning.pytorch.utilities.combined_loader import CombinedLoader

from typing import Literal

class _Dataset(Dataset):
    def __init__(self, dgp: Literal['lin', 'tanh', 'sin'], rng: np.random.Generator, n: int):
        self.n = n

        def _f(x):
            if dgp == 'lin':
                return x
            elif dgp == 'tanh':
                return 3 * np.tanh(x)
            elif dgp == 'sin':
                return 3 * np.sin(x)
            else:
                raise ValueError(f"Unknown dgp: {dgp}")

        x1 = rng.normal(size=(n, 3))
        x2 = rng.normal(size=(n, 3))
        r1 = _f(x1.mean(axis=1))
        r2 = _f(x2.mean(axis=1))
        z = r1 - r2
        y = rng.binomial(1, expit(z))

        self.x1 = torch.tensor(x1, dtype=torch.float32)
        self.x2 = torch.tensor(x2, dtype=torch.float32)
        self.r1 = torch.tensor(r1, dtype=torch.float32)
        self.r2 = torch.tensor(r2, dtype=torch.float32)
        self.z = torch.tensor(z, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return {
            "x1": self.x1[idx],
            "x2": self.x2[idx],
            "r1": self.r1[idx],
            "r2": self.r2[idx],
            "z": self.z[idx],
            "y": self.y[idx],
        }

class PreferenceDataModule(L.LightningDataModule):
    def __init__(self, dgp: Literal['lin', 'tanh', 'sin'], rng: np.random.Generator, n: int, batch_size: int):
        super().__init__()
        self.batch_size = batch_size
        self.train_datasets = {f"data{i+1}": _Dataset(dgp, rng, n) for i in range(3)}
        self.val_datasets = {f"data{i+1}": _Dataset(dgp, rng, n) for i in range(3)}

    def train_dataloader(self):
        loaders = {name: DataLoader(ds, batch_size=self.batch_size, shuffle=True) for name, ds in self.train_datasets.items()}
        return CombinedLoader(loaders, mode="min_size")

    def val_dataloader(self):
        loaders = {name: DataLoader(ds, batch_size=self.batch_size, shuffle=False) for name, ds in self.val_datasets.items()}
        return CombinedLoader(loaders, mode="min_size")
    
    def predict_dataloader(self):
        loaders = {name: DataLoader(ds, batch_size=self.batch_size, shuffle=False) for name, ds in self.train_datasets.items()}
        return CombinedLoader(loaders, mode="sequential")


