import itertools
import torch
from loaders.image_loader import load_images


class CondLoader:
    def __init__(self):
        super().__init__()

class ImageCondLoader(CondLoader):
    def __init__(self, param_root_dir, image_data_name, image_num, guider):
        """
        :param param_root_dir:
        :param image_data_name:
        :param image_num:
        :param guider:
        """
        super().__init__()
        self.param_root_dir = param_root_dir
        self.image_data_name = image_data_name
        self.image_num = image_num
        self.guider = guider

        self.param_type_map = self._load_param_type_map()
        self.image_cond_loaders = self._init_image_loader()
        self.image_cond_dict = {}

    def _load_param_type_map(self):
        # TODO load real map.json via self.param_root_dir
        # Art, Clipart, Product, RealWorld
        param_tupe_map = {
            "Clipart": "/nfs196/hjc/datasets/Office-Home/Clipart",
            "Product": "/nfs196/hjc/datasets/Office-Home/Product",
            "RealWorld": "/nfs196/hjc/datasets/Office-Home/RealWorld",
        }
        return param_tupe_map

    def _init_image_loader(self):
        image_loaders = {}
        for param_type, image_dir in self.param_type_map.items():
            image_loader = load_images(
                data_dir=image_dir,
                data_name=self.image_data_name,
                data_type="test",
                batch_size=self.image_num
            )
            image_loader = itertools.cycle(image_loader)
            image_loaders[param_type] = image_loader
        return image_loaders

    def load_image_cond(self, param_types):
        image_cond_dict = {}
        for param_type, image_loader in self.image_cond_loaders.items():
            samples = next(image_loader)
            images, labels, _ = samples
            images = images.to(self.guider.device)
            embeds = self.guider.encode_image(images)
            embeds = torch.mean(embeds, dim=0)  # (image_num, dim) -> (dim)
            # TODO Cache encoded features to save time
            image_cond_dict[param_type] = embeds

        image_conds = list(map(image_cond_dict.get, param_types))
        image_conds = torch.stack(image_conds, dim=0)
        return image_conds

    # def load_image_cond(self, param_types):
    #     if len(self.image_cond_dict) == 0:
    #         print('--->', 'Cache image condition embedding')
    #         for param_type, image_loader in self.image_cond_loaders.items():
    #             samples = next(image_loader)
    #             images, labels, _ = samples
    #             images = images.to(self.guider.device)
    #             embeds = self.guider.encode_image(images)
    #             embeds = torch.mean(embeds, dim=0)  # (image_num, dim) -> (dim)
    #             self.image_cond_dict[param_type] = embeds
    #         torch.save(self.image_cond_dict,
    #                    "/nfs196/hjc/projects/PP/outputs/rn18_OH_Ar_base/unet_v2_epo30000_bs512_lr1e-4_cache/image_cond_dict.pt")
    #
    #     image_conds = list(map(self.image_cond_dict.get, param_types))
    #     image_conds = torch.stack(image_conds, dim=0)
    #     return image_conds
