import os
import time
from typing import Tuple, Literal

import torch
from einops import rearrange, repeat
from numpy.typing import NDArray
from torch import Tensor
from torch.nn.functional import one_hot
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from margflow.datasets.dataset_abstracts import DatasetIdentifier
from margflow.datasets.encoders.autoencoder import ResNetEncoder, ResNetDecoder

model_types = {
    "2-dim": {
        "latent_dim": 2,
        "encoder_path": "../margflow/datasets/encoders/weights/autoencoder-2dim-encoder.pt",
        "decoder_path": "../margflow/datasets/encoders/weights/autoencoder-2dim-decoder.pt",
    },
    "20-dim": {
        "latent_dim": 20,
        "encoder_path": "../margflow/datasets/encoders/weights/autoencoder-20dim-encoder.pt",
        "decoder_path": "../margflow/datasets/encoders/weights/autoencoder-20dim-decoder.pt",
    },
}


class MNIST(DatasetIdentifier):
    def __init__(
        self,
        args,
        encoder_model: str | None = "2-dim",
        filter_class: int | None = None,
        sample_from_means: Literal["never", "once", "infinite"] = "never",
    ):
        super().__init__(args)
        self.encoder_model = encoder_model
        self.dataset_suffix += f"_mnist_encoded_with_{encoder_model}"
        assert isinstance(filter_class, int) or filter_class is None
        self.filter_class = filter_class
        self.sample_from_means = sample_from_means
        self.no_classes = 10 if filter_class is None else 1
        if encoder_model is not None:
            if encoder_model not in model_types:
                raise ValueError(f"Model type {encoder_model} does not exist")
            self.encoder = ResNetEncoder(latent_dim=model_types[encoder_model]["latent_dim"])
            self.encoder.load_state_dict(torch.load(model_types[encoder_model]["encoder_path"]))
            self.decoder = ResNetDecoder(latent_dim=model_types[encoder_model]["latent_dim"])
            self.decoder.load_state_dict(torch.load(model_types[encoder_model]["decoder_path"]))
        if encoder_model is None and sample_from_means != "never":
            raise ValueError(
                f"For sample_from_means {sample_from_means}, you need to specify an encoder model."
            )

    def encode(self, x: Tensor, return_sample: bool) -> Tensor:
        mean, log_var = self.encoder(x)
        return self.reparameterize(mean, log_var) if return_sample else mean

    def decode(self, x: Tensor) -> Tensor:
        return self.decoder(x)

    @staticmethod
    def reparameterize(mu: Tensor, log_var: Tensor) -> Tensor:
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    @torch.no_grad()
    def encode_data(self, overwrite=False):
        mnist_encoded_dir = self.dataset_folder.parent / "encoded_data/"
        os.makedirs(mnist_encoded_dir, exist_ok=True)
        filter_class = "" if self.filter_class is None else f"_class{self.filter_class}"
        encoded_train_data_path = (
            mnist_encoded_dir / f"mnist_encoded_train_data_{self.encoder_model}{filter_class}.pt"
        )
        encoded_val_data_path = (
            mnist_encoded_dir / f"mnist_encoded_val_data_{self.encoder_model}{filter_class}.pt"
        )
        if encoded_train_data_path.is_file() and encoded_val_data_path.is_file():
            train_dataset = torch.load(encoded_train_data_path)
            val_dataset = torch.load(encoded_val_data_path)
        else:
            train_dataset, _ = self.load_entire_split(train=True, filter_class=self.filter_class)
            val_dataset, _ = self.load_entire_split(train=False, filter_class=self.filter_class)
            print("Encoding data...")
            start = time.perf_counter()
            if self.sample_from_means == "never":
                train_dataset = self.encode(train_dataset, return_sample=False)
                val_dataset = self.encode(val_dataset, return_sample=False)
            elif self.sample_from_means == "once":
                train_dataset = self.encode(train_dataset, return_sample=True)
                val_dataset = self.encode(val_dataset, return_sample=True)
            else:
                raise NotImplementedError(
                    f"Parameter sample_from_means {self.sample_from_means} is not implemented."
                )
            end = time.perf_counter()
            print(f"Encoded train and val dataset on CPU in {end - start:.4f} seconds")
            torch.save(train_dataset, encoded_train_data_path)
            torch.save(val_dataset, encoded_val_data_path)

        return train_dataset, val_dataset

    @torch.no_grad()
    def load_dataset(self, overwrite=False) -> Tuple[NDArray, NDArray, NDArray | None]:
        if self.encoder_model is not None:
            train_dataset, val_dataset = self.encode_data(overwrite=overwrite)
        else:
            train_dataset, _ = self.load_entire_split(train=True, filter_class=self.filter_class)
            val_dataset, _ = self.load_entire_split(train=False, filter_class=self.filter_class)
            train_dataset = rearrange(train_dataset, "b c h w -> b (c h w)")
            val_dataset = rearrange(val_dataset, "b c h w -> b (c h w)")
        return train_dataset.numpy(), val_dataset.numpy(), None

    def load_entire_split(self, train: bool, filter_class=None) -> Tuple[Tensor, Tensor]:
        transform = transforms.ToTensor()
        dataset = datasets.MNIST(
            root=self.dataset_folder.parent, train=train, transform=transform, download=True
        )
        data_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
        dataset_tensor, labels = next(iter(data_loader))
        if filter_class is not None:
            class_idx = labels == filter_class
            dataset_tensor = dataset_tensor[class_idx]
        return dataset_tensor, labels

    def sample(self, *args, **kwargs):
        raise NotImplementedError("MNIST dataset is only available in fixed datapoints mode")


