import os
import re
from abc import ABC, abstractmethod
from os.path import join
import pandas as pd
import numpy as np
import cv2
import h5py
from torch.utils.data import Dataset
from torchvision import transforms
import torch
from PIL import Image
from torchvision.models import ResNet18_Weights

def default_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def to_PIL(img):
    return Image.fromarray(img.astype('uint8'), 'RGB')


supported_datasets = [
    'cars3d',
    'celeba',
    'rafd'
]


def get_dataset(dataset_id, path=None):
    if dataset_id == 'cars3d':
        return Cars3D(path)

    if dataset_id == 'celeba':
        return CelebA(path)

    if dataset_id == 'rafd':
        return RaFD(path)

    raise Exception('unsupported dataset: %s' % dataset_id)


class DataSet(ABC):
    def __init__(self, base_dir=None):
        super().__init__()
        self._base_dir = base_dir


class Cars3D(DataSet):

    def __init__(self, attribute=0, value=0, split='train', knn=False, all=False, setting='multi', method='msad', base_dir='PATH_TO_CAR3D'):
        super().__init__(base_dir)
        self.split = 0 if split == 'train' else 1
        self.knn = knn
        self.all = all
        self.__data_path = os.path.join(base_dir, 'cars3d.npz')
        self.data = np.load(self.__data_path)
        self.imgs = self.data['imgs']
        self.labels = self.data['factors']
        self.sizes = self.data['factor_sizes']
        max_value = self.sizes[attribute]

        self.train_labels = self.labels[:, attribute]
        test_indices = []
        train_indices = []
        train_labels = []
        test_labels = []

        anom_train = 0
        anom_test = 0
        for val in range(max_value):
            val_indices = np.argwhere(self.labels[:, attribute] == val)[:, 0]
            split_val = int(0.85 * len(val_indices))
            test_indices.append(val_indices[split_val:])
            if setting == 'multi':
                if val != value: # != Multi class, == OCC
                    train_indices.append(val_indices[:split_val])
                    train_labels.append(np.zeros(len(train_indices[-1])))
                    test_labels.append(np.zeros(len(test_indices[-1])))
                else:
                    if not self.knn:
                        train_indices.append(val_indices[:split_val])
                        train_labels.append(np.ones(len(train_indices[-1])))
                        anom_train = len(train_indices[-1])
                    anom_test = len(test_indices[-1])
                    test_labels.append(np.ones(len(test_indices[-1])))
            else:
                if val == value: # != Multi class, == OCC
                    train_indices.append(val_indices[:split_val])
                    train_labels.append(np.zeros(len(train_indices[-1])))
                    test_labels.append(np.zeros(len(test_indices[-1])))
                else:
                    if not self.knn:
                        train_indices.append(val_indices[:split_val])
                        train_labels.append(np.ones(len(train_indices[-1])))
                        anom_train = len(train_indices[-1])
                    anom_test = len(test_indices[-1])
                    test_labels.append(np.ones(len(test_indices[-1])))
        train_indices = np.concatenate(train_indices, 0)
        test_indices = np.concatenate(test_indices, 0)
        self.train_labels = np.concatenate(train_labels, 0)
        self.test_labels = np.concatenate(test_labels, 0)

        if self.all:
            train_indices = np.arange(len(self.imgs))
            test_indices = np.arange(len(self.imgs))
            self.train_labels = np.zeros(len(train_indices))
            self.test_labels = np.zeros(len(test_indices))

        if self.split:
            print(len(test_indices), 'TEST LEN')
            print(anom_test, 'Anomalous TEST')
        else:
            print(len(train_indices), 'TRAIN LEN')
            print(anom_train, 'Anomalous Train')


        mean = (self.imgs[train_indices]/255).mean((0,1,2))
        std = (self.imgs[train_indices]/255).std((0,1,2))

        # self.imgs = torch.from_numpy(np.load('features/cars3d/clip_resnet50.npy'))
        # self.imgs = torch.from_numpy(np.load('features/cars3d/clip_vitb16.npy'))
        if not self.all:
            if method is None:
                PATH = './features/cars3d/{}/imagenet_resnet18.npy'.format(setting)
            else:
                PATH = './features/cars3d/{}/{}/a{}_v{}.npy'.format(setting, method, attribute, value)
            self.imgs = torch.from_numpy(np.load(PATH))
            self.imgs = self.imgs.type(torch.float32)

        if self.split:
            self.imgs = self.imgs[test_indices]
        else:
            self.imgs = self.imgs[train_indices]

        self.transform = ResNet18_Weights.IMAGENET1K_V1.transforms()

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, i):
        if self.all:
            img = self.transform(to_PIL(self.imgs[i]))
        else:
            img = self.imgs[i]
        if self.split:
            return img, self.test_labels[i]
        return img, self.train_labels[i]


