import torch
import numpy as np
import pandas as pd
import random
from omegaconf import OmegaConf
from sklearn.model_selection import train_test_split
config = OmegaConf.load('config.yaml')
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import umap
import scanpy as sc
from scipy import stats
import glob
import os
from PIL import Image


def normalize_by_norm(data):
    for i, t in enumerate(data):
        data[i] = t / np.linalg.norm(t)
    return data


def umap_scatter_anomalies(X, y):
    features_embedded = umap.UMAP(n_components=2, random_state=42).fit_transform(X)
    plt.figure(figsize=(12, 8))
    plt.scatter(*zip(*features_embedded[np.where(y==0)]), marker='o', color='b', s=4, alpha=1, label='clean')
    plt.scatter(*zip(*features_embedded[np.where(y==1)]), marker='o', color='r', s=4, alpha=1, label='anomaly')
    plt.tick_params(left=False, right=False , labelleft=False , labelbottom=False, bottom=False)
    plt.legend(loc='upper right', fontsize=20)
    plt.grid(False)
    plt.show()


class LoadDataset(Dataset):
    def __init__(self, root, n_samples):
        self.filenames = []
        self.root = root
        self.transform = transforms.ToTensor()
        filenames = glob.glob(os.path.join(root, '*.jpg'))
        images = []
        for idx, filename in enumerate(filenames):
            if idx < n_samples:
                self.filenames.append(filename)
                image = Image.open(filename)
                image = self.transform(image)
                images.append(image)

        self.data = torch.stack(images)
        self.len = len(self.filenames)

    def __getitem__(self, index):
        image = Image.open(self.filenames[index])
        return self.transform(image)

    def __len__(self):
        return self.len


