import torch
import os
import PIL
from PIL import Image
import numpy as np
import pandas as pd
import h5py
import pickle
from scipy import sparse

import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
from torchvision import transforms, datasets

class VectorDataset(torch.utils.data.Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __getitem__(self, i):
        return (self.X[i], self.Y[i])
    
    def __len__(self):
        return self.X.shape[0]

class VectorDataset2(torch.utils.data.Dataset):
    def __init__(self, X, Y, S):
        self.X = X
        self.Y = Y
        self.S = S

    def __getitem__(self, i):
        return (self.X[i], self.Y[i], self.S[i])
    
    def __len__(self):
        return self.X.shape[0]

class npy_ysx_pl(LightningDataModule):
    def __init__(self, train_data_dir, test_data_dir = None, batch_size: int = 100):
        super().__init__()
        self.train_data_dir = train_data_dir
        self.test_data_dir = test_data_dir
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = npy_ysx(self.train_data_dir)
        if self.test_data_dir is not None:
            self.val_dataset = npy_ysx(self.test_data_dir)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, self.batch_size, num_workers = 5, shuffle = True, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, self.batch_size, num_workers = 5, shuffle = False, drop_last=False)

class MNIST_pl(LightningDataModule):
    def __init__(self, data_dir: str = './data', batch_size: int = 100, class_no: bool = False):
        super().__init__()
        self.class_no = class_no
        self.data_dir = data_dir
        self.num_classes = 10
        self.batch_size = batch_size

    def setup(self, stage=None):
        train_data = datasets.MNIST(root=self.data_dir, train=True, download=True)
        test_data = datasets.MNIST(root=self.data_dir, train=False, download=True)

        n = len(train_data)
        train_X = torch.zeros(n, 1, 28, 28)
        for i in range(n):
            train_X[i,0] = train_data.data[i].squeeze()
        train_X /= 255.

        n = len(test_data)
        test_X = torch.zeros(n, 1, 28, 28)
        for i in range(n):
            test_X[i,0] = test_data.data[i].squeeze()
        test_X /= 255.

        if self.class_no:
            train_Y = train_data.targets
            test_Y = test_data.targets
        else:
            train_Y = F.one_hot(train_data.targets, self.num_classes)
            test_Y = F.one_hot(test_data.targets, self.num_classes)

        self.train_dataset = VectorDataset(train_X, train_Y)
        self.val_dataset = VectorDataset(test_X, test_Y)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, self.batch_size, num_workers = 5, shuffle = True, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, self.batch_size, num_workers = 5, shuffle = False, drop_last=False)

class eYaleB_pl(LightningDataModule):
    def __init__(self, data_dir: str = './data/YaleBFace128', batch_size: int = 100, class_no: bool = False, lighting = False):
        super().__init__()
        self.class_no = class_no
        self.data_dir = data_dir
        self.num_classes = 38
        self.batch_size = batch_size
        self.lighting = lighting

    def setup(self, stage=None):
        with open('%s/YaleBFaceTrain.dat' % self.data_dir, 'rb') as f:
            train_data = pickle.load(f)
        with open('%s/YaleBFaceTest.dat' % self.data_dir, 'rb') as f:
            test_data = pickle.load(f)

        n = len(train_data['person'])
        train_X = torch.zeros(n, 1, 128, 128)
        train_S = torch.zeros(n, 2)
        for i in range(n):
            train_X[i,0] = torch.from_numpy(train_data['image'][i].reshape((128, 128)))
            train_S[i,0] = train_data['azimuth'][i]
            train_S[i,1] = train_data['elevation'][i]
        train_X /= 255.
        train_S[:,0] = train_S[:,0]/180.0
        train_S[:,1] = train_S[:,1]/90.0

        m = len(test_data['person'])
        test_X = torch.zeros(m, 1, 128, 128)
        test_S = torch.zeros(m, 2)
        for i in range(m):
            test_X[i,0] = torch.from_numpy(test_data['image'][i].reshape((128, 128)))
            test_S[i,0] = test_data['azimuth'][i]
            test_S[i,1] = test_data['elevation'][i]
        test_X /= 255.
        test_S[:,0] = test_S[:,0]/180.0
        test_S[:,1] = test_S[:,1]/90.0

        train_Y = torch.zeros(n)
        train_Y[:] = torch.from_numpy(train_data['person'])
        train_Y = train_Y.type(torch.long)

        test_Y = torch.zeros(m)
        test_Y[:] = torch.from_numpy(test_data['person'])
        test_Y = test_Y.type(torch.long)
        if not self.class_no:
            train_Y = F.one_hot(train_Y, self.num_classes)
            test_Y = F.one_hot(test_Y, self.num_classes)

        if self.lighting:
            self.train_dataset = VectorDataset2(train_X, train_Y, train_S)
            self.val_dataset = VectorDataset2(test_X, test_Y, test_S)
        else:
            self.train_dataset = VectorDataset(train_X, train_Y)
            self.val_dataset = VectorDataset(test_X, test_Y)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, self.batch_size, num_workers = 5, shuffle = True, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, self.batch_size, num_workers = 5, shuffle = False, drop_last=True)

