from semantic_aug.generative_augmentation import GenerativeAugmentation
from typing import Any, Tuple
from torch.utils.data import Dataset
from collections import defaultdict
from itertools import product
from tqdm import tqdm
from PIL import Image
import glob

import torchvision.transforms as transforms
import torch
import numpy as np
import abc
import random
import os


class FewShotDataset(Dataset):

    num_classes: int = None
    class_names: int = None

    def __init__(self, examples_per_class: int = None, 
                 generative_aug: str = None, 
                 synthetic_probability: float = 0.5,
                 synthetic_dir: dict = None):

        self.examples_per_class = examples_per_class
        self.generative_aug = generative_aug

        self.synthetic_probability = synthetic_probability
        self.synthetic_dir = synthetic_dir
        self.synthetic_examples = {
            'samecls': defaultdict(list),
            'othercls': defaultdict(list)
        }

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], 
                                  std=[0.5, 0.5, 0.5]),
        ])
    
    @abc.abstractmethod
    def get_image_by_idx(self, idx: int) -> Image.Image:

        return NotImplemented
    
    @abc.abstractmethod
    def get_label_by_idx(self, idx: int) -> int:

        return NotImplemented
    
    @abc.abstractmethod
    def get_metadata_by_idx(self, idx: int) -> dict:

        return NotImplemented
    
    
    def load_images(self, root, examples, num_repeats):
        for idx in tqdm(range(len(self))):

            # image = self.get_image_by_idx(idx)
            label = self.get_label_by_idx(idx)
            metadata = self.get_metadata_by_idx(idx)
            name_w_underscore = metadata.get("name", "").replace(" ", "_")
            
            # load image in synthetic_dir/pascal-{seed}-{epc}/class_name
            img_dir = os.path.join(
                root, name_w_underscore, f'{idx}','*.png')
            img_paths = glob.glob(img_dir)
            
            # print(word_name)
            # print(img_dir)
            # print(img_paths)
            assert len(img_paths) >= num_repeats
            for path in img_paths[:num_repeats]:
            #     image = Image.open(path).convert('RGB')
                examples[idx].append((path, label))
    
    
    def load_augmentations(self, num_repeats: int):
        self.load_images(self.synthetic_dir['samecls'],
                            self.synthetic_examples['samecls'],
                            num_repeats)
        
        if self.generative_aug == 'autodiff':
            self.load_images(self.synthetic_dir['othercls'],
                            self.synthetic_examples['othercls'],
                            num_repeats)

        # self.synthetic_examples.clear()
        # for idx in tqdm(range(len(self))):

        #     # image = self.get_image_by_idx(idx)
        #     label = self.get_label_by_idx(idx)
        #     metadata = self.get_metadata_by_idx(idx)
        #     name_w_underscore = metadata.get("name", "").replace(" ", "_")
            
        #     # load image in synthetic_dir/pascal-{seed}-{epc}/class_name
        #     img_dir = os.path.join(
        #         self.synthetic_dir, name_w_underscore, f'{idx}','*.png')
        #     img_paths = glob.glob(img_dir)
            
        #     # print(word_name)
        #     # print(img_dir)
        #     # print(img_paths)
            
        #     assert len(img_paths) >= num_repeats
        #     for path in img_paths[:num_repeats]:
        #     #     image = Image.open(path).convert('RGB')
        #         self.synthetic_examples[idx].append((path, label))

    # def generate_augmentations(self, num_repeats: int):

    #     self.synthetic_examples.clear()
    #     options = product(range(len(self)), range(num_repeats))

    #     for idx, num in tqdm(list(
    #             options), desc="Generating Augmentations"):

    #         image = self.get_image_by_idx(idx)
    #         label = self.get_label_by_idx(idx)

    #         image, label = self.generative_aug(
    #             image, label, self.get_metadata_by_idx(idx))

    #         if self.synthetic_dir is not None:

    #             pil_image, image = image, os.path.join(
    #                 self.synthetic_dir, f"aug-{idx}-{num}.png")

    #             pil_image.save(image)

    #         self.synthetic_examples[idx].append((image, label))

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:

        if self.generative_aug == 'dafusion':   
            if len(self.synthetic_examples['samecls'][idx]) > 0 and \
                    np.random.uniform() < self.synthetic_probability:

                image, label = random.choice(self.synthetic_examples['samecls'][idx])
                if isinstance(image, str): image = Image.open(image)

            else:

                image = self.get_image_by_idx(idx)
                label = self.get_label_by_idx(idx)

            return self.transform(image), label
        
        elif self.generative_aug == 'autodiff':
            image = self.get_image_by_idx(idx)
            label = self.get_label_by_idx(idx)
            samecls_image, _ = random.choice(self.synthetic_examples['samecls'][idx])
            othercls_image, _ = random.choice(self.synthetic_examples['othercls'][idx])
            if isinstance(samecls_image, str): samecls_image = Image.open(samecls_image)
            if isinstance(othercls_image, str): othercls_image = Image.open(othercls_image)
            
            data = [
                self.transform(image),
                self.transform(samecls_image),
                self.transform(othercls_image),
            ]
            
            return data, label
        
        else:
            image = self.get_image_by_idx(idx)
            label = self.get_label_by_idx(idx)
            return self.transform(image), label