import os
import pickle

os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = str(pow(2, 40))

import anndata as ad
import cv2
import h5py
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
from scipy.sparse import csr_matrix
from tqdm import tqdm

from config_files._constants import mapping
from models._utils import (
    adjust_crop,
    get_safe_region,
    set_seed,
    sf_normalize,
    tokenize_data,
)


ROOT_DIR = 'path/to/HEST'
OUT_PATH = 'path/to/output'
LABEL_PATH = 'path/to/annotation'


label_mapping = {}

set_seed(42)

meta_data = pd.read_csv(f'{ROOT_DIR}/hest/HEST_v1_0_2.csv', index_col=1)
# Load gene name-id dictionary and h5ad header
with open('gene_name_id_dict.pkl', 'rb') as f:
        gene_name_id_dict = pickle.load(f)
h5ad_head = sc.read_h5ad('./model.h5ad')
h5ad_head_vars = h5ad_head.var_names

# Read the gene expression data
print('Reading mean expression value for each gene...')
mean = np.load('Visium_mean.npy')
mean = np.nan_to_num(mean)
rounded_values = np.where((mean % 1) >= 0.5, np.ceil(mean), np.floor(mean))
mean = np.where(mean == 0, 1, rounded_values)
print('Done!')
print('\n')


projects = list(mapping.keys())
for project in projects:
    save_dir = f'{OUT_PATH}/{project}/'
    os.makedirs(save_dir, exist_ok=True)
    if project in label_mapping.keys():
        label_dict = label_mapping[project]
    else:
        label_dict = {}
    print(f"Project: {project}")
    adata_list = []
    for sample_id, label_id in mapping[project].items():
        adata = sc.read_h5ad(f'{ROOT_DIR}/st/{sample_id}.h5ad')
        print('Remove spots not in Visium...')
        print(f'The processed shape is {adata.shape}')
        adata = adata[adata.obs['in_tissue'] == 1.0]
        print((f'The combined adata after remove spots not in Visium has {adata.n_obs} spots'))
        print('Done!')
        print('\n')

        print('Filter spots by gene counts...')
        spot_expression_sums = adata.X.sum(axis=1).A1 if sp.isspmatrix(adata.X) else adata.X.sum(axis=1)
        adata = adata[spot_expression_sums >= 100].copy()
        print(f'The shape after filtering spots with gene expression count < 100 is {adata.X.shape[0]}')
        print('\n')

        # Add label
        adata.obs['label'] = None
        if label_id is not None:
            labels = pd.read_csv(f'{LABEL_PATH}/{label_id}_anno.csv')
            if len(labels.columns) == 2:
                labels.columns = ['spot_id', 'label']
            else:
                labels.columns = ['spot_id', 'c', 'label']
            labels = labels.set_index('spot_id')
            for spot_id in adata.obs.index:
                lable_spot_id = f'{label_id}_{spot_id}'
                if lable_spot_id in labels.index:
                    spot_label = labels.loc[lable_spot_id]['label']
                    if spot_label not in label_dict:
                        label_dict[spot_label] = len(label_dict)
                    adata.obs.loc[spot_id, 'label'] = label_dict[spot_label]
            adata = adata[adata.obs['label'].notnull()].copy()
        print(f'The shape after adding label is {adata.X.shape}')
        print('\n')

        adata_list.append(adata)
        adata = adata.copy()
        current_obs_names = adata.obs.index
        new_obs_names = [f"{sample_id}_{obs}" for obs in current_obs_names]
        adata.obs.index = new_obs_names

    # Concatenate adata in adata_list
    union_adata = ad.concat(adata_list, join='inner', axis=0)
    
    # # Find top 5000 highly variable genes (HVG)
    sc.pp.normalize_total(union_adata, target_sum=1e4)
    sc.pp.log1p(union_adata)
    sc.pp.highly_variable_genes(
            union_adata, 
            n_top_genes=5000,
            flavor='seurat_v3',
        )
    hvg_gene_name = union_adata.var_names[union_adata.var['highly_variable']]
    # Save highly variable genes to a file
    hvg_gene_name.to_series().to_csv(f'{save_dir}/hvg_gene_names.csv', index=False)

    expression_list = []
    for adata in adata_list:
        sc.pp.normalize_total(adata, target_sum=1e4)
        sc.pp.log1p(adata)
        adata = adata[:, adata.var_names.isin(hvg_gene_name)]
        expression = adata.X.toarray()
        print(f'None Zero Rate : {np.count_nonzero(expression) / expression.size}')
        print(f'The shape of expression is {expression.shape}')
        expression_list.append(expression)


    for idx, (sample_id, label_id) in enumerate(mapping[project].items()):
        print(f"Sample ID: {sample_id}, Label ID: {label_id}")
        magnification = int(meta_data.loc[sample_id, 'magnification'][:-1])
        patch_size = 112 * magnification / 20

        adata = adata_list[idx]
        wsi = cv2.imread(f'{ROOT_DIR}/wsis/{sample_id}.tif')
        wsi_width, wsi_height = wsi.shape[1], wsi.shape[0]
        # Sparse matrix conversion
        if not sp.isspmatrix_csr(adata.X):
            if sp.isspmatrix(adata.X):
                adata.X = adata.X.tocsr()
            elif isinstance(adata.X, np.ndarray):
                adata.X = csr_matrix(adata.X)
        print(f'The ori shape is {adata.X.shape}')
        print('\n')

        # Replace with gene index
        current_var_names = adata.var_names
        new_var_names = list(current_var_names.copy())
        print(f'The first var index is {new_var_names[0]}')
        print(f'The ori shape is {adata.X.shape}')
        for i, gene_name in enumerate(current_var_names):
            if gene_name in gene_name_id_dict:
                new_var_names[i] = gene_name_id_dict[gene_name]  # Replace with gene_id
        adata.var_names = new_var_names
        if adata.var_names.duplicated().any():
            adata.var_names_make_unique()

        adata_filtered = adata[:, adata.var_names.isin(h5ad_head_vars)]
        adata = ad.concat([h5ad_head, adata_filtered], join='outer', axis=0)
        adata = adata[1:]

        print('Remove spots not in Visium...')
        print(f'The processed shape is {adata.shape}')
        adata = adata[(adata.obs['in_tissue'] == 1.0).astype('bool')]
        print((f'The combined adata after remove spots not in Visium has {adata.n_obs} spots'))
        print('Done!')
        print('\n')

        print('Filter spots by gene counts...')
        spot_expression_sums = adata.X.sum(axis=1).A1 if sp.isspmatrix(adata.X) else adata.X.sum(axis=1)
        adata = adata[spot_expression_sums >= 100].copy()
        print(f'The shape after filtering spots with gene expression count < 100 is {adata.X.shape[0]}')
        print('\n')

        adata.obs['spot_id'] = adata.obs.index
        adata.obs.reset_index(drop=True, inplace=True)
        adata_coords = adata.obsm['spatial']
        adata_obs = adata.obs
        
        adata_obs = adata_obs.reset_index().rename(columns={'index':'idx'})
        adata_obs['idx'] = adata_obs['idx'].astype('i8')


        # Tokenize
        x = adata.X
        x = np.nan_to_num(x) # is NaN values, fill with 0s
        x = sf_normalize(x)
        median_counts_per_gene = mean
        median_counts_per_gene += median_counts_per_gene == 0
        out = x / median_counts_per_gene.reshape((1, -1))
        tokenized_idx = tokenize_data(out, 4096, 30)

        # load data
        image_arrays = []
        image_aug_arrays = []
        image_aug_pos_arrays = []
        spot_names = []
        spot_pos_arrays = []
        dense_expression_array = []

        dense_expression = expression_list[idx]
        assert dense_expression.shape[0] == adata_obs.shape[0]
        
        for i in tqdm(range(len(adata_obs)), ncols=0, total=len(adata_obs)):
            spot_id = adata_obs['spot_id'][i]
            row, col = adata_obs['array_row'][i], adata_obs['array_col'][i]
            x, y = adata_coords[i]
            top_left = (int(x - patch_size), int(y - patch_size))
            bottom_right = (int(x + patch_size), int(y + patch_size))
            top_left_aug, bottom_right_aug, pos_label = get_safe_region(*(x,y), patch_size, wsi_width, wsi_height)

            img =  wsi[top_left[1]:bottom_right[1], top_left[0]:bottom_right[0]]
            img_aug = wsi[top_left_aug[1]:bottom_right_aug[1], top_left_aug[0]:bottom_right_aug[0]]
            
            if img.size == 0:
                print(f"Warning: Patch out of image bounds for barcode {spot_id}. Skipping.")
                continue
            if img_aug.size == 0:
                print(f"Warning: AUG Patch out of image bounds for barcode {spot_id}. Adjustting.")
                print(f"Befor adjuest: ", top_left_aug, bottom_right_aug)
                top_left_aug, bottom_right_aug = adjust_crop(top_left_aug, bottom_right_aug, wsi.shape[0], wsi.shape[1], 3 * patch_size)
                img_aug = wsi[top_left_aug[1]:bottom_right_aug[1], top_left_aug[0]:bottom_right_aug[0]]
                print('After adjuest:', top_left_aug, bottom_right_aug)
            
            img = cv2.resize(img, (224, 224))
            img_aug = cv2.resize(img_aug, (224, 224))
            img_array = np.array(img)# (224,224,3)
            img_aug_array = np.array(img_aug)
            image_arrays.append(img_array) 
            image_aug_arrays.append(img_aug_array)
            spot_names.append(spot_id)
            image_aug_pos_arrays.append(pos_label)
            spot_pos_arrays.append((row, col))
            dense_expression_array.append(dense_expression[i])


        image_arrays = np.stack(image_arrays)
        image_arrays = np.transpose(image_arrays, (0, 3, 1, 2))
        
        image_aug_arrays = np.stack(image_aug_arrays)
        image_aug_arrays = np.transpose(image_aug_arrays, (0, 3, 1, 2))

        spot_pos_arrays = np.stack(spot_pos_arrays)
        dense_expression_array = np.stack(dense_expression_array)
        print(f'Sample {sample_id} has {expression.shape[0]} spots and {expression.shape[1]} genes.')
        assert len(adata_obs) == len(adata_coords) == len(tokenized_idx) == len(image_arrays) == len(image_aug_arrays) == len(image_aug_pos_arrays) == len(spot_names) == len(spot_pos_arrays) == len(dense_expression_array)
        print(label_dict)
        with h5py.File(f'{save_dir}/{sample_id}.h5', 'w') as f:
            f.create_dataset('tokenized_gene', data=tokenized_idx)
            f.create_dataset('images', data=image_arrays)
            if np.sum(adata.obs['label']) != 0: 
                f.create_dataset('spot_label', data=np.array(adata.obs['label'], dtype=int))
            f.create_dataset('spot_names', data=np.array(spot_names, dtype='S'))
            # Only the gene expression values are saved in the downstream dataset of the gene expression prediction task.
            if project in ['PSC', 'HER2+', 'HHK']: 
                f.create_dataset('expression', data=dense_expression_array, compression="gzip")
            f.create_dataset('spot_pos', data=spot_pos_arrays, dtype=int)