def gaussian_data_loader(n_samples, out_dim, latent_dim, batch_size, scenario, noise_per=0, snr_db=-15, shift=1):
    n_test_samples = 10_000
    rand_mat = np.random.randn(out_dim, latent_dim)
    snr = 10 ** (snr_db / 20)
    if scenario == 'sample_noise' or scenario == 'anomalies':
        rand_mat *= snr / latent_dim ** 0.5
    elif scenario == 'feature_noise':
        rand_mat *= snr * (noise_per ** 0.5) / (latent_dim ** 0.5)
    elif scenario == 'domain_shift':
        domain_shift_mat = rand_mat + shift * np.random.randn(out_dim, latent_dim)
        rand_mat *= snr / latent_dim ** 0.5
        domain_shift_mat *= snr / ((shift ** 2 + 1) * latent_dim) ** 0.5
    else:
        raise ValueError('scenario not implemented')

    latent_features = np.random.randn(n_samples + n_test_samples, latent_dim)

    if scenario == 'domain_shift':
        latent_train_data, latent_test_data = train_test_split(latent_features, test_size=n_test_samples, random_state=0)
        train_data = latent_train_data @ domain_shift_mat.T
        test_data = latent_test_data @ rand_mat.T

    else:
        data = latent_features @ rand_mat.T
        train_data, test_data = train_test_split(data, test_size=n_test_samples, random_state=0)

    if scenario == 'sample_noise' or scenario == 'domain_shift':
        n_samples_to_noise = int(noise_per * len(train_data))
        samples_to_noise = random.sample(range(len(train_data)), n_samples_to_noise)
        train_data[samples_to_noise] += np.random.normal(loc=0, scale=1,
                                                         size=(n_samples_to_noise, out_dim)).astype('float32')

    elif scenario == 'feature_noise':
        n_features_to_noise = int(noise_per * out_dim)
        features_to_noise = random.sample(range(out_dim), n_features_to_noise)
        train_data[:, features_to_noise] += np.random.normal(loc=0, scale=1,
                                                             size=(len(train_data), n_features_to_noise)).astype(
            'float32')

    elif scenario == 'anomalies':
        n_anomalies = int(noise_per * len(train_data))
        anomalies_samples = random.sample(range(len(train_data)), n_anomalies)
        anomalies = np.random.normal(loc=0, scale=1, size=(n_anomalies, out_dim))
        train_data[anomalies_samples] = anomalies
        test_data = test_data[:n_samples//2]
        test_anomaly_data = np.random.normal(loc=0, scale=1, size=(n_samples//2, out_dim))
        test_anomaly_set = TensorDataset(torch.Tensor(test_anomaly_data), torch.ones(len(test_anomaly_data)))
        test_anomaly_loader = DataLoader(test_anomaly_set, batch_size=batch_size, shuffle=True)

    train_set = TensorDataset(torch.Tensor(train_data))
    if scenario == 'anomalies':
        test_set = TensorDataset(torch.Tensor(test_data), torch.zeros(len(test_data)))
    else:
        test_set = TensorDataset(torch.Tensor(test_data))

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
    if scenario == 'anomalies':
        return train_loader, test_loader, test_anomaly_loader
    else:
        return train_loader, test_loader


def single_cell_data_loader(n_samples, n_features, batch_size, source, target, scenario, snr_db=0, noise_per=0,
                                   verbose=0):
    # read and process data:
    adata = pd.read_csv('data/batch_effect/dataset4/myData_pancreatic_5batches.txt', sep='\t', header=0, index_col=0)
    adata = sc.AnnData(np.transpose(adata))
    sample_adata = pd.read_csv('data/batch_effect/dataset4/mySample_pancreatic_5batches.txt', header=0, index_col=0,
                               sep='\t')
    adata.obs['cell_type'] = sample_adata.loc[adata.obs_names, ['celltype']]
    adata.obs['batch'] = sample_adata.loc[adata.obs_names, ['batchlb']]
    sc.pp.filter_cells(adata, min_genes=300)
    sc.pp.filter_genes(adata, min_cells=10)
    sc.pp.log1p(adata)
    sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
    sc.pp.filter_genes_dispersion(data=adata, n_top_genes=n_features, min_mean=0.0125, max_mean=3, min_disp=0.5)
    data = adata.X
    cell_info = adata.obs
    data = stats.zscore(data, axis=0)

    # visualize data:
    if verbose == 1:
        embedding = umap.UMAP().fit_transform(data)
        batch1 = embedding[np.where(np.array(cell_info.batch) == 'Baron_b1')[0]]
        batch2 = embedding[np.where(np.array(cell_info.batch) == 'Mutaro_b2')[0]]
        batch3 = embedding[np.where(np.array(cell_info.batch) == 'Segerstolpe_b3')[0]]
        batch4 = embedding[np.where(np.array(cell_info.batch) == 'Wang_b4')[0]]
        batch5 = embedding[np.where(np.array(cell_info.batch) == 'Xin_b5')[0]]
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.scatter(batch1[:, 0], batch1[:, 1], color='blue', s=1, label='Batch 1')
        ax.scatter(batch2[:, 0], batch2[:, 1], color='orange', s=1, label='Batch 2')
        ax.scatter(batch3[:, 0], batch3[:, 1], color='green', s=1, label='Batch 3')
        ax.scatter(batch4[:, 0], batch4[:, 1], color='red', s=1, label='Batch 4')
        ax.scatter(batch5[:, 0], batch5[:, 1], color='purple', s=1, label='Batch 5')
        plt.legend(loc='best')
        plt.title(f'UMAP of the data with z-score and {n_features} features')
        plt.tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
        plt.grid(False)
        plt.show()

    # separate data to batches (domains):
    batch_1_idx = np.where(np.array(cell_info.batch) == 'Baron_b1')[0]
    batch_2_idx = np.where(np.array(cell_info.batch) == 'Mutaro_b2')[0]
    batch_3_idx = np.where(np.array(cell_info.batch) == 'Segerstolpe_b3')[0]
    batch_4_idx = np.where(np.array(cell_info.batch) == 'Wang_b4')[0]
    batch_5_idx = np.where(np.array(cell_info.batch) == 'Xin_b5')[0]
    batch_1 = data[batch_1_idx]
    batch_2 = data[batch_2_idx]
    batch_3 = data[batch_3_idx]
    batch_4 = data[batch_4_idx]
    batch_5 = data[batch_5_idx]

    if scenario == 'domain_shift':
        # configure the source dataset:
        if source == 'Baron_b1':
            batch_1 = batch_1[: n_samples]
            train_set = TensorDataset(torch.Tensor(batch_1))
        elif source == 'Mutaro_b2':
            train_set = TensorDataset(torch.Tensor(batch_2))
        elif source == 'Segerstolpe_b3':
            train_set = TensorDataset(torch.Tensor(batch_3))
        elif source == 'Wang_b4':
            train_set = TensorDataset(torch.Tensor(batch_4))
        elif source == 'Xin_b5':
            train_set = TensorDataset(torch.Tensor(batch_5))
        else:
            raise ValueError("'source' not implemented")

        # configure the target dataset:
        if target == 'Baron_b1':
            batch_1 = batch_1[: n_samples]
            test_set = TensorDataset(torch.Tensor(batch_1))
        elif target == 'Mutaro_b2':
            test_set = TensorDataset(torch.Tensor(batch_2))
        elif target == 'Segerstolpe_b3':
            test_set = TensorDataset(torch.Tensor(batch_3))
        elif target == 'Wang_b4':
            test_set = TensorDataset(torch.Tensor(batch_4))
        elif target == 'Xin_b5':
            test_set = TensorDataset(torch.Tensor(batch_5))
        else:
            raise ValueError("'target' not implemented")

    elif scenario == 'sample_noise':
        snr = 10 ** (snr_db / 20)
        train_data, test_data = train_test_split(batch_1, train_size=n_samples, random_state=0)
        train_data = normalize_by_norm(train_data)
        test_data = normalize_by_norm(test_data)
        n_samples_to_noise = int(noise_per * len(train_data))
        samples_to_noise = random.sample(range(len(train_data)), n_samples_to_noise)
        noise = np.random.normal(loc=0, scale=1, size=(n_samples_to_noise, train_data.shape[1])).astype('float32')
        noise = normalize_by_norm(noise) / snr
        train_data[samples_to_noise] += noise
        train_set = TensorDataset(torch.Tensor(train_data))
        test_set = TensorDataset(torch.Tensor(test_data))

    elif scenario == 'feature_noise':
        snr = 10 ** (snr_db / 20)
        train_data, test_data = train_test_split(batch_1, train_size=n_samples, random_state=0)
        train_data = normalize_by_norm(train_data)
        test_data = normalize_by_norm(test_data)
        n_features_to_noise = int(noise_per * train_data.shape[1])
        features_to_noise = random.sample(range(train_data.shape[1]), n_features_to_noise)
        noise = np.random.normal(loc=0, scale=1, size=(len(train_data), n_features_to_noise)).astype('float32')
        noise = normalize_by_norm(noise) / snr
        train_data[:, features_to_noise] += noise
        train_set = TensorDataset(torch.Tensor(train_data))
        test_set = TensorDataset(torch.Tensor(test_data))

    else:
        raise ValueError('Scenario not implemented')

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last=True)
    return train_loader, test_loader


def anomaly_data_loader_celeba(n_samples, anomaly_per, batch_size, verbose=0):
    n_test_samples = 1000
    data = np.load('data/8_celeba.npz', allow_pickle=True)
    X, y = data['X'].astype('float32'), data['y']
    X = stats.zscore(X, axis=0)
    clean = X[np.where(y == 0)[0]]
    anomalies = X[np.where(y == 1)[0]]
    train_data, test_data = train_test_split(clean, train_size=n_samples, random_state=0)


    if verbose == 1:
        data = np.concatenate((train_data, anomalies))
        labels = np.concatenate((np.zeros(len(train_data)), np.ones(len(anomalies))))
        umap_scatter_anomalies(data, labels)

    n_anomalies = int(len(train_data) * anomaly_per)
    anomalies_samples = random.sample(range(len(train_data)), n_anomalies)
    anomaly_train, anomaly_test = train_test_split(anomalies, train_size=n_anomalies, random_state=0)
    train_data[anomalies_samples] = anomaly_train
    anomaly_test = anomaly_test[:n_test_samples]
    test_data = test_data[:len(anomaly_test)]
    train_set = TensorDataset(torch.Tensor(train_data))
    test_set = TensorDataset(torch.Tensor(test_data), torch.zeros(len(test_data)))
    test_anomaly_set = TensorDataset(torch.Tensor(anomaly_test), torch.ones(len(anomaly_test)))
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
    test_anomaly_loader = DataLoader(test_anomaly_set, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader, test_anomaly_loader


def mnist_data_loader(n_samples, batch_size, num_workers, scenario, noise_per=0, snr_db=15, source='mnist'):
    snr = 10 ** (snr_db / 20)
    if scenario == 'domain_shift':
        if source == 'mnist':
            train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor()).data / 255.0
            test_data = LoadDataset(root="./data/mnistm", n_samples=10_000).data
        else:
            train_data = LoadDataset(root="./data/mnistm", n_samples=5000).data
            test_data = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor()).data / 255.0

    else:
        train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor()).data / 255.0
        test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor()).data / 255.0
    np.random.shuffle(train_data.numpy())
    train_data = train_data[:n_samples]
    if scenario != 'domain_shift':
        train_data = normalize_by_norm(train_data)
        test_data = normalize_by_norm(test_data)
        num_features = train_data.shape[1] * train_data.shape[2]
        if noise_per != 0:
            if scenario == 'sample_noise' or scenario == 'domain_shift':
                n_samples_to_noise = int(noise_per * len(train_data))
                samples_to_noise = random.sample(range(len(train_data)), n_samples_to_noise)
                noise = torch.randn(size=(n_samples_to_noise, train_data.shape[1], train_data.shape[2]))
                noise = normalize_by_norm(noise) / snr
                train_data[samples_to_noise] += noise
            elif scenario == 'feature_noise':
                n_features_to_noise = int(noise_per * num_features)
                features_to_noise = random.sample(range(num_features), n_features_to_noise)
                features_to_noise = np.unravel_index(indices=features_to_noise, shape=(train_data.shape[1], train_data.shape[2]))
                noise = torch.randn(size=(len(train_data), n_features_to_noise))
                noise = normalize_by_norm(noise) / snr
                train_data[:, features_to_noise[0], features_to_noise[1]] += noise
            else:
                raise ValueError('scenario not implemented')
    if source == 'mnist':
        train_data = train_data[:, None, :, :]
        train_data = train_data.repeat(1, 3, 1, 1)
    if scenario != 'domain_shift' or source == 'mnistm':
        test_data = test_data[:, None, :, :]
        test_data = test_data.repeat(1, 3, 1, 1)

    train_dataset = TensorDataset(train_data)
    test_dataset = TensorDataset(test_data)

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True,
                              drop_last=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False,
                             drop_last=True)
    return train_loader, test_loader


