from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional, Union, cast
import hashlib
import functools

import numpy as np
import torch
from torchvision import transforms as torch_transforms
from torchvision.transforms import functional
from torch.utils.data.dataset import Dataset
from PIL import Image

from vis_datasets.lib.dirs import get_dataset_dir
from vis_datasets.wrappers.data_sample import DataSample
from .transforms import TRANSFORMS_WITH_NONE
from .image_utils import (
    resize,
    calc_top_left_coordinates,
    paste_fg_on_bg,
    crop_image_to_square,
    load_image,
)


OBJECTS = list(range(61))

def _get_default_fg_dir() -> Path:
    return get_dataset_dir("objects_2d") / "foreground_images"

TransformOptions = Union[list[str], list[torch.nn.Module]]
CompositeTransform = Union[TransformOptions, dict[int, TransformOptions]]

@dataclass
class ImageForegroundsConfig:
    fg_class_indices: list[int]
    transforms: CompositeTransform
    img_size: int
    fg_dir: Path = field(default_factory=_get_default_fg_dir)

@dataclass
class RandomForegroundsConfig:
    n_classes: int
    transforms: CompositeTransform
    seed: int
    img_size: int

FGConfig = Union[ImageForegroundsConfig, RandomForegroundsConfig]

class ForegroundsDataset(Dataset[DataSample[Image.Image]]):
    
    def __init__(
        self,
        config: FGConfig,
        n_samples: int,
        sampling_seed: int,
    ) -> None:
        # self.config = config
        self.n_samples = n_samples

        torch.random.manual_seed(sampling_seed)
        self.rng = np.random.default_rng(sampling_seed)
        self.class_salt = self.rng.integers(low=0, high=10^6)

        if isinstance(config, RandomForegroundsConfig):
            self.n_classes = config.n_classes
            self.fgs = generate_random_foregrounds(
                config,
                img_size=config.img_size,
            )
            self._filter_foregrounds()
            # if config.n_backgrounds is not None:
            #     self.bgs = generate_random_backgrounds(
            #         config, config.img_size,
            #     )
            fg_class_indices = list(range(self.n_classes))
        else:
            fg_class_indices = config.fg_class_indices
            fg_dir = config.fg_dir
            self.n_classes = len(fg_class_indices)
            self.fgs = load_foregrounds(fg_dir, config.img_size)
            if len(fg_class_indices) > len(self.fgs):
                raise ValueError(
                    f"Number of classes ({self.n_classes}) greater than "
                    f"the number of foreground categories ({len(self.fgs)})"
                )
            self._filter_foregrounds(fg_class_indices)

        self._init_transforms(
            config=config,
            fg_class_indices=fg_class_indices,
        )

    def _init_transforms(
        self,
        config: FGConfig,
        fg_class_indices: list[int],
    ) -> None:
        # TODO: test this
        if isinstance(config.transforms, dict):
            # self.fg_transforms = {
            #     class_idx: ForegroundsDataset._compose_transforms(
            #         class_transforms, config.img_size,
            #     )
            #     for class_idx, class_transforms
            #     in enumerate(transforms.items())
            # }
            self.fg_transforms = {
                i: ForegroundsDataset._compose_transforms(
                    config.transforms[class_idx], config.img_size,
                )
                for i, class_idx in enumerate(fg_class_indices)
            }
        else:
            self.fg_transforms = ForegroundsDataset._compose_transforms(
                config.transforms, config.img_size
            )

    @staticmethod
    def _compose_transforms(
        config_transforms: TransformOptions, img_size: int
    ) -> torch_transforms.Compose:
        fg_transforms = [
            TRANSFORMS_WITH_NONE[transform](img_size)
            if isinstance(transform, str)
            else transform
            for transform in config_transforms
        ]
        return torch_transforms.Compose(fg_transforms)

    def _filter_foregrounds(
        self, fg_class_indices: Optional[list[int]] = None,
    ) -> None:
        fg_key_list = list(self.fgs.keys())
        if fg_class_indices is not None:
            filtered_fg_classes = [
                fg_key_list[idx] for idx in fg_class_indices
            ]
            filtered_fgs = {
                fg_class: self.fgs[fg_class] for fg_class in filtered_fg_classes
            }
        else:
            filtered_fgs = self.fgs
        self.fgs = filtered_fgs
        self.fg_classes: list[str] = list(filtered_fgs.keys())
        self.fg_imgs: list[list[ImageRecord]] = list(filtered_fgs.values())
        
    def __len__(self) -> int:
        # return len(self.fgs)
        return self.n_samples

    def __getitem__(
        self,
        idx: int,
        is_class_idx: bool = False,
    ) -> DataSample[Image.Image]:
        if is_class_idx:
            class_idx = idx
        else:
            class_idx = _get_item_idx(idx, self.n_classes, self.class_salt)
        fg_img = self.fg_imgs[class_idx][0].image

        # Apply transformations
        if isinstance(self.fg_transforms, dict):
            class_transforms = self.fg_transforms[class_idx]
            fg_aug = class_transforms(fg_img)
        else:
            fg_aug = self.fg_transforms(fg_img)

        return DataSample(
            input=fg_aug,
            target=torch.tensor(class_idx, requires_grad=False),
            # _class_labels=self.fg_classes,
        )