class CelebA(DataSet):

    def __init__(self, attribute=0, value=0, split='train', knn=False, all=False, setting='multi', method='msad',
                 base_dir='PATH_TO_CELEBA'):
        super().__init__(base_dir)
        self.split = 0 if split == 'train' else 1
        self.knn = knn
        self.all = all
        filenames = sorted(os.listdir(join(self._base_dir, 'Img', 'align_test')))
        self.imgs = np.array([join(self._base_dir, 'Img', 'align_test', fname) for fname in filenames])
        self.full_content = pd.read_csv(join(self._base_dir, 'full_contents.csv')).sort_index().iloc[-len(self.imgs):]
        self.labels = self.full_content[['hair_color', 'hair_style', 'beard_styles']].to_numpy()
        self.sizes = np.amax(self.labels, axis=0) + 1
        max_value = self.sizes[attribute]
        self.train_labels = self.labels[:, attribute]
        test_indices = []
        train_indices = []
        train_labels = []
        test_labels = []

        anom_train = 0
        anom_test = 0
        for val in range(max_value):
            val_indices = np.argwhere(self.labels[:, attribute] == val)[:, 0]
            split_val = int(0.85 * len(val_indices))
            test_indices.append(val_indices[split_val:])
            if setting == 'multi':
                if val != value:  # != Multi class, == OCC
                    train_indices.append(val_indices[:split_val])
                    train_labels.append(np.zeros(len(train_indices[-1])))
                    test_labels.append(np.zeros(len(test_indices[-1])))
                else:
                    if not self.knn:
                        train_indices.append(val_indices[:split_val])
                        train_labels.append(np.ones(len(train_indices[-1])))
                        anom_train = len(train_indices[-1])
                    anom_test = len(test_indices[-1])
                    test_labels.append(np.ones(len(test_indices[-1])))
            else:
                if val == value:  # != Multi class, == OCC
                    train_indices.append(val_indices[:split_val])
                    train_labels.append(np.zeros(len(train_indices[-1])))
                    test_labels.append(np.zeros(len(test_indices[-1])))
                else:
                    if not self.knn:
                        train_indices.append(val_indices[:split_val])
                        train_labels.append(np.ones(len(train_indices[-1])))
                        anom_train = len(train_indices[-1])
                    anom_test = len(test_indices[-1])
                    test_labels.append(np.ones(len(test_indices[-1])))
        train_indices = np.concatenate(train_indices, 0)
        test_indices = np.concatenate(test_indices, 0)
        self.train_labels = np.concatenate(train_labels, 0)
        self.test_labels = np.concatenate(test_labels, 0)

        if self.all:
            train_indices = np.arange(len(self.imgs))
            test_indices = np.arange(len(self.imgs))
            self.train_labels = np.zeros(len(train_indices))
            self.test_labels = np.zeros(len(test_indices))

        if self.split:
            print(len(test_indices), 'TEST LEN')
            print(anom_test, 'Anomalous TEST')
        else:
            print(len(train_indices), 'TRAIN LEN')
            print(anom_train, 'Anomalous Train')

        if not self.all:
            if method is None:
                PATH = './features/celeba/{}/imagenet_resnet18.npy'.format(setting)
            else:
                PATH = './features/celeba/{}/{}/a{}_v{}.npy'.format(setting, method, attribute, value)
            self.imgs = torch.from_numpy(np.load(PATH))
            self.imgs = self.imgs.type(torch.float32)

        if self.split:
            self.imgs = self.imgs[test_indices]
        else:
            self.imgs = self.imgs[train_indices]

        self.transform = ResNet18_Weights.IMAGENET1K_V1.transforms()

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, i):
        if self.all:
            # img = self.transform(to_PIL(self.imgs[i]))
            img = self.transform(default_loader(self.imgs[i]))
        else:
            img = self.imgs[i]
        if self.split:
            return img, self.test_labels[i]
        return img, self.train_labels[i]



