#!/usr/bin/env python
# -*-coding:utf-8 -*-
import os
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import numpy as np 
from torchvision import transforms
import cv2 
import torch 
import data.dsprint as dd 


class Shape(Dataset):
    def __init__(self, image_path, image_label, transform):
        super().__init__()
        self.image_path = image_path
        self.image_label = image_label
        self.transform = transform

    def __len__(self):
        return len(self.image_path)

    def _load_image(self, id):
        return Image.open(self.image_path[id]).convert("RGB")

    def __getitem__(self, index):
        image = self._load_image(index)
        if self.transform is not None:
            image = self.transform(image)
        return image
    
    
class ShapeDsprint(Dataset):
    def __init__(self, image, label, transform):
        super().__init__()
        self.image = image
        self.targets = label
        self.transform = transform
        self.index = np.arange(len(self.image))

    def __len__(self):
        return len(self.image)

    def __getitem__(self, index):
        image = self.image[index]
        if self.transform is not None:
            image = self.transform(image)
        # s_label = self.targets[index]
        s_label = self.targets[index]
        return image, s_label 
    

def get_shape_transforms():
    tr_transform = [#transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Resize((32, 32))]
    tt_transform = [transforms.ToTensor(),
                    transforms.Resize((32, 32))]
    train_transforms = transforms.Compose(tr_transform)
    test_tranforms = transforms.Compose(tt_transform)
    return train_transforms, test_tranforms


def get_shape_data(path="../image_dataset/toy_shape/"):
    train_transforms, test_transforms = get_shape_transforms()
    image_filenames, labels = get_image_filenames(path)
    print("There are %d training images and %d validation images" % (len(image_filenames[0]), 
                                                                     len(image_filenames[1])))
    tr_filename = [path + "/images/" + v for v in image_filenames[0]]
    val_filename = [path + "/images/" + v for v in image_filenames[1]]
    tr_data = Shape(tr_filename, labels[0], train_transforms)
    val_data = Shape(val_filename, labels[1], test_transforms)
    return tr_data, val_data 


def get_dsprint_data(opt="train_vae"):
    train_transforms, test_transforms = get_shape_transforms()
    tr_image, tr_label, tt_image, tt_label = dd.prepare_data(opt=opt)
    if opt == "train_cls":
        tr_latent = np.load("../exp_data/shape_vae/version_01/mu_std_group_training.npy", allow_pickle=True)
        tt_latent = np.load("../exp_data/shape_vae/version_01/mu_std_group_testing.npy", allow_pickle=True)
        assert len(tr_latent[0]) == len(tr_image)
        assert len(tt_latent[0]) == len(tt_image)
        tr_image = np.expand_dims(tr_latent[0].copy(), axis=1)
        tt_image = np.expand_dims(tt_latent[0].copy(), axis=1)
    tr_data = ShapeDsprint(tr_image, tr_label, train_transforms)
    tt_data = ShapeDsprint(tt_image, tt_label, test_transforms)
    return tr_data, tt_data 


def get_dsprite_data_split_tr(client_id):
    path = "../image_dataset/dsprite/real_images/1/%02d/" % client_id 
    tr_content = np.load(path + "/train.npz")
    tr_im, tr_la = tr_content["arr_0"].astype(np.float32) / 255.0, tr_content["arr_1"].astype(np.int64)
    tr_im = np.transpose(tr_im, (0, 2, 3, 1))
    tr_transform, tt_transform = get_shape_transforms()
    tr_data = ShapeDsprint(tr_im, tr_la, tr_transform)
    return tr_data 


def get_dsprite_data_tt():
    path = "../image_dataset/dsprite/real_images/1/"
    tt_im, tt_la = [], []
    for i in range(12):
        im_content = np.load(path + "%02d/val.npz" % i)
        im, la = im_content["arr_0"].astype(np.float32) / 255.0, im_content["arr_1"].astype(np.int64)
        im = np.transpose(im, (0, 2, 3, 1))
        tt_im.append(im)
        tt_la.append(la)
    tt_im = np.concatenate(tt_im, axis=0)
    tt_la = np.concatenate(tt_la, axis=0)
    np.random.seed(102)
    val_index = np.random.choice(np.arange(len(tt_la)), 128 * 5, replace=False)
    val_im, val_la = tt_im[val_index], tt_la[val_index]
    _, tt_transform = get_shape_transforms()
    return ShapeDsprint(tt_im, tt_la, tt_transform), ShapeDsprint(val_im, val_la, tt_transform)


def get_synthetic_data(conf):
    path = "../image_dataset/dsprite/sync_images/num_images_train_synthetic_%d/" % conf.num_images_train_synthetic 
    sync_im, sync_label = [], []
    for i in range(12):
        _p = path + "%d" % i + "/synthetic_epoch_%d/25p_full/data.npz" % (conf.synthetic_epoch) 
        _content = np.load(_p)
        _im, _la  = _content["arr_0"].astype(np.float32) / 255.0, _content["arr_1"].astype(np.int64)
        _im = np.transpose(_im, (0, 2, 3, 1))
        if conf.num_synthetic_images < len(_im):
            sub_index = np.random.choice(np.arange(len(_im)), conf.num_synthetic_images, replace=False)
            _im = _im[sub_index]
            _la = _la[sub_index]
        sync_im.append(_im)
        sync_label.append(_la)
    sync_im = np.concatenate(sync_im, axis=0)
    sync_label = np.concatenate(sync_label, axis=0)
    shuffle_index = np.random.choice(np.arange(len(sync_im)), len(sync_im), replace=False)
    split_index = np.split(shuffle_index, 12)
    sync_im_group = [sync_im[v] for v in split_index]
    sync_la_group = [sync_label[v] for v in split_index]
    return sync_im_group, sync_la_group


def get_dataloader(tr_data, val_data, batch_size, workers):
    train_loader = DataLoader(tr_data, batch_size=batch_size, shuffle=True,
                              drop_last=True, pin_memory=True, num_workers=workers)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False,
                            drop_last=True, pin_memory=True, num_workers=workers)
    return train_loader, val_loader

    
    
