
import numpy as np
import os
import torch
import torchvision.transforms as transforms

from PIL import Image
from torch.utils.data import Dataset

########################################################
## GENERAL DATASET GLOBAL VARIABLES
########################################################

N_CLASSES = 50

# CAN BE OVERWRITTEN WITH AN ENV VARIABLE DATASET_DIR
DATASET_DIR = os.environ.get("DATASET_DIR", 'data/AwA2/')


#########################################################
## CONCEPT INFORMATION REGARDING AwA2
#########################################################

CLASS_NAMES = [
    'antelope',
    'grizzly+bear',
    'killer+whale',
    'beaver',
    'dalmatian',
    'persian+cat',
    'horse',
    'german+shepherd',
    'blue+whale',
    'siamese+cat',
    'skunk',
    'mole',
    'tiger',
    'hippopotamus',
    'leopard',
    'moose',
    'spider+monkey',
    'humpback+whale',
    'elephant',
    'gorilla',
    'ox',
    'fox',
    'sheep',
    'seal',
    'chimpanzee',
    'hamster',
    'squirrel',
    'rhinoceros',
    'rabbit',
    'bat',
    'giraffe',
    'wolf',
    'chihuahua',
    'rat',
    'weasel',
    'otter',
    'buffalo',
    'zebra',
    'giant+panda',
    'deer',
    'bobcat',
    'pig',
    'lion',
    'mouse',
    'polar+bear',
    'collie',
    'walrus',
    'raccoon',
    'cow',
    'dolphin',
]

CONCEPT_SEMANTICS = [
    'black',
    'white',
    'blue',
    'brown',
    'gray',
    'orange',
    'red',
    'yellow',
    'patches',
    'spots',
    'stripes',
    'furry',
    'hairless',
    'toughskin',
    'big',
    'small',
    'bulbous',
    'lean',
    'flippers',
    'hands',
    'hooves',
    'pads',
    'paws',
    'longleg',
    'longneck',
    'tail',
    'chewteeth',
    'meatteeth',
    'buckteeth',
    'strainteeth',
    'horns',
    'claws',
    'tusks',
    'smelly',
    'flys',
    'hops',
    'swims',
    'tunnels',
    'walks',
    'fast',
    'slow',
    'strong',
    'weak',
    'muscle',
    'bipedal',
    'quadrapedal',
    'active',
    'inactive',
    'nocturnal',
    'hibernate',
    'agility',
    'fish',
    'meat',
    'plankton',
    'vegetation',
    'insects',
    'forager',
    'grazer',
    'hunter',
    'scavenger',
    'skimmer',
    'stalker',
    'newworld',
    'oldworld',
    'arctic',
    'coastal',
    'desert',
    'bush',
    'plains',
    'forest',
    'fields',
    'jungle',
    'mountains',
    'ocean',
    'ground',
    'water',
    'tree',
    'cave',
    'fierce',
    'timid',
    'smart',
    'group',
    'solitary',
    'nestspot',
    'domestic',
]

CONCEPT_GROUPS = {
    'color': ['black', 'white', 'blue', 'brown', 'gray', 'orange', 'red', 'yellow'],
    'fur_pattern': ['patches', 'spots', 'stripes', 'furry', 'hairless', 'toughskin'],
    'size': ['big', 'small', 'bulbous', 'lean'],
    'limb_shape': ['flippers', 'hands', 'hooves', 'pads', 'paws', 'longleg', 'longneck'],
    'tail': ['tail'],
    'teeth_type': ['chewteeth','meatteeth','buckteeth','strainteeth'],
    'horns': ['horns'],
    'claws': ['claws'],
    'tusks': ['tusks'],
    'smelly': ['smelly'],
    'transport_mechanism': ['flys', 'hops', 'swims', 'tunnels', 'walks'],
    'speed': ['fast', 'slow'],
    'strength': ['strong', 'weak'],
    'muscle': ['muscle'],
    'movement_move': ['bipedal', 'quadrapedal'],
    'active': ['active', 'inactive'],
    'nocturnal': ['nocturnal'],
    'hibernate': ['hibernate'],
    'agility': ['agility'],
    'diet': ['fish', 'meat', 'plankton', 'vegetation', 'insects'],
    'feeding_type': ['forager', 'grazer', 'hunter', 'scavenger', 'skimmer', 'stalker'],
    'general_location': ['newworld', 'oldworld', 'arctic'],
    'biome': ['coastal', 'desert', 'bush', 'plains', 'forest', 'fields', 'jungle', 'mountains', 'ocean', 'ground', 'water', 'tree', 'cave'],
    'fierceness': ['fierce', 'timid'],
    'smart': ['smart'],
    'social_mode': ['group', 'solitary'],
    'nestspot': ['nestspot'],
    'domestic': ['domestic'],
}
CONCEPT_GROUPS = {
    key: [CONCEPT_SEMANTICS.index(name) for name in concept_names]
    for key, concept_names in CONCEPT_GROUPS.items()
}



