from .setup_picking_data import setup_segmentation_subtomos
from .utils import load_mrc_data
import pandas as pd


def setup_segmentation_tomos_shrec2021_model(shrec_model_dir, out_dir, subtomo_size=37, subtomo_extraction_strides=None, save_full_locmaps=False, tomotwin_model_file=None, setup_tomotwin_reference_embeddings=True, add_background_class_to_locmap=False, skip_existing=True, crop_tomo_fn=None):
    """
    Wrapper for setup_segmentation_subtomos that takes a SHREC2021 model directory as input
    Example for shrec_model_dir: .../full_dataset/model_0
    """
    locmaps = get_shrec2021_prtcl_locmaps(f"{shrec_model_dir}/class_mask.mrc")
    tomo = -1 * load_mrc_data(f"{shrec_model_dir}/reconstruction.mrc")[170:350] # -1 because tomotwin needs reversed contrast, so I adapted this
    pdb_coord_dict = get_shrec2021_pdb_coord_dict(f"{shrec_model_dir}/particle_locations.txt")
    setup_segmentation_subtomos(
        tomo=tomo,
        locmaps=locmaps,
        pdb_coord_dict=pdb_coord_dict,
        out_dir=out_dir,
        subtomo_size=subtomo_size,
        subtomo_extraction_strides=subtomo_extraction_strides,
        save_full_locmaps=save_full_locmaps,
        tomotwin_model_file=tomotwin_model_file,
        setup_tomotwin_reference_embeddings=setup_tomotwin_reference_embeddings,
        add_background_class_to_locmap=add_background_class_to_locmap,
        skip_existing=skip_existing,
        crop_tomo_fn=crop_tomo_fn,
    )


def get_shrec2021_prtcl_locmaps(shrec_occupancy_class_file):
    id_to_pdb = {
        #0: "background",
        1: "4v94",  # somehow not in the dataset
        2: "4cr2", 
        3: "1qvr", 
        4: "1bxn", 
        5: "3cf3", 
        6: "1u6g", 
        7: "3d2f", 
        8: "2cg9", 
        9: "3h84", 
        10: "3gl1", 
        11: "3qm1", 
        12: "1s3x", 
        13: "5mrc", 
    }
    occ = load_mrc_data(shrec_occupancy_class_file)[170:350]
    locmaps = {
        id_to_pdb[i]: (occ == i).float() for i in id_to_pdb.keys()
    }
    return locmaps

def get_shrec2021_pdb_coord_dict(coord_file):
    coord = pd.read_csv(coord_file, sep=" ", header=None)
    coord.columns = ["pdb", "x", "y", "z", "alpha", "beta", "gamma"]
    coord.z -= 170
    # revmove columsn with pdb = vesicle or fiducial
    coord = coord[~coord.pdb.isin(["vesicle", "fiducial"])]
    pdb_coord_dict = {
        pdb: coord[coord.pdb == pdb] for pdb in coord.pdb.unique()
    }
    return pdb_coord_dict