class RaFD(DataSet):

    def __look_for_labels(self, key, value):
        new_label = None
        for i in range(len(self.mappings[key])):
            if self.mappings[key][i] == value:
                new_label = i
        if new_label is None:
            new_label = len(self.mappings[key])
            self.mappings[key].append(value)
        return new_label

    def __list_imgs(self):
        img_paths = []
        labels = []
        self.mappings = {'idx': [], 'identity': [], 'expression_ids': []}
        regex = re.compile('Rafd(\d+)_(\d+)_(\w+)_(\w+)_(\w+)_(\w+).jpg')
        for file_name in os.listdir(self._base_dir):
            img_path = os.path.join(self._base_dir, file_name)
            idx, identity_id, description, gender, expression_id, angle = regex.match(file_name).groups()
            img_paths.append(img_path)
            labels.append([self.__look_for_labels('idx', idx), self.__look_for_labels('identity', identity_id),
                           self.__look_for_labels('expression_ids', expression_id)])
        self.imgs, self.labels = np.array(img_paths), np.array(labels)
        self.sizes = np.zeros(3, dtype=np.int32)
        self.sizes[0] = len(self.mappings['idx'])
        self.sizes[1] = len(self.mappings['identity'])
        self.sizes[2] = len(self.mappings['expression_ids'])

    def __init__(self, attribute=0, value=0, split='train', knn=False, all=False, setting='multi', method='msad',
                 base_dir='PATH_TO_RAFD'):
        super().__init__(base_dir)
        self.split = 0 if split == 'train' else 1
        self.knn = knn
        self.all = all
        self._base_dir = os.path.join(base_dir, 'Rafd')
        self.__list_imgs()
        max_value = self.sizes[attribute]
        self.train_labels = self.labels[:, attribute]
        test_indices = []
        train_indices = []
        train_labels = []
        test_labels = []

        anom_train = 0
        anom_test = 0
        for val in range(max_value):
            val_indices = np.argwhere(self.labels[:, attribute] == val)[:, 0]
            split_val = int(0.85 * len(val_indices))
            test_indices.append(val_indices[split_val:])
            if setting == 'multi':
                if val != value: # != Multi class, == OCC
                    train_indices.append(val_indices[:split_val])
                    train_labels.append(np.zeros(len(train_indices[-1])))
                    test_labels.append(np.zeros(len(test_indices[-1])))
                else:
                    if not self.knn:
                        train_indices.append(val_indices[:split_val])
                        train_labels.append(np.ones(len(train_indices[-1])))
                        anom_train = len(train_indices[-1])
                    anom_test = len(test_indices[-1])
                    test_labels.append(np.ones(len(test_indices[-1])))
            else:
                if val == value: # != Multi class, == OCC
                    train_indices.append(val_indices[:split_val])
                    train_labels.append(np.zeros(len(train_indices[-1])))
                    test_labels.append(np.zeros(len(test_indices[-1])))
                else:
                    if not self.knn:
                        train_indices.append(val_indices[:split_val])
                        train_labels.append(np.ones(len(train_indices[-1])))
                        anom_train = len(train_indices[-1])
                    anom_test = len(test_indices[-1])
                    test_labels.append(np.ones(len(test_indices[-1])))
        train_indices = np.concatenate(train_indices, 0)
        test_indices = np.concatenate(test_indices, 0)
        self.train_labels = np.concatenate(train_labels, 0)
        self.test_labels = np.concatenate(test_labels, 0)

        if self.all:
            train_indices = np.arange(len(self.imgs))
            test_indices = np.arange(len(self.imgs))
            self.train_labels = np.zeros(len(train_indices))
            self.test_labels = np.zeros(len(test_indices))

        if self.split:
            print(len(test_indices), 'TEST LEN')
            print(anom_test, 'Anomalous TEST')
        else:
            print(len(train_indices), 'TRAIN LEN')
            print(anom_train, 'Anomalous Train')


        # mean = (self.imgs[train_indices]/255).mean((0,1,2))
        # std = (self.imgs[train_indices]/255).std((0,1,2))

        # self.imgs = torch.from_numpy(np.load('features/cars3d/clip_resnet50.npy'))
        # self.imgs = torch.from_numpy(np.load('features/cars3d/clip_vitb16.npy'))
        if not self.all:
            if method is None:
                PATH = './features/rafd/{}/imagenet_resnet18.npy'.format(setting)
            else:
                PATH = './features/rafd/{}/{}/a{}_v{}.npy'.format(setting, method, attribute, value)
            self.imgs = torch.from_numpy(np.load(PATH))
            self.imgs = self.imgs.type(torch.float32)

        if self.split:
            self.imgs = self.imgs[test_indices]
        else:
            self.imgs = self.imgs[train_indices]

        self.transform = ResNet18_Weights.IMAGENET1K_V1.transforms()

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, i):
        if self.all:
            # img = self.transform(to_PIL(self.imgs[i]))
            img = self.transform(default_loader(self.imgs[i]))
        else:
            img = self.imgs[i]
        if self.split:
            return img, self.test_labels[i]
        return img, self.train_labels[i]