class ConditionalMNIST(MNIST):
    def __init__(
        self,
        args,
        encoder_model: str | None = "2-dim",
        sample_from_means: Literal["never", "once", "infinite"] = "never",
        one_hot_: bool = False,
    ):
        super().__init__(
            args=args,
            encoder_model=encoder_model,
            filter_class=None,
            sample_from_means=sample_from_means,
        )
        self.one_hot = one_hot_
        if self.encoder_model is None:
            raise NotImplementedError

    def create_stratified_batches(
        self, dataset: Tensor, labels: Tensor, samples_per_class: int, no_batches: int
    ) -> Tuple[Tensor, Tensor]:
        if samples_per_class % no_batches != 0:
            raise ValueError("samples_per_class must be divisible by no_batches")

        data_by_digit = [dataset[labels == i] for i in range(self.no_classes)]

        min_lengths = min(samples.shape[0] for samples in data_by_digit)
        if min_lengths < samples_per_class:
            raise ValueError(
                f"Not all classes have enough data points: samples_per_class {samples_per_class}"
            )

        data_by_digit = [v[:samples_per_class] for v in data_by_digit]
        data_by_digit = [rearrange(v, "(a b) d -> a b d", b=no_batches) for v in data_by_digit]
        train_data = torch.cat(data_by_digit, dim=0)
        out_dim = "" if self.one_hot else 1
        labels_for_batches = repeat(
            torch.arange(0, self.no_classes),
            f"n -> (n r) {out_dim}",
            r=samples_per_class // no_batches,
        )
        if self.one_hot:
            labels_for_batches = one_hot(labels_for_batches.long())

        return train_data, labels_for_batches

    @torch.no_grad()
    def load_dataset(
        self, overwrite=False
    ) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor] | None]:
        train_dataset, val_dataset = self.encode_data(overwrite=overwrite)
        _, train_labels = self.load_entire_split(train=True)
        _, val_labels = self.load_entire_split(train=False)

        # empiric values for MNIST, e.g., one digit only has 5421 training samples
        train_strat_data, train_strat_labels = self.create_stratified_batches(
            train_dataset, train_labels, 5400, 135
        )
        val_strat_data, val_strat_labels = self.create_stratified_batches(
            val_dataset, val_labels, 880, 4
        )

        return (
            (train_strat_data, train_strat_labels.float()),
            (val_strat_data, val_strat_labels.float()),
            None,
        )
