#%%
import pandas as pd
import torch
import tqdm

from .setup_picking_data import setup_segmentation_subtomos, extract_prtcls
from .utils import load_mrc_data

def load_shrec2019_prtcl_positions(coords_file):
    with open(coords_file, "r") as f:
        coords = [line.strip().split() for line in f.readlines()]
    coords = pd.DataFrame(coords, columns=["pdb", "z", "y", "x", "alpha", "beta", "gamma"])  # x and y are switched in the file for some reason
    # set columns "z", "y", "x" to int
    coords[["z", "y", "x"]] = coords[["z", "y", "x"]].astype(int)
    # set columns "alpha", "beta", "gamma" to float
    coords[["alpha", "beta", "gamma"]] = coords[["alpha", "beta", "gamma"]].astype(float)
    return coords

def get_shrec2019_pdb_coord_dict(coords_file):
    coords = load_shrec2019_prtcl_positions(coords_file)
    pdb_to_coords = {}
    for pdb, pdb_coords in coords.groupby("pdb"):
        pdb_to_coords[pdb] = pdb_coords
    return pdb_to_coords

def get_shrec2019_prtcl_locmaps(shrec_prtcl_locations_file, device="cpu", limit_to_pdbs=None):
    if device == "cpu":
        print("WARNING: Using CPU to generate locmaps is very slow, consider using a GPU")
    coords = load_shrec2019_prtcl_positions(shrec_prtcl_locations_file)
    coord_grid = torch.meshgrid([
        torch.arange(200, device=device),
        torch.arange(512, device=device), 
        torch.arange(512, device=device), 
    ])
    coord_grid = torch.stack(coord_grid, dim=-1)

    pdbs = ['1bxn', '1qvr', '1s3x', '1u6g', '2cg9', '3cf3', '3d2f', '3gl1', '3h84', '3qm1', '4b4t', '4d8q']
    radii = [6, 6, 3, 6, 6, 7, 6, 4, 4, 3, 10, 8]

    if limit_to_pdbs is not None:
        pdbs = [pdb for pdb in pdbs if pdb in limit_to_pdbs]
        radii = [radii[i] for i, pdb in enumerate(pdbs)]

    pdb_to_radius = dict(zip(pdbs, radii))

    locmap_dict = {}
    for pdb in pdbs:
        pdb_coords = coords[coords["pdb"] == pdb]
        pdb_locmap = torch.zeros(200, 512, 512, device=device)
        pdb_radius = torch.tensor(pdb_to_radius[pdb], device=device)
        for _, row in tqdm.tqdm(pdb_coords.iterrows(), total=len(pdb_coords), desc=f"Get locmap for {pdb}"):
            pdb = row["pdb"]
            x, y, z = int(row["x"]), int(row["y"]), int(row["z"])
            center = torch.tensor([z, y, x], device=device) 
            pdb_ball = (coord_grid - center).pow(2).sum(-1) < pdb_radius**2
            pdb_locmap[pdb_ball] = 1

        locmap_dict[pdb] = pdb_locmap.cpu()
    return locmap_dict


def setup_segmentation_tomos_shrec2019_model(shrec_model_dir, out_dir, subtomo_size=37, subtomo_extraction_strides=None, save_full_locmaps=False, tomotwin_model_file=None, setup_tomotwin_reference_embeddings=True, add_background_class_to_locmap=False, skip_existing=True, device="cuda:0", crop_tomo_fn=None):
    """
    Wrapper for setup_segmentation_subtomos that takes a SHREC2019 model directory as input
    Example for shrec_model_dir: .../full_dataset/model_0
    """
    model_id = shrec_model_dir.split("/")[-1]
    particle_locations_file = f"{shrec_model_dir}/particle_locations_model_{model_id}.txt"
    locmaps = get_shrec2019_prtcl_locmaps(particle_locations_file, device=device)
    tomo = -1 * load_mrc_data(f"{shrec_model_dir}/reconstruction_model_{model_id}.mrc")[156:-156] # -1 because tomotwin needs reversed contrast, so I adapted this
    pdb_coord_dict = get_shrec2019_pdb_coord_dict(particle_locations_file)
    setup_segmentation_subtomos(
        tomo=tomo,
        locmaps=locmaps,
        pdb_coord_dict=pdb_coord_dict,
        out_dir=out_dir,
        subtomo_size=subtomo_size,
        subtomo_extraction_strides=subtomo_extraction_strides,
        save_full_locmaps=save_full_locmaps,
        tomotwin_model_file=tomotwin_model_file,
        setup_tomotwin_reference_embeddings=setup_tomotwin_reference_embeddings,
        add_background_class_to_locmap=add_background_class_to_locmap,
        skip_existing=skip_existing,
        crop_tomo_fn=crop_tomo_fn,
    )