import torch
import numpy as np
import torchvision.datasets as datasets
from torchvision import transforms
import scanpy as sc
from PIL import Image
import os
from sklearn.preprocessing import LabelEncoder
from scGeneFit.functions import load_example_data


from gwdr.src.utils import PCA


DTYPE = torch.double

user_path = os.path.expanduser('~')
print('user_path:', user_path)

# The following line needs to be adapted to your specific setup
data_folder = os.path.abspath('./code_DistR/gwdr/data'),
print('data_folder :', data_folder) 


def load_dataset(name: str, d: int=50, device: str='cpu', Yonly=False):
    assert name in dataset_dict.keys()
    X, Y = dataset_dict[name]()
    if Yonly:
        return Y.to(device=device, dtype=torch.int64)
    if d is not None and X.shape[-1] > d:
        X = PCA(n_components=d).fit_transform(X)
    return X.to(device=device), Y.to(device=device, dtype=torch.int64)


def load_mnist(N=10000):
    mnist_trainset = datasets.MNIST(root=data_folder, train=True, download=False, transform=None)
    #mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)
    return torch_dataset(mnist_trainset, N=N, p=28)

def load_small_mnist(N=2000):
    mnist_trainset = datasets.MNIST(root=data_folder, train=True, download=False, transform=None)
    #mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)
    return torch_dataset(mnist_trainset, N=N, p=28)

def load_fashion_mnist(N=10000):
    fmnist_trainset = datasets.FashionMNIST(root=data_folder, train=True, download=False, transform=None)
    #fmnist_testset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=None)
    return torch_dataset(fmnist_trainset, N=N, p=28)

def load_small_fashion_mnist(N=2000):
    fmnist_trainset = datasets.FashionMNIST(root=data_folder, train=True, download=False, transform=None)
    #fmnist_testset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=None)
    return torch_dataset(fmnist_trainset, N=N, p=28)

def torch_dataset(dataset, N=2000, p=28):
    X = torch.empty(N,p,p)
    Y = torch.empty(N)
    converter = transforms.ToTensor()
    for i in range(N):
        img, label = dataset[i]
        X[i] = converter(img).squeeze()
        Y[i] = label
    X = X.reshape(N,-1).to(DTYPE)
    return X, Y


def load_coil(dir=None):
    """
    Ref : Sameer A Nene, Shree K Nayar, Hiroshi Murase, et al. 
    Columbia object image library (coil-20). 1996.
    """
    if dir is None:
        dir = data_folder + '/coil-20-proc/'
    n = 1440
    p = 16384
    X = torch.empty((n,p), dtype=DTYPE)
    Y = torch.empty(n)
    imgs = []
    for i,filename in enumerate(os.listdir(dir)):
        img = Image.open(os.path.join(dir, filename))
        imgs.append(img)
        convert_tensor = transforms.ToTensor() 
        X[i] = convert_tensor(img)[0].view(-1).to(DTYPE)
        if filename[4]=='_':
            Y[i] = int(filename[3])
        else:
            Y[i] = int(filename[3:5])
    return X, Y


def load_snareseq1():
    """
    Ref : Chen, S., Lake, B. B., and Zhang, K. High-throughput
    sequencing of transcriptome and chromatin accessibility
    in the same cell. Nature Biotechnology, 37(12):1452–
    1457, 2019.
    """
    X = torch.from_numpy(np.genfromtxt(data_folder+"/SCOT/SNAREseq_1.txt", delimiter="\t")).to(DTYPE)
    Y = torch.from_numpy(np.genfromtxt(data_folder+"/SCOT/SNAREseq_label.txt", delimiter="\t")).to(DTYPE)
    return X, Y


def load_snareseq2():
    """
    Ref : Chen, S., Lake, B. B., and Zhang, K. High-throughput
    sequencing of transcriptome and chromatin accessibility
    in the same cell. Nature Biotechnology, 37(12):1452–
    1457, 2019.
    """
    X = torch.from_numpy(np.genfromtxt(data_folder+"/SCOT/SNAREseq_2.txt", delimiter="\t")).to(DTYPE)
    Y = torch.from_numpy(np.genfromtxt(data_folder+"/SCOT/SNAREseq_label.txt", delimiter="\t")).to(DTYPE)
    return X, Y


def load_citeseq():
    """
    Ref : Marlon Stoeckius, Christoph Hafemeister, William Stephenson, 
    Brian Houck-Loomis, Pratip K Chattopadhyay, Harold Swerdlow, Rahul Satija, and Peter Smibert. 
    Simultaneous epitope and transcriptome measurement insingle cells. 
    Nature Methods, 14(9):865, 2017.
    """
    X, Y, _ = load_example_data("CITEseq")
    X = torch.from_numpy(X).to(DTYPE)
    Y = torch.from_numpy(Y).to(DTYPE)
    return X, Y


def load_zeisel():
    """
    Ref : Amit Zeisel, Ana B Munoz-Manchado, Simone Codeluppi, Peter Lonnerberg, 
    Gioele La Manno, Anna Jureus, Sueli Marques, Hermany Munguba, Liqun He, Christer Betsholtz, et al. 
    Cell types in the mouse cortex and hippocampus revealed by single-cell RNA-seq. 
    Science, 347(6226):1138–1142, 2015.
    """
    X, Y, _= load_example_data("zeisel")
    X = torch.from_numpy(X).to(DTYPE)
    Y = torch.from_numpy(Y).to(DTYPE)
    return X, Y


def load_pbmc():
    adata_ref = sc.datasets.pbmc3k_processed()
    X = torch.Tensor(adata_ref.X).to(DTYPE)
    Y = torch.Tensor(LabelEncoder().fit_transform(adata_ref.obs.louvain.values)).to(DTYPE)
    return X, Y


dataset_dict = {'mnist': load_mnist,
                'small_mnist': load_small_mnist,
                'fashion_mnist': load_fashion_mnist,
                'small_fashion_mnist': load_small_fashion_mnist,
                'coil': load_coil,
                'snareseq1': load_snareseq1,
                'snareseq2': load_snareseq2,
                'citeseq': load_citeseq,
                'zeisel': load_zeisel,
                'pbmc': load_pbmc
                }

