import math
import os
import pickle
import warnings

import anndata as ad
import h5py
import numpy as np
import pandas as pd
from PIL import Image
import scanpy as sc
import scipy.sparse as sp
from scipy.sparse import csr_matrix

from models._utils import (
    set_seed,
    sf_normalize,
    tokenize_data,
)

warnings.filterwarnings('ignore')


set_seed(42)

ROOT_DIR = 'path/to/HEST'
OUT_PATH = 'path/to/output'
N_GPU = 4
downstream_slide = [
    'MISC1', 'MISC2', 'MISC3', 'MISC4', 'MISC5', 'MISC6', 'MISC7', 'MISC8', 'MISC9', 'MISC10', 'MISC11', 'MISC12', # DLPFC
    'TENX13', # HBC
    'NCBI672', 'NCBI673', 'NCBI674', 'NCBI675', # PSC
    'NCBI709', 'NCBI710', 'NCBI711', 'NCBI712', 'NCBI713', 'NCBI714' # HHK
]


def tokenize_dataset(file_df, gene_name_id_dict, h5ad_head, h5ad_head_vars, train=True):
    processed_adata_list = []
    less_than_10 = []
    for idx in range(len(file_df)):
        dataset_name = file_df.iloc[idx]['dataset']
        id = file_df.iloc[idx]['id']
        if id in downstream_slide:
            print(f'{idx+1}/{len(file_df)} Processing file dataset: {dataset_name} id: {id} skipped')
            print('\n')
            continue
        short_name = file_df.iloc[idx]['short_name']
        adata = sc.read_h5ad(f'{ROOT_DIR}/hest_st/{id}.h5ad')
        adata.obs['batch_slide'] = file_df.iloc[idx]['dataset_title']+'_'+file_df.iloc[idx]['id']
        adata.obs['batch_dataset'] = file_df.iloc[idx]['dataset_title']
        # Sparse matrix conversion
        adata.X = adata.X.astype('float32')
        if not sp.isspmatrix_csr(adata.X):
            if sp.isspmatrix(adata.X):
                adata.X = adata.X.tocsr()
            elif isinstance(adata.X, np.ndarray):
                adata.X = csr_matrix(adata.X)
        print(f'{idx+1}/{len(file_df)} Processing file dataset: {dataset_name} id: {id}')
        print('\n')
        # Replace gene index
        current_var_names = adata.var_names
        new_var_names = list(current_var_names.copy())
        print(f'The first var index is {new_var_names[0]}')
        print(f'The ori shape is {adata.X.shape}')
        for i, gene_name in enumerate(current_var_names):
            if gene_name in gene_name_id_dict:
                new_var_names[i] = gene_name_id_dict[gene_name]  # replace with gene_id
        adata.var_names = new_var_names
        if adata.var_names.duplicated().any():
            adata.var_names_make_unique()

        # Add information to obs index to avoid duplicates
        current_obs_names = adata.obs.index
        new_obs_names = [f"{id}_{obs}" for obs in current_obs_names]
        adata.obs.index = new_obs_names

        # Remove vars not present in the vocabulary
        adata_filtered = adata[:, adata.var_names.isin(h5ad_head_vars)]
        persent = adata_filtered.n_vars / h5ad_head.n_vars
        print(f'{dataset_name}-{id}: {persent*100:.2f}% of gene name transition')
        if persent < 0.1:
            print(f'Warning {dataset_name}_{id} has less than 10% gene name transition')
            less_than_10.append(f'{dataset_name}_{id}')
            continue
        assert adata_filtered.var_names.is_unique

        adata = ad.concat([h5ad_head, adata_filtered], join='outer', axis=0)
        adata_output = adata[1:]
        adata_output.obs['slide'] = id
        sc.pp.filter_cells(adata_output, min_counts=100)
        print(f'The processed shape is {adata_output.shape}')
        processed_adata_list.append(adata_output)

    print('Combining all processed adata...')
    combined_adata = ad.concat(processed_adata_list, join='inner')
    print((f'The combined adata has {combined_adata.n_obs} spots and {combined_adata.n_vars} gene'))
    print('Done!')
    print('\n')

    print('Filter spots by gene counts...')
    sc.pp.filter_cells(combined_adata, min_counts=100)
    print((f'The combined adata after filter has {combined_adata.n_obs} spots'))
    print('Done!')
    print('\n')

    print('Filter spots by images...')
    image_files = {file_name[:-4] for file_name in os.listdir(f'{ROOT_DIR}/hest_patches_image') if file_name.endswith('.png')}
    spot_ids = set(combined_adata.obs.index)
    missing_spots = spot_ids - image_files
    combined_adata = combined_adata[~combined_adata.obs.index.isin(missing_spots)]
    print(f'The combined adata after filter has {combined_adata.n_obs} spots')
    print('Done!\n')

    # Read the mean expression value for each gene
    print('Reading mean expression value for each gene...')
    mean = np.load('Visium_mean.npy')
    mean = np.nan_to_num(mean)
    rounded_values = np.where((mean % 1) >= 0.5, np.ceil(mean), np.floor(mean))
    mean = np.where(mean == 0, 1, rounded_values)
    print('Done!')
    print('\n')

    print('Shuffling obs...')
    obs_indices = np.arange(combined_adata.n_obs)
    np.random.shuffle(obs_indices)
    combined_adata = combined_adata[obs_indices]
    print('Done!')
    print('\n')

    combined_adata.obs['batch_slide_encoded'], uniques = pd.factorize(combined_adata.obs['batch_slide'])
    combined_adata.obs['batch_dataset_encoded'], uniques = pd.factorize(combined_adata.obs['batch_dataset'])

    print('Total number of spots:', len(combined_adata))
    print('\n')
    #Tokenzie
    adata = combined_adata
    adata.obs['spot_id'] = adata.obs.index
    adata.obs.reset_index(drop=True, inplace=True)
    adata_obs = adata.obs
    print(f'The length of dataset: {adata_obs.shape[0]}')
    if train:
        N_BATCHES = math.ceil(adata_obs.shape[0] / 10000)
    else:
        N_BATCHES = math.ceil(adata_obs.shape[0] / 2000)
    
    batch_indices = np.array_split(adata_obs.index, N_BATCHES)
    chunk_len = len(batch_indices[0])
    if N_BATCHES % N_GPU != 0:
        N_BATCHES = N_BATCHES + N_GPU - N_BATCHES % N_GPU
    change_chunk_len = chunk_len - ((adata_obs.shape[0] - chunk_len * (N_BATCHES - N_GPU)) // N_GPU)
    print('N_BATCHES: ', N_BATCHES)
    print('chunk_len: ', chunk_len)
    
    adata_obs = adata_obs.reset_index().rename(columns={'index':'idx'})
    adata_obs['idx'] = adata_obs['idx'].astype('i8')

    for batch in range(N_BATCHES):
        if batch < N_BATCHES - N_GPU or batch < N_GPU:
            begin = batch * chunk_len
            end = (batch + 1) * chunk_len
        else:
            begin = batch * chunk_len - (batch-N_BATCHES+N_GPU) * change_chunk_len
            end = (batch+1) * chunk_len - (batch+1-N_BATCHES+N_GPU) * change_chunk_len
        # tokenize expressions
        obs_tokens = adata_obs.iloc[begin:end].copy()
        x = adata.X[begin:end]
        x = np.nan_to_num(x) # is NaN values, fill with 0s
        x = sf_normalize(x)
        median_counts_per_gene = mean
        median_counts_per_gene += median_counts_per_gene == 0
        out = x / median_counts_per_gene.reshape((1, -1))
        tokenized_idx = tokenize_data(out, 4096, 30)
        obs_tokens = obs_tokens[['spot_id', 'batch_slide_encoded', 'batch_dataset_encoded']]
        # load images
        image_arrays = []
        image_aug_arrays = []
        image_aug_pos_arrays = []
        spot_names = []
        batch_slide_ids = []
        batch_dataset_ids = []
        for i in range(begin, end):
            spot_id = obs_tokens['spot_id'][i]
            batch_slide_id = obs_tokens['batch_slide_encoded'][i]
            batch_dataset_id = obs_tokens['batch_dataset_encoded'][i]
            img = Image.open(f'{ROOT_DIR}/hest_patches_image/{spot_id}.png')
            for pos in range(9):
                if os.path.exists(f'{ROOT_DIR}/hest_patches_image/{spot_id}_aug_pos_{pos}.png'):
                    img_aug = Image.open(f'{ROOT_DIR}/hest_patches_image/{spot_id}_aug_pos_{pos}.png')
                    pos_label = pos
                    break
            img = img.convert('RGB')
            img_aug = img_aug.convert('RGB')
            img_array = np.array(img)# (224,224,3)
            img_aug_array = np.array(img_aug)
            image_arrays.append(img_array) 
            image_aug_arrays.append(img_aug_array)
            spot_names.append(spot_id)
            batch_slide_ids.append(batch_slide_id)
            batch_dataset_ids.append(batch_dataset_id)
            image_aug_pos_arrays.append(pos_label)
        image_arrays = np.stack(image_arrays)
        image_arrays = np.transpose(image_arrays, (0, 3, 1, 2))
        
        image_aug_arrays = np.stack(image_aug_arrays)
        image_aug_arrays = np.transpose(image_aug_arrays, (0, 3, 1, 2))
        if train:
            save_dir = f'{OUT_PATH}/train/'
        else:
            save_dir = f'{OUT_PATH}/val/'
        os.makedirs(save_dir, exist_ok=True)
        with h5py.File(f'{save_dir}/tokens-{batch}.h5', 'w') as f:
            f.create_dataset('tokenized_gene', data=tokenized_idx)
            f.create_dataset('images', data=image_arrays)
            f.create_dataset('images_aug', data=image_aug_arrays)
            f.create_dataset('pos_label', data=np.array(image_aug_pos_arrays, dtype=int))
            f.create_dataset('spot_names', data=np.array(spot_names, dtype='S'))
            f.create_dataset('batch_slide_id', data=np.array(batch_slide_ids, dtype=int))
            f.create_dataset('batch_dataset_id', data=np.array(batch_dataset_ids, dtype=int))
        
        print(f'{batch+1}/{N_BATCHES}  length of this batch: {end-begin} from {begin} to {end}')


if __name__ == '__main__':
    with open('gene_name_id_dict.pkl', 'rb') as f:
        gene_name_id_dict = pickle.load(f)
    h5ad_head = sc.read_h5ad('./model.h5ad')
    h5ad_head_vars = h5ad_head.var_names
    file_df = pd.read_csv('./dataset_info/HEST_info.csv')
    file_df = file_df.sample(frac=1, random_state=42).reset_index(drop=True)
    val_df = file_df.sample(frac=0.05, random_state=1).reset_index(drop=True)
    train_df = file_df.drop(val_df.index).reset_index(drop=True)
    tokenize_dataset(val_df, gene_name_id_dict, h5ad_head, h5ad_head_vars, train=False)
    tokenize_dataset(train_df, gene_name_id_dict, h5ad_head, h5ad_head_vars, train=True)