import os
import time
from pathlib import Path
from typing import Tuple, Literal, Dict, Callable, Optional

import torch
from PIL import Image
from einops import rearrange
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data import random_split, DataLoader
from torchvision import transforms

from margflow.datasets.dataset_abstracts import DatasetIdentifier
from margflow.datasets.encoders.autoencoder import JaffeEncoder, JaffeDecoder

model_types = {
    "10-dim": {
        "latent_dim": 10,
        "encoder_path": "../margflow/datasets/encoders/weights/jaffe-10dim-encoder.pt",
        "decoder_path": "../margflow/datasets/encoders/weights/jaffe-10dim-decoder.pt",
    }
}


class AnnotatedImageFolderDataset(Dataset):
    def __init__(
        self, data_dir: str | Path, labels: Dict[str, Tensor], transform: Callable | None = None
    ):
        self.data_dir = Path(data_dir)
        self.image_paths = list(self.data_dir.glob("*.png"))
        self.label_data = labels
        self.transform = transform

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx) -> Image.Image | Tuple[Image.Image, Tensor]:
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("L")
        if self.transform:
            img = self.transform(img)
        file_name = img_path.stem
        parts = file_name.split(".")
        key = f"{parts[0]}-{parts[1]}"
        labels = self.label_data[key]
        return img, labels


class ConditionalJaffe(DatasetIdentifier):
    def __init__(
        self,
        args,
        data_dir: (
            str | Path
        ) = "/home/jahn0002/ownCloud/Institution/coding/marginal_flow/experiments/data/jaffe/images/cropped",  # fixme make path relative
        encoder_model: str = "10-dim",
        sample_from_means: Literal["never", "once", "infinite"] = "never",
    ):
        super().__init__(args)
        self.encoder_model = encoder_model
        self.dataset_suffix += f"_jaffe_encoded_with_{encoder_model}"
        self.sample_from_means = sample_from_means
        self.val_split = 0.2

        if encoder_model not in model_types:
            raise ValueError(f"Model type {encoder_model} does not exist")
        self.encoder = JaffeEncoder(latent_dim=model_types[encoder_model]["latent_dim"])
        self.encoder.load_state_dict(torch.load(model_types[encoder_model]["encoder_path"]))
        self.decoder = JaffeDecoder(latent_dim=model_types[encoder_model]["latent_dim"])
        self.decoder.load_state_dict(torch.load(model_types[encoder_model]["decoder_path"]))
        self.data_dir = Path(data_dir)
        self.labels_file = self.data_dir.parent.parent / "expLabels.txt"
        self.train_dataset: Optional[AnnotatedImageFolderDataset] = None
        self.val_dataset: Optional[AnnotatedImageFolderDataset] = None

    @staticmethod
    def _load_labels_from_file(label_file: Path) -> Dict[str, Tensor]:
        result_dict = {}
        with open(label_file, "r") as f:
            for row in f.readlines()[1:]:  # skip header
                values = row.strip().split()
                pic_value = values[7]  # The 'PIC' value is the last element
                numerical_values = [float(v) for v in values[1:7]]  # Exclude 'N' and 'PIC'
                result_dict[pic_value] = torch.tensor(numerical_values)
        return result_dict

    @staticmethod
    def _image_transform():
        return transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

    def encode(self, x: Tensor, return_sample: bool) -> Tensor:
        mean, log_var = self.encoder(x)
        return self.reparameterize(mean, log_var) if return_sample else mean

    def decode(self, x: Tensor) -> Tensor:
        return self.decoder(x)

    @staticmethod
    def reparameterize(mu: Tensor, log_var: Tensor) -> Tensor:
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    @torch.no_grad()
    def encoded_data(self, overwrite=False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        jaffe_encoded_dir = Path(self.dataset_folder.parent) / "encoded_data"
        os.makedirs(jaffe_encoded_dir, exist_ok=True)
        encoded_train_data_path = (
            jaffe_encoded_dir / f"jaffe_encoded_train_data_{self.encoder_model}.pt"
        )
        encoded_train_labels_path = (
            jaffe_encoded_dir / f"jaffe_encoded_train_labels_{self.encoder_model}.pt"
        )
        encoded_val_data_path = (
            jaffe_encoded_dir / f"jaffe_encoded_val_data_{self.encoder_model}.pt"
        )
        encoded_val_labels_path = (
            jaffe_encoded_dir / f"jaffe_encoded_val_labels_{self.encoder_model}.pt"
        )

        if (
            not overwrite
            and encoded_train_data_path.is_file()
            and encoded_train_labels_path.is_file()
            and encoded_val_data_path.is_file()
            and encoded_val_labels_path.is_file()
        ):
            train_encoded = torch.load(encoded_train_data_path)
            train_labels = torch.load(encoded_train_labels_path)
            val_encoded = torch.load(encoded_val_data_path)
            val_labels = torch.load(encoded_val_labels_path)
            return train_encoded, train_labels, val_encoded, val_labels

        if self.sample_from_means == "never":
            return_sample = False
        elif self.sample_from_means == "once":
            return_sample = True
        else:
            raise NotImplementedError(
                f"Parameter sample_from_means {self.sample_from_means} is not implemented."
            )

        self.setup_datasets()
        print("Encoding JAFFE data...")
        start = time.perf_counter()
        train_images, train_labels = self.load_entire_split(train=True)
        val_images, val_labels = self.load_entire_split(train=False)
        train_encoded = self.encode(train_images, return_sample=return_sample)
        val_encoded = self.encode(val_images, return_sample=return_sample)
        end = time.perf_counter()
        print(f"Encoded train and val dataset on CPU in {end - start:.4f} seconds")

        torch.save(train_encoded, encoded_train_data_path)
        torch.save(train_labels, encoded_train_labels_path)
        torch.save(val_encoded, encoded_val_data_path)
        torch.save(val_labels, encoded_val_labels_path)

        return train_encoded, train_labels, val_encoded, val_labels

    @torch.no_grad()
    def load_dataset(
        self, overwrite: bool = False
    ) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], None]:
        train_encoded, train_labels, val_encoded, val_labels = self.encoded_data(
            overwrite=overwrite
        )
        train_encoded = rearrange(train_encoded, "b c -> b 1 c")  # 1 batch
        val_encoded = rearrange(val_encoded, "b c -> b 1 c")  # 1 batch
        return (train_encoded, train_labels.float()), (val_encoded, val_labels.float()), None

    def setup_datasets(self, stage=None):
        if not self.labels_file.exists():
            raise FileNotFoundError(f"Labels file {self.labels_file} does not exist")
        labels_dict = self._load_labels_from_file(self.labels_file)
        full_dataset = AnnotatedImageFolderDataset(
            data_dir=self.data_dir, labels=labels_dict, transform=self._image_transform()
        )
        self.train_dataset, self.val_dataset = random_split(
            full_dataset,
            [1 - self.val_split, self.val_split],
            generator=torch.Generator().manual_seed(42),
        )

    def load_entire_split(self, train: bool) -> Tuple[Tensor, Tensor]:
        dataset = self.train_dataset if train else self.val_dataset
        data_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
        dataset_tensor, labels = next(iter(data_loader))
        return dataset_tensor, labels

    def sample(self, *args, **kwargs):
        raise NotImplementedError("Jaffe dataset is only available in fixed datapoints mode")