@dataclass
class ImageRecord:
    image: Image.Image
    # path: Path

@functools.cache
def load_foregrounds(
    fg_dir: Union[str, Path],
    img_size: int,
) -> dict[str, list[ImageRecord]]:
    """Loads foregrounds from a directory.

    Args:
    foregrounds_dir: path to directory containing foregrounds.
        Directory of the form `foregrounds_dir`/$OBJECT_CLASS/$FILE_NAME.
    """
    foregrounds_dir = Path(fg_dir)
    if not foregrounds_dir.exists():
        raise ValueError(
            f"Foregrounds directory {foregrounds_dir} does not exist."
        )

    # self.fg_classes = []
    fgs: dict[str, list[ImageRecord]] = {}
    for fg_dir in sorted(list(foregrounds_dir.iterdir())):
        # e.g. 'car', 'cow'
        fg_class = str(fg_dir.name)
        fg_files = sorted(list(fg_dir.iterdir()))
        # We only keep one foreground object per class.
        # This makes sure we're not accidentally picking the wrong one
        assert len(fg_files) == 1
        fgs[fg_class] = [
            ImageRecord(
                image=resize(
                    load_image(fg_file),#.convert("F"),
                    img_size,
                ),
            )
            for fg_file in fg_files
        ]
    print(f"{len(fgs)} foregrounds loaded.")
    return fgs

def generate_random_foregrounds(
    config: RandomForegroundsConfig,
    img_size: int,
) -> dict[str, list[ImageRecord]]:
    class_imgs: dict[str, list[ImageRecord]] = {}
    rng = np.random.default_rng(config.seed)
    for i in range(config.n_classes):
        # img = Image.new("RGB", (img_size, img_size))
        # pixels = rng.integers(0, 255, (img_size, img_size, 4), dtype=np.uint8)
        pixels = rng.random((img_size, img_size, 4), dtype=np.float32)
        img = Image.fromarray(pixels, "RGBA")
        class_imgs[str(i)] = [
            ImageRecord(image=img)
        ]
    return class_imgs

def _get_item_idx(idx: int, n_items: int, salt: int) -> int:
    """ Maps sample indices to item indices in a shuffled but
    deterministic way """
    return (
        int.from_bytes(
            hashlib.sha256(bytes(idx + salt)).digest(),
            byteorder="big",
            signed=False,
        )
        % n_items
    )


@dataclass
class ImageBackgroundsConfig:
    img_size: int
    bg_dir: Path = field(default_factory=lambda: (
        get_dataset_dir("objects_2d") / "background_images"
    ))

@dataclass
class RandomBackgroundsConfig:
    n_backgrounds: int
    seed: int
    img_size: int

@dataclass
class UniformBackgroundsConfig:
    color: list[float]
    img_size: int

BGConfig = Union[
    ImageBackgroundsConfig,
    RandomBackgroundsConfig,
    UniformBackgroundsConfig,
]

class BackgroundsDataset(Dataset[DataSample[Image.Image]]):

    def __init__(
        self,
        config: BGConfig,
        n_samples: int,
        sampling_seed: int,
    ) -> None:
        self.n_samples = n_samples

        torch.random.manual_seed(sampling_seed)
        self.rng = np.random.default_rng(sampling_seed)
        self.bgs_salt = self.rng.integers(low=0, high=10^6)

        if isinstance(config, RandomBackgroundsConfig):
            self.bgs = generate_random_backgrounds(config)
        elif isinstance(config, ImageBackgroundsConfig):
            self.bgs = load_backgrounds(config.bg_dir, config.img_size)
        elif isinstance(config, UniformBackgroundsConfig):
            self.bgs = generate_uniform_background(config)
        else:
            raise ValueError("Invalid background config provided")


    def __len__(self) -> int:
        # return len(self.fgs)
        return self.n_samples

    def __getitem__(self, idx: int) -> DataSample:
        bg_idx = _get_item_idx(idx, len(self.bgs), self.bgs_salt)
        # bg_idx = self.rng.integers(0, len(self.bgs))
        bg_img = self.bgs[bg_idx].image.copy()

        return DataSample(
            input=bg_img,
            target=torch.tensor(bg_idx, requires_grad=False),
            # _class_labels=self.fg_classes,
        )

def generate_random_backgrounds(
    config: RandomBackgroundsConfig,
) -> list[ImageRecord]:
    if config.n_backgrounds is None or config.n_backgrounds <= 0:
        raise ValueError("No valid random background config provided.")
    rng = np.random.default_rng(config.seed + 73)
    backgrounds: list[ImageRecord] = []
    for _ in range(config.n_backgrounds):
        # img = Image.new("RGB", (img_size, img_size))
        # pixels = rng.integers(
        pixels = rng.random(
            (config.img_size, config.img_size, 3),
            dtype=np.float32,
        )
        img = Image.fromarray(pixels, "RGB")
        backgrounds.append(ImageRecord(image=img))
    return backgrounds

