# file: user_extensions/datasets/mnist.py
import gzip
from pathlib import Path
import struct

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, random_split
from torchvision import transforms

from prism.core.registry import DATASETS
from prism.data.base_datamodule import BaseDataModule

MORPHO_FEATURES_TO_USE = ['length', 'thickness', 'slant', 'width', 'height']
MORPHO_TENSOR_INDEX_MAP = {name: i for i, name in enumerate(MORPHO_FEATURES_TO_USE)}


class _MNISTWithMorpho(Dataset):
    def __init__(self, data_dir, train=True, transform=None, use_morpho=False):
        self.transform = transform
        self.use_morpho = use_morpho
        prefix = 'train' if train else 't10k'
        data_path = Path(data_dir)

        images_gz = data_path / f'{prefix}-images-idx3-ubyte.gz'
        labels_gz = data_path / f'{prefix}-labels-idx1-ubyte.gz'

        images_path = self._decompress_if_needed(str(images_gz))
        labels_path = self._decompress_if_needed(str(labels_gz))

        self.images = self._load_idx_file(images_path)
        self.labels = self._load_idx_file(labels_path)

        if self.use_morpho:
            morpho_csv = data_path / f'{prefix}-morpho.csv'
            if not morpho_csv.exists():
                raise FileNotFoundError(f"Morpho file not found: {morpho_csv}")
            self.morpho = pd.read_csv(morpho_csv)[MORPHO_FEATURES_TO_USE].values.astype(np.float32)
            assert len(self.labels) == len(self.morpho)
        else:
            num_style_features = len(MORPHO_FEATURES_TO_USE)
            self.morpho = np.zeros((len(self.labels), num_style_features), dtype=np.float32)

        assert len(self.images) == len(self.labels)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        digit_class = self.labels[idx]
        morpho_features = self.morpho[idx]

        pil_img = transforms.ToPILImage()(image)
        if self.transform:
            pil_img = self.transform(pil_img)

        return (
            pil_img,
            torch.tensor(digit_class, dtype=torch.long),
            torch.tensor(morpho_features, dtype=torch.float)
        )

    def _load_idx_file(self, path):
        with open(path, 'rb') as f:
            data = f.read()
        magic, num_items = struct.unpack_from('>II', data, 0)
        if magic == 2051:
            num_rows, num_cols = struct.unpack_from('>II', data, 8)
            shape = (num_items, num_rows, num_cols)
            offset = 16
        elif magic == 2049:
            shape = (num_items,)
            offset = 8
        else:
            raise ValueError(f"Unknown IDX magic number: {magic}")
        return np.frombuffer(data, dtype=np.uint8, offset=offset).reshape(shape)

    def _decompress_if_needed(self, gz_path_str):
        gz_path = Path(gz_path_str)
        raw_path = gz_path.with_suffix('')
        if raw_path.exists():
            return str(raw_path)
        if gz_path.exists():
            with gzip.open(gz_path, 'rb') as f_in, open(raw_path, 'wb') as f_out:
                f_out.write(f_in.read())
            return str(raw_path)
        raise FileNotFoundError(f"File not found: {raw_path} or {gz_path}")


@DATASETS.register("mnist")
class MNISTDataModule(BaseDataModule):
    def __init__(self, config):
        super().__init__(config)
        self.split_size = self.data_config.val_split_size
        self.use_morpho = False

    def prepare_data(self):
        _MNISTWithMorpho(self.data_dir, train=True, use_morpho=self.use_morpho)
        _MNISTWithMorpho(self.data_dir, train=False, use_morpho=self.use_morpho)

    def setup(self, stage=None):
        if not self.test_ds:
            self.test_ds = _MNISTWithMorpho(self.data_dir, train=False, transform=self.transform, use_morpho=self.use_morpho)

        if stage in ('fit', None) and not (self.train_ds and self.val_ds):
            full_train = _MNISTWithMorpho(self.data_dir, train=True, transform=self.transform, use_morpho=self.use_morpho)
            train_size = len(full_train) - self.split_size
            self.train_ds, self.val_ds = random_split(
                full_train, [train_size, self.split_size],
                generator=torch.Generator().manual_seed(42)
            )

    @property
    def style_feature_map(self):
        return MORPHO_TENSOR_INDEX_MAP


@DATASETS.register("morpho-mnist")
class MorphoMNISTDataModule(MNISTDataModule):
    def __init__(self, config):
        super().__init__(config)
        self.use_morpho = True