import os
import torchvision
import yaml
from torchvision.datasets.stl10 import STL10
from torchvision.datasets import FashionMNIST, ImageFolder
from torchvision.datasets.cifar import CIFAR100, CIFAR10
from dataloaders.cub200 import Cub2011
from dataloaders.pets import Pets
from dataloaders.bmw10 import BMW10
from utils import INNTransform
import torch
import pandas as pd
from PIL import Image

class Cub2011_MOD(torch.utils.data.Dataset):
    base_folder = 'images'

    def __init__(self, root, train=True, transform=None, download=True):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.train = train

        try:
            self._load_metadata()
        except Exception:
            raise FileNotFoundError('Missing data for cub20')

    def _load_metadata(self):
        images = pd.read_csv(os.path.join(self.root, 'images.txt'), sep=' ',
                             names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(self.root, 'image_class_labels.txt'),
                                         sep=' ', names=['img_id', 'target'])
        train_test_split = pd.read_csv(os.path.join(self.root, 'train_test_split.txt'),
                                       sep=' ', names=['img_id', 'is_training_img'])

        self.class_names = pd.read_csv(os.path.join(self.root, 'classes.txt'),
                                       sep=' ', names=['target', 'name'])
        data = images.merge(image_class_labels, on='img_id')
        self.data = data.merge(train_test_split, on='img_id')

        if self.train:
            self.data = self.data[self.data.is_training_img == 1]
        else:
            self.data = self.data[self.data.is_training_img == 0]

    def __len__(self):
        return len(self.data)

    def get_custom_item(self, idx):
        sample = self.data.iloc[idx]
        path = os.path.join(self.root, self.base_folder, sample.filepath)
        img = self.loader(path)
        name = self.class_names[self.class_names.target == sample.target].name.to_string(index=False).strip().split('.')[-1].replace('_', ' ')
        return img, name

    def __getitem__(self, idx):
        sample = self.data.iloc[idx]
        path = os.path.join(self.root, self.base_folder, sample.filepath)
        target = sample.target - 1  # Targets start at 1 by default, so shift to 0
        with open(path, 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img, target

class InnSTL10(STL10):
    def __init__(self, root, img_transform, neg_factor=None, num_negatives=None, num_classes=10, **kwargs):
        super().__init__(root, **kwargs)
        self.num_classes = num_classes
        self.inn_transform = INNTransform(start_transform=img_transform, num_negatives=num_negatives,
                                          neg_conf_factor=neg_factor)

    def __getitem__(self, item):
        sample, target = super(InnSTL10, self).__getitem__(item)
        imgs, y_source, inn_target, y_target = self.inn_transform(sample, target, self.num_classes)
        return imgs, y_source, inn_target, y_target


class InnCIFAR100(CIFAR100):
    def __init__(self, root, img_transform, neg_factor=None, num_negatives=None, num_classes=100, **kwargs):
        super().__init__(root, **kwargs)
        self.num_classes = num_classes
        self.inn_transform = INNTransform(start_transform=img_transform, num_negatives=num_negatives,
                                          neg_conf_factor=neg_factor)

    def __getitem__(self, item):
        sample, target = super(InnCIFAR100, self).__getitem__(item)
        imgs, y_source, inn_target, y_target = self.inn_transform(sample, target, self.num_classes)
        return imgs, y_source, inn_target, y_target


class InnCIFAR10(CIFAR10):
    def __init__(self, root, img_transform, neg_factor=None, num_negatives=None, num_classes=10, **kwargs):
        super().__init__(root, **kwargs)
        self.num_classes = num_classes
        self.inn_transform = INNTransform(start_transform=img_transform, num_negatives=num_negatives,
                                          neg_conf_factor=neg_factor)

    def __getitem__(self, item):
        sample, target = super(InnCIFAR10, self).__getitem__(item)
        imgs, y_source, inn_target, y_target = self.inn_transform(sample, target, self.num_classes)
        return imgs, y_source, inn_target, y_target


class InnFMNIST(FashionMNIST):
    def __init__(self, root, img_transform, neg_factor=None, num_negatives=None, num_classes=10, **kwargs):
        super().__init__(root, **kwargs)
        self.num_classes = num_classes
        self.inn_transform = INNTransform(start_transform=img_transform, num_negatives=num_negatives,
                                          neg_conf_factor=neg_factor)

    def __getitem__(self, item):
        sample, target = super(FashionMNIST, self).__getitem__(item)
        imgs, y_source, inn_target, y_target = self.inn_transform(sample, target, self.num_classes)
        return imgs, y_source, inn_target, y_target

class InnImgNet(ImageFolder):
    def __init__(self, root, split, img_transform, neg_factor=None, num_negatives=None, num_classes=1000, **kwargs):
        root = os.path.join(root, split)
        super().__init__(root, **kwargs)
        self.num_classes = num_classes
        self.inn_transform = INNTransform(start_transform=img_transform, num_negatives=num_negatives,
                                          neg_conf_factor=neg_factor)

    def __getitem__(self, item):
        sample, target = super(ImageFolder, self).__getitem__(item)
        imgs, y_source, inn_target, y_target = self.inn_transform(sample, target, self.num_classes)
        return imgs, y_source, inn_target, y_target


class InnCub(Cub2011):
    def __init__(self, root, img_transform, neg_factor=None, num_negatives=None, num_classes=200, **kwargs):
        super().__init__(root, **kwargs)
        self.num_classes = num_classes
        self.inn_transform = INNTransform(start_transform=img_transform, num_negatives=num_negatives,
                                          neg_conf_factor=neg_factor)

    def __getitem__(self, item):
        sample, target = super().__getitem__(item)
        imgs, y_source, inn_target, y_target = self.inn_transform(sample, target, self.num_classes)
        return imgs, y_source, inn_target, y_target

class InnCub20(Cub2011_MOD):
    def __init__(self, root, img_transform, neg_factor=None, num_negatives=None, num_classes=20, **kwargs):
        super().__init__(root, **kwargs)
        self.num_classes = num_classes
        self.inn_transform = INNTransform(start_transform=img_transform, num_negatives=num_negatives,
                                          neg_conf_factor=neg_factor)

    def __getitem__(self, item):
        sample, target = super().__getitem__(item)
        imgs, y_source, inn_target, y_target = self.inn_transform(sample, target, self.num_classes)
        return imgs, y_source, inn_target, y_target

class InnPets(Pets):
    def __init__(self, root, img_transform, neg_factor=None, num_negatives=None, **kwargs):
        super().__init__(root, **kwargs)
        self.inn_transform = INNTransform(start_transform=img_transform, num_negatives=num_negatives,
                                          neg_conf_factor=neg_factor)

    def __getitem__(self, item):
        sample, target = super().__getitem__(item)
        imgs, y_source, inn_target, y_target = self.inn_transform(sample, target, self.num_classes)
        return imgs, y_source, inn_target, y_target


class InnTiny(ImageFolder):
    def __init__(self, root, img_transform, neg_factor=None, num_negatives=None, num_classes=200, **kwargs):
        root = root + '/tiny-64/train'
        super().__init__(root, **kwargs)
        self.num_classes = num_classes
        self.inn_transform = INNTransform(start_transform=img_transform, num_negatives=num_negatives,
                                          neg_conf_factor=neg_factor)

    def __getitem__(self, item):
        sample, target = super().__getitem__(item)
        imgs, y_source, inn_target, y_target = self.inn_transform(sample, target, self.num_classes)
        return imgs, y_source, inn_target, y_target

class InnBMW10(BMW10):
    def __init__(self, root, img_transform, neg_factor=None, num_negatives=None, num_classes=10, **kwargs):
        super().__init__(root, **kwargs)
        self.num_classes = num_classes
        self.inn_transform = INNTransform(start_transform=img_transform, num_negatives=num_negatives,
                                          neg_conf_factor=neg_factor)

    def __getitem__(self, item):
        sample, target = super().__getitem__(item)
        imgs, y_source, inn_target, y_target = self.inn_transform(sample, target, self.num_classes)
        return imgs, y_source, inn_target, y_target
