import os
import math

import numpy as np
import pandas as pd
import tifffile as tiff
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from tqdm import tqdm


HPC_STORAGE_FILEPATH = "/../../../"


def extract_meta_to_image_dict() -> dict:
    rxrx3_core = load_dataset("recursionpharma/rxrx3-core")

    images = {}
    for idx, element in tqdm(enumerate(rxrx3_core['train']), total=len(rxrx3_core['train'])):
        img_key, channel = element['__key__'].rsplit('_', 1)
        if img_key not in images:
            images[img_key] = np.zeros((6, 512, 512), dtype=np.uint8)
        images[img_key][int(channel) - 1, :, :] = np.array(element['jp2'])
    return images


def make_tiles_and_df(meta_to_image_dict: dict, rxrx_metadata: pd.DataFrame, n_tiles: int = 4):
    # output dataframe size = len(meta_to_image_dict) * n_tiles
    tiled_folder_path = os.path.join(HPC_STORAGE_FILEPATH, f"tiled{n_tiles}x{n_tiles}")
    os.makedirs(tiled_folder_path, exist_ok=True)

    output_df = pd.DataFrame(
        columns=['filepath', 'plate', 'compound_id', 'compound_name', 'compound_uM', 'treatment',
                 'moa', 'rxrx_metadata_idx', 'well', 'rxrx_well_id', 'field', 'cell_type', 'tile']
    )

    x_tiles = int(math.sqrt(n_tiles))
    for idx, (key, image) in tqdm(enumerate(meta_to_image_dict.items()), total=len(meta_to_image_dict)):
        for i in range(x_tiles):
            for j in range(x_tiles):
                tile = image[:, i * 256:(i + 1) * 256, j * 256:(j + 1) * 256]
                experiment, plate, well_field = key.split('/')
                well, field = well_field.split('_')
                tile_output_fpath = os.path.join(tiled_folder_path, f"{experiment}_{plate}_{well}_{field}_{i}_{j}.tiff")
                # save the tile to disk
                tiff.imwrite(tile_output_fpath, tile)

                # get rxrx metadata
                rxrx_meta_row = rxrx_metadata[rxrx_metadata['well_id'] == f'{experiment}_{plate[5:]}_{well}'].iloc[0]
                if rxrx_meta_row['treatment'] == 'EMPTY_control':
                    compound_name = 'DMSO'
                    compound_um = "0"
                elif rxrx_meta_row['perturbation_type'] == 'CRISPR':
                    compound_name = rxrx_meta_row['gene']
                    compound_um = "0"
                else:
                    compound_name = rxrx_meta_row['treatment']
                    compound_um = str(rxrx_meta_row['concentration'])

                # add the tile to the dataframe
                output_df = pd.concat([output_df, pd.DataFrame([{
                    'filepath': tile_output_fpath,
                    'plate': f"{experiment}__{plate}",
                    'compound_id': compound_name,
                    'compound_name': compound_name,
                    'compound_uM': compound_um,
                    'treatment': compound_name + "@" + compound_um,
                    'moa': compound_name,
                    'rxrx_metadata_idx': rxrx_meta_row.name,
                    'well': well,
                    'rxrx_well_id': rxrx_meta_row['well_id'],
                    'field': field,
                    'cell_type': 'HUVEC',
                    'tile': f"{i}_{j}",
                }])], ignore_index=True)

    # save the dataframe to disk
    output_df.to_csv(os.path.join(HPC_STORAGE_FILEPATH, f"rxrx3_core_tiled{n_tiles}x{n_tiles}.csv"), index=False)


def main():
    file_path_metadata = hf_hub_download(
        "recursionpharma/rxrx3-core",
        filename="metadata_rxrx3_core.csv",
        repo_type="dataset",
    )
    rxrx3_core_metadata = pd.read_csv(file_path_metadata)
    rxrx3_core_metadata.to_csv(os.path.join(HPC_STORAGE_FILEPATH, "metadata_rxrx3_core.csv"), index=False)

    print("Extracting images from rxrx3-core dataset, and stacking channels")
    meta_to_image = extract_meta_to_image_dict()
    print("Extracted images, now making tiles and dataframe")
    make_tiles_and_df(meta_to_image, rxrx3_core_metadata, n_tiles=4)


if __name__ == '__main__':
    main()

