import math
import os
import pickle
import sys
import warnings

import anndata as ad
import h5py
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
from scipy.sparse import csr_matrix

from models._utils import (
    filter_batches_by_distance,
    group_spots_into_unique_batches,
    set_seed,
    sf_normalize,
    tokenize_data,
)

warnings.filterwarnings('ignore')

set_seed(42)

ROOT_DIR = 'path/to/SpaVis-6M'
OUT_PATH = 'path/to/output'
N_GPU = 4
def fast_slice(adata, begin, end):
    """Zero-copy fast slicing of CSR matrix"""
    X = adata.X
    new_indptr = X.indptr[begin:end+1].copy()
    new_indptr -= new_indptr[0]
    return csr_matrix(
        (X.data[X.indptr[begin]:X.indptr[end]],
         X.indices[X.indptr[begin]:X.indptr[end]],
         new_indptr),
        shape=(end-begin, X.shape[1]),
        dtype=np.float32
    )


def tokenize_dataset(file_df, gene_name_id_dict, h5ad_head, h5ad_head_vars, train=True):
    processed_adata_list = []
    less_than_10 = []
    for idx in range(len(file_df)):
        dataset_name = file_df.iloc[idx]['dataset']
        if not file_df.iloc[idx]['spatial_info']:
            continue
        id = file_df.iloc[idx]['id']
        short_name = file_df.iloc[idx]['short_name']
        if dataset_name == 'GEO':
            adata = sc.read_h5ad(f'{ROOT_DIR}/GEO/{short_name}_{id}.h5ad')
        elif dataset_name == 'SpatialOmics':
            adata = sc.read_h5ad(f'{ROOT_DIR}/SpatialOmics/{short_name}/{id}.h5ad')
        elif dataset_name == 'STimage-1K4M':
            adata = sc.read_h5ad(f'{ROOT_DIR}/STimage-1K4M/{id}_count.h5ad')
        elif dataset_name == 'STOmicsDB':
            adata = sc.read_h5ad(f'{ROOT_DIR}/STOmicsDB/{short_name}_{id}.h5ad')
        elif dataset_name == 'HEST':
            adata = sc.read_h5ad(f'{ROOT_DIR}/HEST/{id}.h5ad')
        else:
            continue
        if 'array_row' not in adata.obs.keys():
            continue
        if 'array_col' not in adata.obs.keys():
            continue
        all_mask = ~adata.obs[['array_row', 'array_col']].isna().any(axis=1)
        spatial_persent = (all_mask.sum() / adata.n_obs) * 100
        print(f'{dataset_name}-{id}: {spatial_persent:.2f} % have spatial information')
        if spatial_persent == 0:
            continue
        adata = adata[all_mask]

        # Sparse matrix conversion
        adata.X = adata.X.astype('float32')
        if not sp.isspmatrix_csr(adata.X):
            if sp.isspmatrix(adata.X):
                adata.X = adata.X.tocsr()
            elif isinstance(adata.X, np.ndarray):
                adata.X = csr_matrix(adata.X)
        print('\n')
        print(f'{idx+1}/{len(file_df)} Processing file dataset: {dataset_name} id: {id}')
        # Replace gene index
        current_var_names = adata.var_names
        new_var_names = list(current_var_names.copy())
        print(f'The first var index is {new_var_names[0]}')
        print(f'The ori shape is {adata.X.shape}')
        for i, gene_name in enumerate(current_var_names):
            if gene_name in gene_name_id_dict:
                new_var_names[i] = gene_name_id_dict[gene_name]  # replace with gene_id
        adata.var_names = new_var_names
        if adata.var_names.duplicated().any():
            adata.var_names_make_unique()

        # Add information to obs index to avoid duplicates
        current_obs_names = adata.obs.index
        new_obs_names = [f"{dataset_name}_{id}_{obs}" for obs in current_obs_names]
        adata.obs.index = new_obs_names

        # Remove vars not present in the vocabulary
        adata_filtered = adata[:, adata.var_names.isin(h5ad_head_vars)]
        persent = adata_filtered.n_vars / h5ad_head.n_vars
        print(f'{dataset_name}-{id}: {persent*100:.2f}% of gene name transition')
        if persent < 0.1:
            print(f'Warning {dataset_name}_{id} has less than 10% gene name transition')
            less_than_10.append(f'{dataset_name}_{id}')
            continue
        assert adata_filtered.var_names.is_unique
        adata = ad.concat([h5ad_head, adata_filtered], join='outer', axis=0)
        adata_output = adata[1:]
        adata_output.obs['slide'] = id
        sc.pp.filter_cells(adata_output, min_counts=100)
        print(f'The processed shape is {adata_output.shape}')
        processed_adata_list.append(adata_output)
        
    print('Nmber of less than 10% gene name transition:', len(less_than_10))
    print('The following files have less than 10% gene name transition:')
    print(less_than_10)
    print('\n')
    print('Combining all processed adata...')
    combined_adata = ad.concat(processed_adata_list, join='inner')
    print((f'The combined adata has {combined_adata.n_obs} spots and {combined_adata.n_vars} gene'))
    print('Done!')
    print('\n')

    # Calculate the mean expression value for each gene
    print('Loading mean expression value for each gene...')
    mean = np.load('Visium_mean.npy')
    mean = np.nan_to_num(mean)
    rounded_values = np.where((mean % 1) >= 0.5, np.ceil(mean), np.floor(mean))
    mean = np.where(mean == 0, 1, rounded_values)
    print('Done!')
    print('\n')

    print("Grouping spots into unique mini-batches by slide...")
    unique_slides = combined_adata.obs['slide'].unique()
    all_batches = []
    total_slides = len(unique_slides)
    for i, slide in enumerate(unique_slides, 1):
        print(f"Processing slide {i}/{total_slides}: {slide}", end="\n")
        sys.stdout.flush()

        slide_all_batches = []
        adata_slide = combined_adata[combined_adata.obs['slide'] == slide]
        batches = group_spots_into_unique_batches(adata_slide, batch_size=9)

        # Convert local indices within the slide to global indices
        slide_indices = adata_slide.obs.index.to_numpy()
        for batch in batches:
            global_batch = slide_indices[batch]
            slide_all_batches.append(global_batch)

        slide_all_batches = filter_batches_by_distance(slide_all_batches, combined_adata)
        all_batches.extend(slide_all_batches)
    print('\n')
    print('Done!')
    print('\n')

    print('Shuffling obs...')
    np.random.shuffle(all_batches)
    print('Done!')
    print('\n')

    print('Total number of spots:', len(all_batches)*len(all_batches[0]))
    print('\n')

    # Tokenzie
    num_batches = len(all_batches)
    num_spots = len(all_batches[0])
    N_BATCHES = math.ceil(num_batches / 1000)

    batch_indices = np.array_split(all_batches, N_BATCHES)
    chunk_len = len(batch_indices[0])
    if N_BATCHES % N_GPU != 0:
        N_BATCHES = N_BATCHES + N_GPU - N_BATCHES % N_GPU
    change_chunk_len = chunk_len - ((num_batches - chunk_len * (N_BATCHES - N_GPU)) // N_GPU)
    print('N_BATCHES: ', N_BATCHES)
    print('set chunk_len: ', chunk_len)
    for batch in range(N_BATCHES):
        if batch < N_BATCHES - N_GPU or batch < N_GPU:
            begin = batch * chunk_len
            end = (batch + 1) * chunk_len
        else:
            begin = batch * chunk_len - (batch-N_BATCHES+N_GPU) * change_chunk_len
            end = (batch+1) * chunk_len - (batch+1-N_BATCHES+N_GPU) * change_chunk_len
        single_batch = all_batches[begin:end]
        single_batch = np.concatenate(single_batch)
        adata = combined_adata[single_batch]
        x = adata.X
        x = np.nan_to_num(x) # is NaN values, fill with 0
        x = sf_normalize(x)
        median_counts_per_gene = mean
        median_counts_per_gene += median_counts_per_gene == 0
        out = x / median_counts_per_gene.reshape((1, -1))
        tokenized_idx = tokenize_data(out, 4096, 30)
        tokenized_idx = tokenized_idx.reshape(-1, num_spots, tokenized_idx.shape[-1])
        if train:
            save_dir = f'{OUT_PATH}/train'
        else:
            save_dir = f'{OUT_PATH}/val/'
        os.makedirs(save_dir, exist_ok=True)
        with h5py.File(f'{save_dir}/tokens-{batch}.h5', 'w') as f:
            f.create_dataset('tokenized_gene', data=tokenized_idx)
        print(f'{batch+1}/{N_BATCHES}  length of this batch: {end-begin} from {begin} to {end}')


if __name__=='__main__':
    with open('gene_name_id_dict.pkl', 'rb') as f:
        gene_name_id_dict = pickle.load(f)
    h5ad_head = sc.read_h5ad('./model.h5ad')
    h5ad_head_vars = h5ad_head.var_names
    file_df = pd.read_csv('./dataset_info/SpaVis_6M.csv')
    file_df = file_df.sample(frac=1, random_state=42).reset_index(drop=True)
    val_df = file_df.sample(frac=0.05, random_state=42).reset_index(drop=True)
    train_df = file_df.drop(val_df.index).reset_index(drop=True)
    tokenize_dataset(val_df, gene_name_id_dict, h5ad_head, h5ad_head_vars, train=False)
    tokenize_dataset(train_df, gene_name_id_dict, h5ad_head, h5ad_head_vars, train=True)