from .setup_picking_data import setup_segmentation_subtomos
from .utils import load_mrc_data
import pandas as pd


def setup_segmentation_tomos_shrec2020_model(shrec_model_dir, out_dir, subtomo_size=37, subtomo_extraction_strides=None, save_full_locmaps=False, tomotwin_model_file=None, setup_tomotwin_reference_embeddings=True, add_background_class_to_locmap=False, skip_existing=True, crop_tomo_fn=None):
    """
    Wrapper for setup_segmentation_subtomos that takes a SHREC2020 model directory as input
    Example for shrec_model_dir: .../full_dataset/model_0
    """
    locmaps = get_shrec2020_prtcl_locmaps(f"{shrec_model_dir}/class_mask.mrc")
    tomo = -1 * load_mrc_data(f"{shrec_model_dir}/reconstruction.mrc")[156:356] # -1 because tomotwin needs reversed contrast, so I adapted this
    pdb_coord_dict = get_shrec2020_pdb_coord_dict(f"{shrec_model_dir}/particle_locations.txt")
    setup_segmentation_subtomos(
        tomo=tomo,
        locmaps=locmaps,
        pdb_coord_dict=pdb_coord_dict,
        out_dir=out_dir,
        subtomo_size=subtomo_size,
        subtomo_extraction_strides=subtomo_extraction_strides,
        save_full_locmaps=save_full_locmaps,
        tomotwin_model_file=tomotwin_model_file,
        setup_tomotwin_reference_embeddings=setup_tomotwin_reference_embeddings,
        add_background_class_to_locmap=add_background_class_to_locmap,
        skip_existing=skip_existing,
        crop_tomo_fn=crop_tomo_fn,
    )


def get_shrec2020_prtcl_locmaps(shrec_occupancy_class_file):
    classes = ['0', '3cf3', '1s3x', '1u6g', '4cr2', '1qvr', '3h84', '2cg9', '3qm1', '3gl1', '3d2f', '4d8q', '1bxn']
    id_to_pdb = {k: v for k, v in enumerate(classes)}
    occ = load_mrc_data(shrec_occupancy_class_file)
    locmaps = {
        id_to_pdb[i]: (occ == i).float() for i in id_to_pdb.keys()
    }
    return locmaps

def get_shrec2020_pdb_coord_dict(coord_file):
    coord = pd.read_csv(coord_file, sep=" ", header=None)
    coord.columns = ["pdb", "x", "y", "z", "alpha", "beta", "gamma"]
    #coord.z -= 156
    # revmove columsn with pdb = vesicle or fiducial
    coord = coord[~coord.pdb.isin(["vesicle", "fiducial"])]
    pdb_coord_dict = {
        pdb: coord[coord.pdb == pdb] for pdb in coord.pdb.unique()
    }
    return pdb_coord_dict