import os
import random
from typing import Any, Iterable

import numpy as np
import torch
import torchvision.transforms.v2 as transforms
from torchvision.datasets.folder import ImageFolder
from torchvision.transforms import functional as tF

from dae.utils.generic_utils import TaskState
from dae.utils.torch_utils import to_torch_type

NUM_WORKERS = 10
PIN_MEORY = True


####################################################################
# Data utils
####################################################################

INTERPOLATIONS = {
    "lanczos": transforms.InterpolationMode.LANCZOS,
    "bilinear": transforms.InterpolationMode.BILINEAR,
    "nearest": transforms.InterpolationMode.NEAREST,
    "bicubic": transforms.InterpolationMode.BICUBIC,
}


def _image_normalize_transform(cfg, mean=None, std=None, mid_transforms=None):
    if not isinstance(mean, Iterable):
        mean = [mean]
    if not isinstance(std, Iterable):
        std = [std]

    return transforms.Compose(
        [
            transforms.ToImage(),
            transforms.ToDtype(to_torch_type(cfg.dtype), scale=True),
            *(mid_transforms or []),
            transforms.Normalize(mean, std),
        ]
    )


def _image_resize_and_random_aug(is_train_split, im_size, resize=True, rand_crop=False, h_flip=False, rand_resize_scale=False, train_crop=None, keep_aspect_ratio=True, interpolation="bilinear"):
    interpolation = INTERPOLATIONS[interpolation]
    if not keep_aspect_ratio:
        print(f"Don't keep aspect ration, train={is_train_split}, im_size={im_size}")

    # Create the list of transformations
    t_list = []
    if not is_train_split:
        train_crop = None
    if is_train_split and rand_resize_scale:
        assert isinstance(rand_resize_scale, (int, float))
        t_list.append(transforms.RandomResize(im_size, im_size * rand_resize_scale, interpolation=interpolation))
    elif resize:
        resize_size = im_size if keep_aspect_ratio else (im_size, im_size)
        t_list.append(transforms.Resize(resize_size, interpolation=interpolation))

    if is_train_split and rand_crop:
        t_list.append(transforms.RandomCrop(train_crop or im_size))
    else:
        t_list.append(transforms.CenterCrop(train_crop or im_size))
    if is_train_split and h_flip:
        t_list.append(transforms.RandomHorizontalFlip())
    return transforms.Compose(t_list)


class LoopSubset(torch.utils.data.Dataset):
    r"""
    Subset of a dataset.
    When training, it will loop over all the dataset in blocks of size n.
    When evaluating, it will only use the first n elements
    """

    def __init__(
        self,
        dataset: torch.utils.data.Dataset,
        limit: int,
        train: bool,
        shuffle_train=True,
    ) -> None:
        self.state = TaskState()
        self.dataset = dataset
        self.loop_len = len(dataset)
        self.train = train
        self.limit = limit
        self.shuffling = shuffle_train and train
        if train and shuffle_train:
            self.rand_map = np.arange(self.loop_len)
            self.rand_map = np.random.permutation(self.rand_map)

    def __getitem__(self, idx):
        if isinstance(idx, list):
            return [self[i] for i in idx]
        cur_idx = int(idx)
        if self.train:
            cur_idx = (cur_idx + self.limit * self.state.cur_epoch) % self.loop_len
        if self.shuffling:
            cur_idx = self.rand_map[int(cur_idx)]
        return self.dataset[int(cur_idx)]

    def __len__(self):
        return self.limit


class MapDataset(torch.utils.data.Dataset):
    def __init__(self, dataset: torch.utils.data.Dataset, map_fn, limit=None) -> None:
        self.dataset = dataset
        self.map_fn = map_fn
        self.limit = limit

    def __getitem__(self, idx):
        return self.map_fn(self.dataset[idx])

    def __getitems__(self, indices: list[int]):
        return [self[i] for i in indices]

    def __len__(self):
        return self.limit or len(self.dataset)


####################################################################
# Image datasets
####################################################################


############### ImageNet ###############


class ImageNet(ImageFolder):
    def __init__(
        self,
        root: str,
        split: str = "train",
        **kwargs: Any,
    ) -> None:
        assert split in ["train", "val"]
        self.root = root
        self.split = split

        super().__init__(self.split_folder, **kwargs)
        self.root = root

    @property
    def split_folder(self) -> str:
        return os.path.join(self.root, self.split)

    def extra_repr(self) -> str:
        return "Split: {split}".format(**self.__dict__)


def _load_imagenet_split(cfg, ds_cfg, is_train_split, transform):
    dataset = ImageNet(
        root=ds_cfg.root,
        split="train" if is_train_split else "val",
        transform=transform,
    )
    return dataset


####################################################################
# Dataset loading
####################################################################

LOADERS = {
    "imagenet": _load_imagenet_split,
}