def get_image_filenames(path="../image_dataset/toy_shape/"):
    images = np.array(sorted([v for v in os.listdir(path + "/images/") if ".jpg" in v]))
    label = np.load(path + "/labels.npy", allow_pickle=True)
    label = np.array([v for q in label for v in q])
    
    shape_index = label[:, 1]
    tr_index, val_index = [], []
    np.random.seed(1240)
    for i in np.unique(shape_index):
        index = np.where(shape_index == i)[0]
        val_percent = 0.1 
        _val = np.random.choice(np.arange(len(index)), int(len(index) * val_percent), replace=False)
        val_index.append(index[_val])
        tr_index.append(index[np.delete(np.arange(len(index)), _val)])
    tr_index = np.array([v for q in tr_index for v in q])
    val_index = np.array([v for q in val_index for v in q])
    tr_images, val_images = images[tr_index], images[val_index]
    tr_labels, val_labels = label[tr_index], label[val_index]
    return [tr_images, val_images], [tr_labels, val_labels]


def get_subset_of_image(image, label, shape_index, color_index, scale_index, path="../image_dataset/toy_shape/"):
    index_group = []
    for s in shape_index:
        for c in color_index:
            for k in scale_index:
                index = np.where(label[:, 1] == s)[0]
                index = index[np.where(label[index, 3] == c)[0]]
                index = index[np.where(label[index, 2] == k)[0]]
                print("There are %d images with shape index %d and color index %d" % (len(index),
                                                                                    s, c))
            index_group.append(index)
    index_group = np.array([v for q in index_group for v in q])
    sub_image = [path + "/images/" + v for v in image[index_group]]
    return sub_image, label[index_group]


def get_test_tensor(image, image_shape=[32, 32, 3], device=torch.device("cuda")):
    images = np.concatenate([[cv2.imread(v).astype(np.float32)] for v in image], axis=0)[:, :, :, ::-1]
    images = images / 255.0    
    image_npy = images.copy()
    images = torch.from_numpy(images).to(torch.float32).permute(0, 3, 1, 2)
    return images.to(device).view(-1, image_shape[2], image_shape[0], image_shape[1]), image_npy 


def get_test_tensor_npy(image_loader, batch_size, shuffle):
    val_loader = DataLoader(image_loader, batch_size=batch_size, shuffle=shuffle,
                            drop_last=False, pin_memory=True, num_workers=1)
    return val_loader 



