"""
ImageNet Latent Dataset with safetensors.
"""

import re
import os
import numpy as np
from glob import glob
from tqdm import tqdm

import torch
from torch.utils.data import Dataset

from safetensors import safe_open


class ImgLatentDataset(Dataset):
    def __init__(self, data_dir, latent_norm=True, latent_multiplier=1.0):
        self.data_dir = data_dir
        self.latent_norm = latent_norm
        self.latent_multiplier = latent_multiplier

        self.files = sorted(glob(os.path.join(data_dir, "*.safetensors")))
        self.img_to_file_map = self.get_img_to_safefile_map()

        if latent_norm:
            self._latent_mean, self._latent_std = self.get_latent_stats()
        else:
            self._latent_mean, self._latent_std = torch.tensor(0), torch.tensor(1)

    def get_img_to_safefile_map(self):
        img_to_file = {}
        for safe_file in self.files:
            with safe_open(safe_file, framework="pt", device="cpu") as f:
                labels = f.get_slice("labels")
                labels_shape = labels.get_shape()
                num_imgs = labels_shape[0]
                cur_len = len(img_to_file)
                for i in range(num_imgs):
                    img_to_file[cur_len + i] = {
                        "safe_file": safe_file,
                        "idx_in_file": i,
                    }
        return img_to_file

    def get_latent_stats(self):
        pattern = r"/data/([^/]+)/[^_]+_(\d+)$"
        match = re.search(pattern, self.data_dir)
        vae_name, res = match.group(1), match.group(2)
        latent_stats_cache_file = os.path.join(
            "./buffers/vaes/stat", f"{vae_name}_{res}.pt"
        )

        if not os.path.exists(latent_stats_cache_file):
            latent_stats = self.compute_latent_stats()
            torch.save(latent_stats, latent_stats_cache_file)
        else:
            latent_stats = torch.load(latent_stats_cache_file)
        return latent_stats["mean"], latent_stats["std"]

    def compute_latent_stats(self):
        num_samples = min(10000, len(self.img_to_file_map))
        random_indices = np.random.choice(
            len(self.img_to_file_map), num_samples, replace=False
        )
        latents = []
        for idx in tqdm(random_indices):
            img_info = self.img_to_file_map[idx]
            safe_file, img_idx = img_info["safe_file"], img_info["idx_in_file"]
            with safe_open(safe_file, framework="pt", device="cpu") as f:
                features = f.get_slice("latents")
                feature = features[img_idx : img_idx + 1]
                latents.append(feature)
        latents = torch.cat(latents, dim=0)
        mean = latents.mean(dim=[0, 2, 3], keepdim=True)
        std = latents.std(dim=[0, 2, 3], keepdim=True)
        latent_stats = {"mean": mean, "std": std}
        print(latent_stats)
        return latent_stats

    def __len__(self):
        return len(self.img_to_file_map.keys())

    def __getitem__(self, idx):
        img_info = self.img_to_file_map[idx]
        safe_file, img_idx = img_info["safe_file"], img_info["idx_in_file"]
        with safe_open(safe_file, framework="pt", device="cpu") as f:
            tensor_key = "latents" if np.random.uniform(0, 1) > 0.5 else "latents_flip"
            features = f.get_slice(tensor_key)
            labels = f.get_slice("labels")
            feature = features[img_idx : img_idx + 1]
            label = labels[img_idx : img_idx + 1]

        if self.latent_norm:
            feature = (feature - self._latent_mean) / self._latent_std
        feature = feature * self.latent_multiplier

        # remove the first batch dimension (=1) kept by get_slice()
        feature = feature.squeeze(0)
        label = label.squeeze(0)
        return feature, label