class DecoderEncoderScales(list):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert len(self) == 2

    @property
    def shape(self):
        return self.decoder_x.shape

    @property
    def decoder_x(self):
        return self[0]

    @property
    def encoder_x(self):
        return self[1]

    def to(self, *args, **kwargs):
        for i in range(len(self)):
            self[i] = self[i].to(*args, **kwargs)
        return self


class SquareCrop(torch.nn.Module):
    def __init__(self, random=False):
        super().__init__()
        self.random = random

    def forward(self, img):
        if isinstance(img, torch.Tensor):
            _, h, w = img.shape
        else:
            w, h = img.size
        crop_size = min(h, w)
        if random:
            top = random.randint(0, h - crop_size)
            left = random.randint(0, w - crop_size)
        else:
            top = (h - crop_size) // 2
            left = (w - crop_size) // 2
        return tF.crop(img, top, left, crop_size, crop_size)


class EncoderDecoderPairTransform(torch.nn.Module):
    def __init__(self, enc_size, dec_size, augs, is_train_split, post_transform=None):
        super().__init__()
        self.enc_size = enc_size
        self.dec_size = dec_size
        self.augs = augs
        self.is_train_split = is_train_split

        # Create transforms
        self.shared_augs = self.create_shared_augs(**augs)
        self.enc_resize = transforms.Resize(enc_size, interpolation=INTERPOLATIONS["bilinear"])
        self.dec_resize = transforms.Resize(dec_size, interpolation=INTERPOLATIONS[self.interpolation])
        self.post_transform = post_transform

    def create_shared_augs(self, resize=True, rand_crop=False, h_flip=False, rand_resize_scale=False, train_crop=None, keep_aspect_ratio=True, interpolation="bilinear"):
        self.interpolation = interpolation
        t_list = []
        t_list.append(SquareCrop(random=(self.is_train_split and rand_crop)))
        if self.is_train_split and h_flip:
            t_list.append(transforms.RandomHorizontalFlip())

        return transforms.Compose(t_list)

    def forward(self, x):
        x = self.shared_augs(x)
        enc_x = self.enc_resize(x)
        dec_x = self.dec_resize(x)
        if self.post_transform is not None:
            enc_x = self.post_transform(enc_x)
            dec_x = self.post_transform(dec_x)
        return DecoderEncoderScales([dec_x, enc_x])


def make_transform(cfg, ds_cfg, is_train_split):
    encoder_im_size = ds_cfg.get("encoder_im_size", None)

    if encoder_im_size is not None and encoder_im_size != ds_cfg.im_size:
        return EncoderDecoderPairTransform(
            enc_size=encoder_im_size,
            dec_size=ds_cfg.im_size,
            augs=ds_cfg.augs,
            is_train_split=is_train_split,
            post_transform=_image_normalize_transform(cfg, mean=ds_cfg.normalize.mean, std=ds_cfg.normalize.std),
        )

    return transforms.Compose(
        [
            _image_resize_and_random_aug(is_train_split and not cfg.deterministic, ds_cfg.im_size, **ds_cfg.augs),
            _image_normalize_transform(cfg, mean=ds_cfg.normalize.mean, std=ds_cfg.normalize.std),
        ]
    )


def get_data_loader(cfg, ds_cfg, dataset, is_train_split):
    limit = ds_cfg.limit
    if limit is not None:
        dataset = LoopSubset(dataset, limit, is_train_split and not cfg.deterministic)
    if is_train_split and len(dataset) == 0:
        return

    bs = cfg.training.gpu_batch_size
    if not is_train_split:
        bs = cfg.testing.gpu_batch_size or bs
    assert bs
    n_workers = 0 if cfg.deterministic else NUM_WORKERS
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=bs,
        shuffle=is_train_split,
        num_workers=n_workers,
        persistent_workers=n_workers > 0,
        pin_memory=PIN_MEORY,
    )


def make_dataset_and_loader(cfg, is_train, ds_cfg=None):
    if ds_cfg is None:
        ds_cfg = cfg.dataset if is_train else cfg.test_dataset

    name = ds_cfg.name.lower()
    if name not in LOADERS:
        raise ValueError(f"No dataset named {name}")

    transform = make_transform(cfg, ds_cfg, is_train)
    dataset = LOADERS[name](cfg, ds_cfg, is_train, transform)
    if isinstance(dataset, tuple) and len(dataset) == 2:
        dataset, loader = dataset
    else:
        loader = get_data_loader(cfg, ds_cfg, dataset, is_train and not cfg.deterministic)

    return dataset, loader


def dataset_from_name(cfg):
    train_dataset, train_loader = make_dataset_and_loader(cfg, True)
    test_dataset, test_loader = make_dataset_and_loader(cfg, False)

    return (train_dataset, test_dataset), (train_loader, test_loader)