class vggface2_h5(torch.utils.data.Dataset):
    def __init__(self, data_home, train = True, attr = False, attr_path = None):
        self.attr = attr
        self.attr_path = attr_path
        if attr:
            db = h5py.File('%s/train.h5' % data_home, 'r')
            self.x = db['image_attr']
            self.y = db['label_attr']
            self.s = db['attribute_attr']
        else:
            if train:
                db = h5py.File('%s/train.h5' % data_home, 'r')
                if attr_path is not None:
                    self.s = np.load("%s/vgg_attr_train.npy"% attr_path)
            else:
                db = h5py.File('%s/test.h5' % data_home, 'r')
                if attr_path is not None:
                    self.s = np.load("%s/vgg_attr_test.npy" % attr_path)

            self.x = db['image']
            self.y = db['label']

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        if self.attr or self.attr_path is not None:
            return [torch.from_numpy(self.x[idx])/255., torch.from_numpy(np.array(self.y[idx])).long(), torch.from_numpy(np.array(self.s[idx]))]
        return [torch.from_numpy(self.x[idx])/255., torch.from_numpy(np.array(self.y[idx])).long()]

class vggface2_h5_pl(LightningDataModule):
    def __init__(self, data_dir: str = '.', batch_size: int = 100, attr = False, attr_path = None):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.attr = attr
        self.attr_path = attr_path

    def setup(self, stage=None):
        self.train_dataset = vggface2_h5(self.data_dir, train=True, attr = False, attr_path = self.attr_path)
        self.val_dataset = vggface2_h5(self.data_dir, train=False, attr = False, attr_path = self.attr_path)
        if self.attr:
            self.train_dataset2 = vggface2_h5(self.data_dir, train=True, attr = self.attr)

    def train_dataloader(self):
        if self.attr:
            loader_main = DataLoader(self.train_dataset, batch_size = self.batch_size, num_workers = 5, shuffle = True, drop_last=True)
            loader_sub = DataLoader(self.train_dataset2, batch_size = 4, shuffle = True, drop_last=True)
            return {"main":loader_main, "sub":loader_sub}
        return DataLoader(self.train_dataset, self.batch_size, num_workers = 5, shuffle = True, drop_last=True)

    def val_dataloader(self):
        if self.attr:
            loader_main = DataLoader(self.train_dataset, batch_size = self.batch_size, num_workers = 5, shuffle = True, drop_last=True)
            loader_sub = DataLoader(self.train_dataset2, batch_size = 4, shuffle = True, drop_last=True)
            return {"main":loader_main, "sub":loader_sub}
        return DataLoader(self.val_dataset, self.batch_size, num_workers = 5, shuffle = True, drop_last=True)