@functools.cache
def load_backgrounds(
    bg_dir: Union[str, Path],
    img_size: int,
    # config: ImageBackgroundsConfig,
) -> list[ImageRecord]:
    """Loads backgrounds from a directory.

    Args:
    backgrounds_dir: path to directory containing foregrounds.
        Dir of the form `backrounds_dir`/$BACKGROUND_TYPE/$FILE_NAME.

    Produces:
    bgs: a list of the form [bg0, bg1, ...] where the backgrounds
        are `PIL.Image.Image`s.
    """
    backgrounds_dir = Path(bg_dir)
    if not backgrounds_dir.exists():
        raise ValueError(
            f"Backgrounds directory {backgrounds_dir} does not exist."
        )
    bgs: list[ImageRecord] = []
    for bg_file in backgrounds_dir.iterdir():
        bg_img = resize(
            crop_image_to_square(
                load_image(bg_file)#.convert("F")
            ),
            img_size,
        )
        bgs.append(ImageRecord(
            image=bg_img,
            # path=bg_file,
        ))
    print(f"{len(bgs)} backgrounds loaded.")
    return bgs

def generate_uniform_background(
    config: UniformBackgroundsConfig,
) -> list[ImageRecord]:
    if config.color != [0, 0, 0]:
        raise NotImplementedError(
            "Only black backgrounds implemented yet."
        )
    pixels = np.zeros(
        (config.img_size, config.img_size, 3), dtype=np.int8
    )
    img = Image.fromarray(pixels, "RGB")
    return [ImageRecord(image=img)]


@dataclass
class Composite2DDatasetConfig:
    img_size: int# = 32
    sampling_seed: int
    normalize: bool = True
    class_from_foreground: bool = True
    # The probability with which to equalize foreground and background classes
    fg_bg_correlation: float = 0.0
    fg_availability: float = 1.0

class Composite2DDataset(Dataset[DataSample[torch.Tensor]]):

    def __init__(
        self,
        config: Composite2DDatasetConfig,
        foregrounds: Dataset,
        backgrounds: Dataset,
    ) -> None:
        assert config.fg_availability == 1.0 or not config.class_from_foreground
        self.config = config
        self.foregrounds = foregrounds
        self.backgrounds = backgrounds

        self.n_samples = min(len(foregrounds), len(backgrounds))
        # Used for sampling correlated foreground objects
        self.rng = np.random.default_rng(self.config.sampling_seed + 583)

        self._init_transforms()

    def _init_transforms(self) -> None:
        img_transforms: list[Callable] = [
            torch_transforms.ToTensor(),
            # transforms.Lambda(
            #     lambda img: functional.pil_to_tensor(img).float()
            # ),
        ]
        if self.config.normalize:
            img_transforms.append(
                torch_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            )
        self.img_transforms = torch_transforms.Compose(img_transforms)

    def __len__(self) -> int:
        # return len(self.fgs)
        return self.n_samples

    def __getitem__(self, idx: int) -> DataSample[torch.Tensor]:
        bg_instance = self.backgrounds[idx]
        if isinstance(bg_instance, DataSample):
            bg_img, bg_class = bg_instance.input, bg_instance.target
        else:
            # It's a tuple
            bg_img, bg_class = bg_instance
        bg_img = bg_img.copy()

        if self.config.fg_availability < 1.0:
            fg_availability_prop = self.rng.random()
            show_fg = fg_availability_prop < self.config.fg_availability
        else:
            show_fg = True

        if show_fg:
            if self.config.fg_bg_correlation > 0:
                fg_bg_sync_prop = self.rng.random()
                sync_fg_bg = fg_bg_sync_prop < self.config.fg_bg_correlation
                if sync_fg_bg:
                    # Sample a foreground object with the same class as the
                    # background object
                    fg_instance = self.foregrounds[bg_class, True]
                else:
                    fg_instance = self.foregrounds[idx]
            else:
                fg_instance = self.foregrounds[idx]
            if isinstance(fg_instance, DataSample):
                fg_img, fg_class = fg_instance.input, fg_instance.target
            else:
                # It's a tuple
                fg_img, fg_class = fg_instance

            fg_obj = resize(fg_img, self.config.img_size)
            start_x, start_y = calc_top_left_coordinates(
                fg_obj, self.config.img_size, 0.5, 0.5,
            )
            img = paste_fg_on_bg(fg_obj, bg_img, start_x, start_y)
            class_idx = fg_class if self.config.class_from_foreground else bg_class

        else:
            img = bg_img
            class_idx = bg_class

        x = self.img_transforms(img).float()
        target = (
            class_idx
            if isinstance(class_idx, torch.Tensor)
            else torch.tensor(class_idx, requires_grad=False)
        )
        return DataSample(
            input=x,
            target=target,
        )
