from semantic_aug.few_shot_dataset import FewShotDataset
from semantic_aug.generative_augmentation import GenerativeAugmentation
from semantic_aug.util import get_img_num_per_cls
from typing import Any, Tuple, Dict

import numpy as np
import torchvision.transforms as transforms
import torch
import os

from pycocotools.coco import COCO
from PIL import Image
from collections import defaultdict

COCO_DIR = "PATH/data/coco"

TRAIN_IMAGE_DIR = os.path.join(COCO_DIR, "train2017")
VAL_IMAGE_DIR = os.path.join(COCO_DIR, "val2017")

DEFAULT_TRAIN_INSTANCES = os.path.join(
    COCO_DIR, "annotations/lvis_v1_train.json")
DEFAULT_VAL_INSTANCES = os.path.join(
    COCO_DIR, "annotations/lvis_v1_val.json")


class LVISDataset(FewShotDataset):

    class_names = [
        'garbage_truck', 'dirt_bike', 'bunk_bed', 'kimono', 'parachute',
        'birdcage', 'water_scooter', 'grizzly', 'cruise_ship', 'robe',
        'pelican', 'pony', 'sweat_pants', 'crib', 'canoe', 'kayak', 'horse_carriage',
        'aquarium', 'jeep', 'black_sheep', 'eagle', 'tow_truck', 'jet_plane', 'ram_(animal)',
        'camper_(vehicle)', 'kitten', 'parrot', 'helicopter', 'shopping_cart', 'trunk',
        'tartan', 'hamburger'
    ]

    # class_names = [
    #     'cast', 'coatrack', 'crab_(animal)', 'crayon',
    #     'crow', 'flannel', 'lip_balm', 'needle', 'oil_lamp',
    #     'parakeet', 'passport', 'peeler_(tool_for_fruit_and_vegetables)',
    #     'poker_(fire_stirring_tool)', 'postcard', 'tricycle',
    #     'wine_bucket', 'Band_Aid', 'dishrag', 'hairpin', 'hummingbird',
    #     'jewelry', 'kitchen_table', 'ladybug', 'musical_instrument', 'papaya',
    #     'perfume', 'nest', 'jam', 'fruit_juice', 'freight_car', 'identity_card',
    #     'bottle_opener', 'volleyball', 'tinsel', 'snowmobile', 'shaker', 'squirrel',
    #     'tea_bag', 'walking_stick', 'beret', 'candy_cane', 'cardigan', 'cayenne_(spice)'
    #     'cufflink', 'fish_food', 'dental_floss', 'cowbell', 'corkscrew', 'birdfeeder',
    #     'shark', 'pocketknife', 'mail_slot', 'lizard', 'hammock', 'hairnet',
    #     'grocery_bag', 'cincture', 'ginger', 'garbage_truck', 'flipper_(footwear)', 'drill',
    #     'chili_(vegetable)', 'sleeping_bag', 'melon', 'lollipop', 'nightshirt', 'parka',
    #     'skewer', 'bread-bin', 'flamingo', 'parasol', 'pretzel', 'recliner',
    #     'record_player', 'binoculars', 'underdrawers', 'food_processor', 'golfcart', 'honey',
    #     'shaving_cream', 'wind_chime', 'ambulance', 'beanbag', 'duct_tape', 'pinecone',
    #     'puppy', 'dirt_bike', 'brownie', 'camel', 'can_opener', 'shield',
    #     'sour_cream', 'turtle', 'windsock', 'bunk_bed', 'clothespin', 'alligator',
    #     'videotape', 'tapestry', 'sweet_potato', 'phonograph_record', 'ironing_board', 'eraser',
    #     'chocolate_bar', 'cape', 'kimono', 'packet', 'palette', 'cantaloup',
    #     'gelatin', 'runner_carpet', 'sushi', 'barrow', 'flute_glass', 'foal',
    #     'bouquet', 'tiara', 'mop', 'projectile_(weapon)', 'iron_(for_clothing)', 'coconut',
    #     'thumbtack', 'reamer_(juicer)', 'parachute', 'gazelle', 'CD_player', 'birdcage',
    #     'webcam', 'wok', 'olive_oil', 'meatball', 'Lego', 'ice_maker',
    #     'trophy_cup', 'cash_register', 'bean_curd', 'hammer', 'camcorder', 'gargle',
    #     'watering_can', 'water_scooter', 'soupspoon', 'footstool', 'drum_)musical_instrument)', 'coffeepot',
    #     'cigarette_case', 'anklet', 'thermometer', 'solar_array', 'rhinoceros', 'razorblade',
    #     'pottery', 'pliers', 'football_(American)', 'cock', 'grizzly', 'domestic_ass',
    #     'cruise_ship', 'artichoke', 'toolbox', 'router_(computer_equipment)', 'pelican', 'mat_(gym_equipment)',
    #     'gun', 'freshener', 'salami', 'robe', 'raft', 'pony',
    #     'hot_sauce', 'chocolate_cake', 'business_card', 'wedding_ring', 'tape_measure', 'silo',
    #     'fire_hose', 'tortilla', 'mandarin_orange', 'cub_(animal)', 'crock_pot', 'clothes_hamper',
    #     'cleansing_agent', 'cart', 'basketball_backboard', 'snowman', 'easel', 'waffle',
    #     'tiger', 'sweat_pants', 'peanut_butter', 'clipboard', 'water_cooler', 'turkey_(food)',
    #     'teakettle', 'loveseat', 'alcohol', 'racket', 'fishing_rod', 'crib',
    #     'canoe', 'bowler_hat', 'step_stool', 'saddlebag', 'file_cabinet', 'Ferris_wheel',
    #     'brussels_sprouts', 'deadbolt', 'birdhouse', 'thermos_bottle', 'kayak', 'gravestone',
    #     'flashlight', 'crescent_roll', 'horse_carriage', 'aquarium', 'jeep', 'pacifier',
    #     'black_sheep', 'wrench', 'postbox_(public)', 'eagle', 'card', 'boiled_egg',
    #     'tow_truck', 'jet_plane', 'bathrobe', 'vacuum_cleaner', 'seashell', 'ram_(animal)',
    #     'cornet', 'camper_(vehicle)', 'bamboo', 'tights_(clothing)', 'paintbrush', 'medicine',
    #     'clasp', 'kitten', 'eggplant', 'blackberry', 'basketball', 'turban',
    #     'teacup', 'measuring_stick', 'garden_hose', 'frog', 'crisp_(potato_chip)', 'bead',
    #     'parrot', 'helicopter', 'eggbeater', 'armband', 'windmill', 'shopping_cart',
    #     'egg_yolk', 'coleslaw', 'wig', 'water_tower', 'water_ski', 'trunk',
    #     'toast_(food)', 'tartan', 'mashed_potato', 'award', 'hamburger', 'deer',
    # ]

    num_classes: int = len(class_names)

    def __init__(self, *args, split: str = "train", seed: int = 0,
                 train_image_dir: str = TRAIN_IMAGE_DIR,
                 val_image_dir: str = VAL_IMAGE_DIR,
                 train_instances_file: str = DEFAULT_TRAIN_INSTANCES,
                 val_instances_file: str = DEFAULT_VAL_INSTANCES,
                 examples_per_class: int = None,
                 generative_aug: GenerativeAugmentation = None,
                 synthetic_probability: float = 0.5,
                 use_randaugment: bool = False,
                 image_size: Tuple[int] = (256, 256), **kwargs):

        super(LVISDataset, self).__init__(
            *args, examples_per_class=examples_per_class,
            synthetic_probability=synthetic_probability,
            generative_aug=generative_aug, **kwargs)

        image_dir = {"train": train_image_dir, "val": val_image_dir}[split]
        instances_file = {"train": train_instances_file, "val": val_instances_file}[split]

        class_to_images = defaultdict(list)
        class_to_annotations = defaultdict(list)

        self.cocoapi = COCO(instances_file)
        for image_id, x in self.cocoapi.imgs.items():

            annotations = self.cocoapi.imgToAnns[image_id]
            if len(annotations) == 0: continue

            maximal_ann = max(annotations, key=lambda x: x["area"])
            class_name = self.cocoapi.cats[maximal_ann["category_id"]]["name"]

            if "train2017" in x["coco_url"]:
                class_to_images[class_name].append(
                    os.path.join(train_image_dir, x["coco_url"].split("train2017/")[-1]))
            elif "val2017" in x["coco_url"]:
                class_to_images[class_name].append(
                    os.path.join(val_image_dir, x["coco_url"].split("val2017/")[-1]))
            class_to_annotations[class_name].append(maximal_ann)

        rng = np.random.default_rng(seed)
        class_to_ids = {key: rng.permutation(
            len(class_to_images[key])) for key in self.class_names}

        # for key in self.class_names:
        #     ids = class_to_ids[key]
        #     print(f"Class <{key}>: {len(ids)}")

        if examples_per_class is not None:
            if examples_per_class >= 1:
                class_to_ids = {key: ids[:examples_per_class]
                                for key, ids in class_to_ids.items()}
            else:
                #  imbalance case examples_per_class in (0, 1)
                img_num_per_cls = get_img_num_per_cls(
                    img_max=16,
                    num_class=self.num_classes,
                    imb_type='exp',
                    imb_factor=examples_per_class)

                rng.shuffle(img_num_per_cls)

                for cls_id, (key, ids) in enumerate(class_to_ids.items()):
                    class_to_ids[key] =ids[:img_num_per_cls[cls_id]]

        self.class_stat = { key: len(ids) for key, ids in class_to_ids.items() }

        self.class_to_images = {
            key: [class_to_images[key][i] for i in ids]
            for key, ids in class_to_ids.items()}

        self.class_to_annotations = {
            key: [class_to_annotations[key][i] for i in ids]
            for key, ids in class_to_ids.items()}

        self.all_images = sum([
            self.class_to_images[key]
            for key in self.class_names], [])

        self.all_annotations = sum([
            self.class_to_annotations[key]
            for key in self.class_names], [])

        self.all_labels = [i for i, key in enumerate(
            self.class_names) for _ in self.class_to_images[key]]

        if use_randaugment: train_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandAugment(),
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Lambda(lambda x: x.expand(3, *image_size)),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                  std=[0.5, 0.5, 0.5])
        ])

        else: train_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15.0),
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Lambda(lambda x: x.expand(3, *image_size)),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                  std=[0.5, 0.5, 0.5])
        ])

        val_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Lambda(lambda x: x.expand(3, *image_size)),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                  std=[0.5, 0.5, 0.5])
        ])

        self.transform = {"train": train_transform, "val": val_transform}[split]

    def __len__(self):

        return len(self.all_images)

    def get_image_by_idx(self, idx: int) -> torch.Tensor:

        return Image.open(self.all_images[idx]).convert('RGB')

    def get_label_by_idx(self, idx: int) -> torch.Tensor:

        return self.all_labels[idx]

    def get_metadata_by_idx(self, idx: int) -> Dict:

        annotation = self.all_annotations[idx]

        return dict(name=self.class_names[self.all_labels[idx]],
                    mask=self.cocoapi.annToMask(annotation),
                    **annotation)