def gaussian_data_loader_non_linear(n_samples, out_dim, latent_dim, batch_size, scenario, noise_per=0, snr_db=-15, shift=1):
    n_test_samples = 10_000
    snr = 10 ** (snr_db / 20)
    D1 = np.random.randn(out_dim, latent_dim)
    D2 = np.random.randn(out_dim, latent_dim)
    D3 = np.random.randn(out_dim, latent_dim)

    latent_features = np.random.randn(n_samples + n_test_samples, latent_dim)

    if scenario == 'domain_shift':
        rand_mat = np.random.randn(out_dim, latent_dim)
        domain_shift_D1 = D1 + shift * rand_mat
        domain_shift_D2 = D2 + shift * rand_mat
        domain_shift_D3 = D3 + shift * rand_mat
        latent_train_data, latent_test_data = train_test_split(latent_features, test_size=n_test_samples, random_state=0)
        train_data = latent_train_data @ domain_shift_D1.T + latent_train_data**2 @ domain_shift_D2.T + \
                     latent_train_data**3 @ domain_shift_D3.T
        test_data = latent_test_data @ D1.T + latent_test_data**2 @ D2.T + latent_test_data**3 @ D3.T

    else:
        data = latent_features @ D1.T + latent_features ** 2 @ D2.T + latent_features ** 3 @ D3.T
        train_data, test_data = train_test_split(data, test_size=n_test_samples, random_state=0)

    train_data = normalize_by_norm(train_data)
    test_data = normalize_by_norm(test_data)

    if scenario == 'sample_noise' or scenario == 'domain_shift':
        n_samples_to_noise = int(noise_per * len(train_data))
        samples_to_noise = random.sample(range(len(train_data)), n_samples_to_noise)
        noise = np.random.normal(loc=0, scale=1, size=(n_samples_to_noise, out_dim)).astype('float32')
        noise = normalize_by_norm(noise) / snr
        train_data[samples_to_noise] += noise

    elif scenario == 'feature_noise':
        n_features_to_noise = int(noise_per * out_dim)
        features_to_noise = random.sample(range(out_dim), n_features_to_noise)
        noise = np.random.normal(loc=0, scale=1, size=(len(train_data), n_features_to_noise)).astype('float32')
        noise = normalize_by_norm(noise) / snr
        train_data[:, features_to_noise] += noise

    elif scenario == 'anomalies':
        n_anomalies = int(noise_per * len(train_data))
        anomalies_samples = random.sample(range(len(train_data)), n_anomalies)
        anomalies = np.random.normal(loc=0, scale=1, size=(n_anomalies, out_dim))
        anomalies = normalize_by_norm(anomalies) / snr
        train_data[anomalies_samples] = anomalies
        test_data = test_data[:n_samples//2]
        test_anomaly_data = np.random.normal(loc=0, scale=1, size=(n_samples//2, out_dim))
        test_anomaly_set = TensorDataset(torch.Tensor(test_anomaly_data), torch.ones(len(test_anomaly_data)))
        test_anomaly_loader = DataLoader(test_anomaly_set, batch_size=batch_size, shuffle=True)

    train_set = TensorDataset(torch.Tensor(train_data))
    if scenario == 'anomalies':
        test_set = TensorDataset(torch.Tensor(test_data), torch.zeros(len(test_data)))
    else:
        test_set = TensorDataset(torch.Tensor(test_data))

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
    if scenario == 'anomalies':
        return train_loader, test_loader, test_anomaly_loader
    else:
        return train_loader, test_loader





