from semantic_aug.few_shot_dataset import FewShotDataset
from semantic_aug.generative_augmentation import GenerativeAugmentation
from semantic_aug.util import get_img_num_per_cls
from typing import Any, Tuple, Dict

import numpy as np
import torchvision.transforms as transforms
import torchvision
import torch
import glob
import os

from scipy.io import loadmat
from PIL import Image
from collections import defaultdict

DEFAULT_IMAGE_DIR = "PATH/data/pet"


class PetsDataset(FewShotDataset):

    class_names = [
        'abyssinian',
        'american bulldog',
        'american pit bull terrier',
        'basset hound',
        'beagle',
        'bengal',
        'birma',
        'bombay',
        'boxer',
        'british shorthair',
        'chihuahua',
        'egyptian mau',
        'english cocker spaniel',
        'english setter',
        'german shorthaired',
        'great pyrenees',
        'havanese',
        'japanese chin',
        'keeshond',
        'leonberger',
        'maine coon',
        'miniature pinscher',
        'newfoundland',
        'persian',
        'pomeranian',
        'pug',
        'ragdoll',
        'russian blue',
        'saint bernard',
        'samoyed',
        'scottish terrier',
        'shiba inu',
        'siamese',
        'sphynx',
        'staffordshire bull terrier',
        'wheaten terrier',
        'yorkshire terrier'
    ]

    num_classes: int = len(class_names)

    def __init__(self, *args, split: str = "train", seed: int = 0,
                 image_dir: str = DEFAULT_IMAGE_DIR,
                 examples_per_class: int = None,
                 generative_aug: GenerativeAugmentation = None,
                 synthetic_probability: float = 0.5,
                 use_randaugment: bool = False,
                 image_size: Tuple[int] = (256, 256), **kwargs):

        super(PetsDataset, self).__init__(
            *args, examples_per_class=examples_per_class,
            synthetic_probability=synthetic_probability,
            generative_aug=generative_aug, **kwargs)

        if split == 'train':
            listfile = os.path.join(image_dir, 'annotations', 'trainval.txt')
        elif split == 'val':
            listfile = os.path.join(image_dir, 'annotations', 'test.txt')

        if os.path.exists(listfile):
            with open(listfile, 'r') as f:
                samples = [
                    (line.strip().split(' ')[0], line.strip().split(' ')[1])
                    for line in f.readlines()
                    if line.strip()
                ]

            self.images = [s[0] for s in samples]
            self.targets = [int(s[1]) for s in samples]
        else:
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        class_to_images = defaultdict(list)

        for image_idx, file_name in enumerate(self.images):
            class_name = self.class_names[self.targets[image_idx] - 1]
            image_path = os.path.join(image_dir, 'images-sz/352',f'{file_name}.jpg')
            class_to_images[class_name].append(image_path)

        rng = np.random.default_rng(seed)
        class_to_ids = {key: rng.permutation(
            len(class_to_images[key])) for key in self.class_names}

        if examples_per_class is not None:
            if examples_per_class >= 1:
                class_to_ids = {key: ids[:examples_per_class]
                                for key, ids in class_to_ids.items()}
            else:
                #  imbalance case examples_per_class in (0, 1)
                img_num_per_cls = get_img_num_per_cls(
                    img_max=16,
                    num_class=self.num_classes,
                    imb_type='exp',
                    imb_factor=examples_per_class)

                rng.shuffle(img_num_per_cls)

                for cls_id, (key, ids) in enumerate(class_to_ids.items()):
                    class_to_ids[key] =ids[:img_num_per_cls[cls_id]]

        self.class_stat = { key: len(ids) for key, ids in class_to_ids.items() }

        self.class_to_images = {
            key: [class_to_images[key][i] for i in ids]
            for key, ids in class_to_ids.items()}

        self.all_images = sum([
            self.class_to_images[key]
            for key in self.class_names], [])

        self.all_labels = [i for i, key in enumerate(
            self.class_names) for _ in self.class_to_images[key]]

        if use_randaugment: train_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandAugment(),
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Lambda(lambda x: x.expand(3, *image_size)),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                  std=[0.5, 0.5, 0.5])
        ])

        else: train_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15.0),
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Lambda(lambda x: x.expand(3, *image_size)),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                  std=[0.5, 0.5, 0.5])
        ])

        val_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Lambda(lambda x: x.expand(3, *image_size)),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                  std=[0.5, 0.5, 0.5])
        ])

        self.transform = {"train": train_transform, "val": val_transform}[split]

    def __len__(self):

        return len(self.all_images)

    def get_image_by_idx(self, idx: int) -> Image.Image:

        return Image.open(self.all_images[idx]).convert('RGB')

    def get_label_by_idx(self, idx: int) -> int:

        return self.all_labels[idx]

    def get_metadata_by_idx(self, idx: int) -> dict:

        return dict(name=self.class_names[self.all_labels[idx]])
