import os
import math
import random
import json
import cv2

import torch
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.folder import pil_loader
import torchvision.transforms.functional as T
import h5py
import numpy as np
from scipy import ndimage, misc

import pickle
import json
from PIL import Image
from torchvision import transforms
import pandas as pd


CLASSES = {
    "shape": ["square", "circle", "triangle"],
    "color": ["red", "yellow", "blue"]
}

colors = ['red', 'yellow', 'blue']
shapes = ['square', 'circle', 'triangle']
sizes = [0.2, 0.4]

shape_num = len(shapes)
color_num = len(colors)
size_num = len(sizes)

# max_objects in dataset class


def one_hot(a, num_classes):
    return np.squeeze(np.eye(num_classes)[a.reshape(-1)])


def get_loader(dataset, batch_size, num_workers=8, shuffle=False):
    return torch.utils.data.DataLoader(
        dataset,
        shuffle=shuffle,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=num_workers,
        drop_last=True,
    )


def to_torch_img(img):
    img = img / 255
    img = np.moveaxis(img, -1, 0)
    img = np.float32(img)
    t_img = torch.from_numpy(img)  # .to(device)
    # t_img = t_img / 255  # 0 - 255 to 0.0 - 1.0
    return t_img


def __load_image_yolo(path, img_size):
    img = cv2.imread(path, -1)
    # img = cv2.resize(img, (img_size, img_size))
    img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
    img = img / 255.0
    return img


def load_image_yolo(path, img_size, stride=32):
    img0 = cv2.imread(path)  # BGR
    assert img0 is not None, 'Image Not Found ' + path
    img = img0

    # Convert
    img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB and HWC to CHW
    img = np.ascontiguousarray(img)
    return img, img0


def load_images_and_labels(dataset='clevr-hans3', split=True):
    image_paths = []
    labels = []
    base_folder = 'data/clevr/' + dataset + '/' + split + '/'
    if dataset == 'clevr-hans3':
        for i, cl in enumerate(['class0', 'class1', 'class2']):
            folder = base_folder + cl + '/'
            filenames = sorted(os.listdir(folder))
            for filename in filenames:
                if filename != '.DS_Store':
                    image_paths.append(os.path.join(folder, filename))
                    labels.append(i)
    elif dataset == 'clevr-hans7':
        for i, cl in enumerate(['class0', 'class1', 'class2', 'class3', 'class4', 'class5', 'class6']):
            folder = base_folder + cl + '/'
            filenames = sorted(os.listdir(folder))
            for filename in filenames:
                if filename != '.DS_Store':
                    image_paths.append(os.path.join(folder, filename))
                    labels.append(i)
    return image_paths, labels


def load_image_clevr(path):
    img = cv2.imread(path)  # BGR
    assert img is not None, 'Image Not Found ' + path

    # Convert
    img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB and HWC to CHW
    img = np.ascontiguousarray(img)
    return img


class CLEVRHans(torch.utils.data.Dataset):
    def __init__(self, dataset, base_path, split, img_size=128):
        super().__init__()
        self.img_size = img_size
        self.dataset = dataset
        assert split in {
            "train",
            "val",
            "test",
        }  # note: test isn't very useful since it doesn't have ground-truth scene information
        self.base_path = base_path
        self.split = split
        # self.max_objects = 10
        self.transform = transforms.Compose(
            [transforms.Resize((img_size, img_size))]
        )
        self.image_paths, self.labels = load_images_and_labels(
            dataset=dataset, split=split)

    @ property
    def images_folder(self):
        return os.path.join(self.base_path, "images", self.split)

    @ property
    def scenes_path(self):
        if self.split == "test":
            raise ValueError("Scenes are not available for test")
        return os.path.join(
            self.base_path, "scenes", "CLEVR_{}_scenes.json".format(self.split)
        )

    def img_db(self):
        path = os.path.join(self.base_path, "{}-images.h5".format(self.split))
        return h5py.File(path, "r")

    def __getitem__(self, item):
        # image = load_image_clevr(
        # self.image_paths[item])
        # image = torch.from_numpy(image).type(torch.float32)  # / 255.
        path = self.image_paths[item]
        image = Image.open(path).convert("RGB")
        image = transforms.ToTensor()(image)[:3, :, :]
        # print('raw image', image.size)
        image = self.transform(image)
        image = (image - 0.5) * 2.0  # Rescale to [-1, 1].
        # print(image.shape)
        if self.dataset == 'clevr-hans3':
            labels = torch.zeros((3, ), dtype=torch.float32)
        elif self.dataset == 'clevr-hans7':
            labels = torch.zeros((7, ), dtype=torch.float32)
        labels[self.labels[item]] = 1.0
        # labels = torch.tensor(self.labels[item], dtype=torch.float32)
        return image, labels

    def __len__(self):
        return len(self.labels)


class CLEVRConcept(torch.utils.data.Dataset):
    def __init__(self, dataset, split):
        self.dataset = dataset
        self.split = split
        self.data, self.labels = self.load_csv()
        print('concept data: ', self.data.shape, 'labels: ', len(self.labels))

    def load_csv(self):
        data = []
        labels = []
        pos_csv_data = pd.read_csv(
            'data/clevr/concept_data/' + self.split + '/' + self.dataset + '_pos' + '.csv', delimiter=' ')
        pos_data = pos_csv_data.values
        #pos_labels = np.ones((len(pos_data, )))
        pos_labels = np.zeros((len(pos_data, )))
        neg_csv_data = pd.read_csv(
            'data/clevr/concept_data/' + self.split + '/' + self.dataset + '_neg' + '.csv', delimiter=' ')
        neg_data = neg_csv_data.values
        #neg_labels = np.zeros((len(neg_data, )))
        neg_labels = np.ones((len(neg_data, )))
        data = torch.tensor(np.concatenate(
            [pos_data, neg_data], axis=0), dtype=torch.float32)
        labels = torch.tensor(np.concatenate(
            [pos_labels, neg_labels], axis=0), dtype=torch.float32)
        return data, labels

    def __getitem__(self, item):
        return self.data[item], self.labels[item]

    def __len__(self):
        return len(self.data)