class AwA2Dataset(Dataset):
    """
    Returns a compatible Torch Dataset object customized for the AwA2 dataset
    """

    def __init__(
        self,
        root,
        training_augment=True,
        split='train',
        image_size=224,
        concept_transform=None,
        sample_transform=None,
        selected_concepts=None,
        seed=42,
        save_split=False,
    ):
        self.root = root
        self.training_augment = training_augment
        self.split = split
        self.concept_transform = concept_transform or identity_transform
        self.name = 'AwA2'

        if not os.path.exists(self.root):
            raise ValueError(
                f'{self.root} does not exist yet. Please download the '
                f'dataset first.'
            )

        if split == 'train':
            self.transform = get_transform_awa2(
                train=True,
                augment_data=training_augment,
                image_size=image_size,
                sample_transform=sample_transform,
            )
        else:
            self.transform = get_transform_awa2(
                train=False,
                augment_data=False,
                image_size=image_size,
                sample_transform=sample_transform,
            )


        self.predicate_binary_mat = np.array(np.genfromtxt(
            os.path.join(root, 'predicate-matrix-binary.txt'),
            dtype='int',
        ))
        self.class_to_index = dict()
        # Build dictionary of indices to classes
        with open(f"{root}/classes.txt") as f:
            for line in f:
                class_name = line.split('\t')[1].strip()
                self.class_to_index[class_name] = len(self.class_to_index)

        split_file_path = os.path.join(
            self.root,
            f'{split}_split.npz',
        )
        if not os.path.exists(split_file_path):
            print(
                f"Split files for AWA2 could not be found. Generating new "
                f"train, validation, and test splits with seed {seed}."
            )
            split_info = self._generate_splits(seed=seed, save=save_split)[split]
        else:
            split_file = os.path.join(
                self.root,
                f'{split}_split.npz',
            )
            split_info = np.load(split_file)
        self.img_paths = split_info['paths']
        self.img_labels = split_info['labels']
        if selected_concepts is None:
            selected_concepts = list(range(len(CONCEPT_SEMANTICS)))
        self.selected_concepts = selected_concepts
        self.concept_names = self.concept_attr_names = list(
            np.array(
                CONCEPT_SEMANTICS
            )[selected_concepts]
        )
        self.task_names = self.task_attr_names = CLASS_NAMES

    def _generate_splits(self, seed, train_size=0.6, val_size=0.2, save=False):
        # First find all samples and generate a list of their paths
        image_paths = []
        image_classes = []
        img_dir = os.path.join(self.root, 'JPEGImages')
        for root, _, files in os.walk(img_dir):
            for file in files:
                if file.lower().endswith('.jpg'):
                    image_paths.append(os.path.abspath(os.path.join(root, file)))
                    parent_dir = os.path.basename(
                        os.path.dirname(image_paths[-1])
                    )
                    image_classes.append(self.class_to_index[parent_dir])

        np.random.seed(seed)
        indices = np.arange(len(image_paths))
        np.random.shuffle(indices)

        train_end = int(train_size * len(image_paths))
        val_end = train_end + int(val_size * len(image_paths))

        # Now time to generate our split matrices and saving them
        image_paths = np.array(image_paths)
        image_classes = np.array(image_classes)

        train_indices = indices[:train_end]
        train_paths = image_paths[train_indices]
        train_classes = image_classes[train_indices]
        if save:
            np.savez(
            os.path.join(self.root, 'train_split.npz'),
            paths=train_paths,
            labels=train_classes,
        )

        val_indices = indices[train_end:val_end]
        val_paths = image_paths[val_indices]
        val_classes = image_classes[val_indices]
        if save:
            np.savez(
                os.path.join(self.root, 'val_split.npz'),
                paths=val_paths,
                labels=val_classes,
            )

        test_indices = indices[val_end:]
        test_paths = image_paths[test_indices]
        test_classes = image_classes[test_indices]
        if save:
            np.savez(
                os.path.join(self.root, 'test_split.npz'),
                paths=test_paths,
                labels=test_classes,
            )
        return {
            'train': {'paths': train_paths, 'labels': train_classes},
            'val': {'paths': val_paths, 'labels': val_classes},
            'test': {'paths': test_paths, 'labels': test_classes}
        }


    def __getitem__(self, index):
        img = Image.open(self.img_paths[index])
        if img.getbands()[0] == 'L':
            img = img.convert('RGB')
        if self.transform:
            img = self.transform(img)
        label_idx = self.img_labels[index]
        concepts = self.predicate_binary_mat[label_idx,:]
        concepts = self.concept_transform(
            np.array(concepts)[self.selected_concepts]
        )
        return img, torch.FloatTensor(concepts), label_idx

    def __len__(self):
        return len(self.img_paths)

def identity_transform(x):
    """Identity transformation function that returns input unchanged."""
    return x

def get_transform_awa2(
    train,
    augment_data,
    image_size=224,
    sample_transform=None,
):
    """Helper function to get the appropiate transformation for the awa2
    data loader.

    Args:
        train (bool): Whether or not this transform is for the training fold
            of the awa2 dataset or not.
        augment_data (bool): Whether or not we want to perform standard
            augmentations (crops and flips) used for the CUB dataset.
        image_size (int, optional): Size of the width and height of each
            of the generated images. Defaults to 224.

    Returns:
        torchvision.Transform: a valid torchvision transform to be applied to
            each image of the awa2 dataset being loaded.
    """
    scale = 256.0/224.0
    sample_transform = (
        sample_transform if sample_transform is not None
        else identity_transform
    )
    if (not train) or (not augment_data):
        # Resizes the image to a slightly larger square then crops the center.
        transform = transforms.Compose([
            transforms.Resize((
                int(image_size*scale),
                int(image_size*scale),
            )),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            sample_transform,
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        transform = transforms.Compose([
            transforms.RandomResizedCrop(
                image_size,
                scale=(0.7, 1.0),
                ratio=(0.75, 1.3333333333333333),
                interpolation=2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            sample_transform,
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    return transform