
import torch
from my_tomotwin.modules.networks.networkmanager import NetworkManager
import tqdm
import pandas as pd
import os
import json
import copy
import datetime
from .mrctools import save_mrc_data


def load_tomotwin_model(weightspth, load_weights=True, device=None):
    """
    Adapted from tomotwin.modules.inference.embedor
    """
    if device is not None:
        checkpoint = torch.load(weightspth, map_location=device)
    else:
        checkpoint = torch.load(weightspth)
    tomotwin_config = checkpoint["tomotwin_config"]
    print("Model config:")
    print(tomotwin_config)
    model = NetworkManager.create_network(tomotwin_config).get_model()
    if load_weights:
        before_parallel_failed = False
        if checkpoint is not None:
            try:
                model.load_state_dict(checkpoint["model_state_dict"])
            except RuntimeError:
                print("Load before failed")
                before_parallel_failed = True
        #model = torch.nn.DataParallel(model)
        if before_parallel_failed:
            model.load_state_dict(checkpoint["model_state_dict"])
        print("Successfully loaded model weights")
    else:
        print("Model weights not loaded")
    return model.to(device)

def pass_subtomos_through_tomotwin(subtomos, tomotwin_model_file, batch_size=32, device="cpu"):    
    if not torch.is_tensor(subtomos):
        raise ValueError("Subtomos must be a 4-dim torch tensor! Got type: ", type(subtomos))
    if not subtomos.ndim == 4:
        raise ValueError("Subtomos must be a 4-dim torch tensor! Input has dim: ", subtomos.ndim)
    if not subtomos.shape[-1] == subtomos.shape[-2] == subtomos.shape[-3] == 37:
        raise ValueError("Subtomos must be 37x37x37 for TomoTwin! Input has shape: ", subtomos.shape)
    # tomotwin was trained with mean 0 and std 1 subtomos
    subtomos -= subtomos.mean(dim=(-1, -2, -3), keepdim=True)
    subtomos /= subtomos.std(dim=(-1, -2, -3), keepdim=True)
    subtomos = subtomos.view(len(subtomos), 1, 37, 37, 37)
    loader = torch.utils.data.DataLoader(subtomos, batch_size=batch_size, shuffle=False)
    tt = load_tomotwin_model(weightspth=tomotwin_model_file, device=device).eval()

    with torch.no_grad():
        tt_outputs = []
        for batch in tqdm.tqdm(loader, "Passing subtomos through TomoTwin"):
            tt_output = tt(batch.to(device))
            if isinstance(tt_output, tuple):
                tt_output = tt_output[0].cpu()
            if not tt_output.shape[1] == 32:
                raise ValueError("Tomotwin output is not shape 32! Something went wrong!")
            tt_outputs.append(tt_output)
        tt_outputs = torch.cat(tt_outputs)
 
    return tt_outputs

def eval_tomotwin_on_mask(tomo, mask, tomotwin_model_file, batch_size, device="cpu"):
    subtomos = []
    tomo = torch.nn.functional.pad(tomo, (19, 19, 19, 19, 19, 19), mode="constant")
    coords = mask.nonzero() + 19
    for centre in coords:
        subtomos.append(tomo[centre[0]-19:centre[0]+18, centre[1]-19:centre[1]+18, centre[2]-19:centre[2]+18])
    subtomos = torch.stack(subtomos)
    tt_outputs = pass_subtomos_through_tomotwin(subtomos, tomotwin_model_file, batch_size, device)
    return tt_outputs
 
def load_gt_positions_from_tomotwin_coords_file(coords_file, skiprows=0):
    gt_positions = pd.read_csv(coords_file, header=None, skiprows=skiprows)
    gt_positions.columns = ["class", "X", "Y", "Z", "rx", "ry", "rz"]
    gt_positions = _add_size(gt_positions, size=37, size_dict=SIZE_DICT)
    # if convert_to_voxels:
    #     x = -512 / 2 + gt_positions["X"].values
    #     y = -512 / 2 + gt_positions["Y"].values
    #     z = -200 / 2 + gt_positions["Z"].values

    #     factor = 10.2 / 10
    #     x *= factor
    #     y *= factor
    #     z *= factor

    #     x = x + 512 / 2
    #     y = y + 512 / 2
    #     z = z + 200 / 2
    #     gt_positions["X"] = x
    #     gt_positions["Y"] = y
    #     gt_positions["Z"] = z
    return gt_positions

def mask_tomotwin_map(df_map, mask):
    z = torch.arange(0, mask.shape[0], 1)
    y = torch.arange(0, mask.shape[1], 1)
    x = torch.arange(0, mask.shape[2], 1)
    z, y, x = torch.meshgrid(z, y, x)
    # get z y x where mask is True
    z_mask = z[mask]
    y_mask = y[mask]
    x_mask = x[mask]
    df_mask = pd.DataFrame({"Z": z_mask, "Y": y_mask, "X": x_mask})
    # intersect df_map with df_mask
    df_map_mask = pd.merge(df_map, df_mask, how="inner", on=["Z", "Y", "X"])
    # save attributes because tomotwin needs themn later on 
    df_map_mask.attrs = copy.deepcopy(df_map.attrs)
    print(f"Mask selected {(len(df_map_mask)/len(df_map)*100):.2f}% of the heatmap")
    return df_map_mask

def run_tomotwin_locate_command(tt_map_file, out_dir, gt_positions_file=None, mask=None, findmax_global_min=0.5, findmax_tolerance=0.2, limit_to_pdbs=None, save_tomotwin_heatmaps=False, optmize_parameters=True, tomotwin_conda_env="tomotwin", boxsizes_dict=None, time_commands=False):
    # setup output directories
    if not os.path.exists(out_dir):
        print(f"Creating output directory: {out_dir}")
        os.makedirs(out_dir)
        # for intermediate results
        os.makedirs(f"{out_dir}/map", exist_ok=True)
        os.makedirs(f"{out_dir}/locate", exist_ok=True)


    df_map = pd.read_pickle(tt_map_file)
    df_map_mask = mask_tomotwin_map(df_map, mask) if mask is not None else copy.deepcopy(df_map)

    # reset references according to limit_to_pdbs
    if limit_to_pdbs is not None:
        refs = df_map.attrs["references"]
        df_map_mask.attrs["references"] = []
        pdbs_of_interest_ids = []
        for pdb_of_interest in limit_to_pdbs:
            found = False
            for id, ref in enumerate(refs):
                if pdb_of_interest in ref or pdb_of_interest.upper() in ref or pdb_of_interest.lower() in ref:
                    found = True
                    pdbs_of_interest_ids.append(id)
            if not found:
                print(f"Could not find reference for pdb '{pdb_of_interest}' in the map file. Skipping.")
        df_map_mask.attrs["references"] = [refs[id] for id in pdbs_of_interest_ids]
        df_map_mask = df_map_mask[["X", "Y", "Z"] +  [f"d_class_{id}" for id in pdbs_of_interest_ids]]
        # rename f"d_class_{pdb_of_interest_ids[0]}"" to d_class_0 etc
        rename_dict = {f"d_class_{id}": f"d_class_{i}" for i, id in enumerate(pdbs_of_interest_ids)} 
        df_map_mask = df_map_mask.rename(columns=rename_dict)
    
    # save the new map file
    new_map_file = f"{out_dir}/map/map.tmap"
    df_map_mask.to_pickle(new_map_file)
    # picking
    start = datetime.datetime.now()
    os.system(
        f"conda run -n {tomotwin_conda_env} tomotwin_locate.py findmax \
            -m {new_map_file} \
            -o {out_dir}/locate/ \
            --global_min {findmax_global_min} \
            --tolerance {findmax_tolerance} \
            {'--write_heatmaps' if save_tomotwin_heatmaps else ''} \
        "
    )
    end = datetime.datetime.now()
    if time_commands:
        print(f"Locating took {end-start}")
    # evaluation wrt ground truth
    if gt_positions_file is not None:
        if boxsizes_dict is None:
            print("WARNING: No boxsizes_json given! Evaluation will be done with default tomotwin boxsizes which may not contain your particles!")
            boxsizes_dict = SIZE_DICT
        
        boxsizes_json = f"{out_dir}/boxsizes.json"
        with open(f"{out_dir}/boxsizes.json", "w") as f:
            json.dump(boxsizes_dict, f)
        start = datetime.datetime.now()
        os.system(
            f"conda run -n {tomotwin_conda_env} tomotwin_scripts_evaluate.py positions \
                -p {gt_positions_file} \
                -l {out_dir}/locate/*.tloc \
                -s {boxsizes_json} \
                {'--optim' if optmize_parameters else ''} \
            "
        )
        end = datetime.datetime.now()
        if time_commands:
            print(f"Evaluation took {end-start}")


def run_tomotwin_map_command(tomo, tomotwin_model_file, out_dir, references, stride=2, batch_size=128, devices="0", skip_tomo_embedding_if_exists=False, tomotwin_conda_env="tomotwin"):

    print("WARNING: References have to be scaled such that particles appear white on a dark background! In the tomogram particles are assumed to be black on a white background. Make sure to invert the references if necessary!")
    if not (type(references) == list or type(references) == dict):
        raise ValueError("References must be a list or a dictionary!")
    if type(references) == list:
        # the format {i}ref is important because tomotwin needs this for its evaluation
        references = {f"{i}ref": ref for i, ref in enumerate(references)}


    if not os.path.exists(out_dir):
        print(f"Creating output directory: {out_dir}")
    os.makedirs(out_dir, exist_ok=True)
    os.makedirs(f"{out_dir}/embed/references", exist_ok=True)
    os.makedirs(f"{out_dir}/embed/tomogram", exist_ok=True)
    os.makedirs(f"{out_dir}/map", exist_ok=True)


    # if prompt_embeds_dict is not None and references is None:      
    #     ref_temb_file = f"{out_dir}/embed/references/embeddings.temb"
    #     if not os.path.exists(os.path.dirname(ref_temb_file)):
    #         os.makedirs(os.path.dirname(ref_temb_file), exist_ok=True)
    #     rows = []
    #     for pdb in prompt_embeds_dict.keys():
    #         row = {
    #             "filepath": f"{pdb}.mrc",
    #             **{i: prompt_embeds_dict[pdb][i].item() for i in range(32)}
    #         }
    #         rows.append(row)
    #     ref_embeds_temb = pd.DataFrame(rows)
    #     ref_embeds_temb.to_pickle(ref_temb_file)

    os.makedirs(f"{out_dir}/references", exist_ok=True)
    for ref_name, ref in references.items():
        save_mrc_data(ref, f"{out_dir}/references/{ref_name}.mrc")
    os.system(
        f"CUDA_VISIBLE_DEVICES={devices} conda run -n {tomotwin_conda_env} tomotwin_embed.py subvolumes \
            -m {tomotwin_model_file} \
            -v {out_dir}/references/*.mrc \
            -o {out_dir}/embed/references \
        "
    )

    run_tomo_embedding = True
    if skip_tomo_embedding_if_exists:
        if os.path.exists(f"{out_dir}/embed/tomogram/"):
            if len(os.listdir(f"{out_dir}/embed/tomogram/")) > 0:
                run_tomo_embedding = False
                print(f"Skipping tomogram embedding because {out_dir}/embed/tomogram/ is not empty")

    if run_tomo_embedding:
        if torch.is_tensor(tomo):
            print(f"Passed tensor of shape {tomo.shape} to as tomo. Saving to {out_dir}/tomogram/tomo.mrc")
            os.makedirs(f"{out_dir}/tomogram", exist_ok=True)
            tomo_file = f"{out_dir}/tomogram/tomo.mrc"
            save_mrc_data(tomo, tomo_file)
            tomo = tomo_file

        start = datetime.datetime.now()
        os.system(f"CUDA_VISIBLE_DEVICES={devices} conda run -n {tomotwin_conda_env} tomotwin_embed.py tomogram \
            -m {tomotwin_model_file} \
            -v {tomo} \
            -b {batch_size} \
            -o {out_dir}/embed/tomogram  \
            --stride {stride} \
        ")
        end = datetime.datetime.now()
        os.system(f"echo '{datetime.datetime.now()}: Embedding on GPUS {devices} took {end-start}' >> {out_dir}/runtime.txt")

    start = datetime.datetime.now()
    os.system(
        f"conda run -n {tomotwin_conda_env} tomotwin_map.py distance \
            -v {out_dir}/embed/tomogram/*.temb \
            -r {out_dir}/embed/references/*.temb \
            -o {out_dir}/map/ \
        "
    )
    end = datetime.datetime.now()
    os.system(f"echo '{datetime.datetime.now()}: Mapping took {end-start}' >> {out_dir}/runtime.txt")



SIZE_DICT = {
    "1AVO": 18,
    "1FZG": 28,
    "1JPM": 18,
    "2HMI": 14,
    "2VYR": 18,
    "3EWF": 18,
    "1E9R": 21,
    "1OAO": 20,
    "2DF7": 33,
    "5XNL": 33,
    "1UL1": 18,
    "2RHS": 19,
    "3MKQ": 33,
    "7EY7": 25,
    "3ULV": 28,
    "1N9G": 19,
    "7BLQ": 28,
    "6WZT": 27,
    "7EGQ":	25,
    "5VKQ":	30,
    "7LSY":	30,
    "7KDV":	29,
    "6LXV":	28,
    "7DD9":	25,
    "7AMV":	25,
    "7NHS":	24,
    "7E8H":	25,
    "7E1Y":	25,

    "2WW2": 20,
    "7VTQ":	28,
    "6YT5":	30,
    "7EGD":	32,
    "7SN7":	32,
    "7WOO":	35,
    "7MEI":	32,
    "7T3U":	30,
    "6Z6O":	35,
    "7BKC":	31,
    "7EEP":	34,

    "7E8S": 35,
    "7QJ0": 30,
    "7NYZ": 35,
    "6VQK": 22,
    "6ZIU": 30,
    "6X02": 26,
    "7E6G": 21,
    "7O01": 35,
    "6X5Z": 30,
    "7WBT": 21,
    "6VGR": 22,
    "4UIC": 23,
    "6Z3A": 28,
    "7KFE": 18,
    "7WI6": 23,
    "7SHK": 17,
    "5TZS": 37,
    "7EGE": 30,
    "7ETM": 21,
    "6SCJ": 30,
    "6TAV": 20,
    "2VZ9": 23,
    "6KLH": 21,
    "1KB9": 20,
    "3PXI": 18,
    "4YCZ": 18,
    "6IGC": 30,

    "6F8L":	18,
    "6JY0":	25,
    "6TA5":	37,
    "6TGC":	28,
    "2DFS":	30,
    "6KSP":	27,
    "7JSN":	24,
    "6KRK":	20,
    "7NIU":	23,
    "5A20":	35,

    "5OOL":	30,
    "6UP6":	33,
    "6I0D":	25,
    "6BQ1":	30,
    "7SFW":	26,
    "3LUE":	37,
    "6JK8":	20,
    "5H0S":	22,
    "6LX3":	17,
    "5LJO":	21,

    "6DUZ":	32,
    "4XK8":	23,
    "6XF8":	29,
    "6M04":	22,
    "6U8Q":	23,
    "6LXK":	24,
    "6CE7":	20,
    "5CSA":	28,
    "7SGM":	25,
    "7B5S":	25,

    "6GYM":	28,
    "6EMK":	27,
    "6W6M":	19,
    "7R04":	35,
    "5O32":	22,
    "6CES":	23,
    "2XNX":	25,
    "6LMT":	17,
    "7BLR":	25,
    "2R9R":	18,

    "6ZQJ": 24,
    "4WRM": 22,
    "7S7K": 23,
    "4V94": 37,
    "4CR2": 33,
    "1QVR": 25,
    "1BXN": 19,
    "3CF3": 25,
    "1U6G": 18,
    "3D2F": 22,
    "2CG9": 18,
    "3H84": 18,
    "3GL1": 13,
    "3QM1": 12,
    "1S3X": 12,
    "5MRC": 37,

    "1FPY": 18,
    "1FO4": 23,
    "1FZ8": 19,
    "1JZ8": 20,
    "4ZIT": 17,
    "5BK4": 25,
    "5BW9": 25,

    "1CU1": 17,
    "1SS8": 22,
    "6AHU": 21,
    "6TPS": 28,
    "6X9Q": 37,
    "6GY6": 31,
    "6NI9": 12,
    "6VZ8": 25,
    "4HHB": 12,
    "7B7U": 20,

    "6Z80": 18,
    "6PWE": 14,
    "6PIF": 20,
    "6O9Z": 21,
    "6ID1": 30,
    "5YH2": 16,
    "4RKM": 16,
    "1G3I": 16,
    "1DGS": 14,
    "1CLQ": 15,

    "7Q21": 20,
    "7KJ2": 25,
    "7K5X": 18,
    "7FGF": 14,
    "7CRQ": 22,
    "6YBS": 25,
    "5JH9": 23,
    "5A8L": 20,
    "3IF8": 15,
    "2B3Y": 14,

    "6VN1": 14,
    "6MRC": 23,
    "6CNJ": 25,
    "5G04": 26,
    "4QF6": 17,
    "1SVM": 18,
    "1O9J": 17,
    "1ASZ": 19,
    "VESICLE": None,
    "FIDUCIAL": 18
}


def _add_size(df, size, size_dict = SIZE_DICT) -> pd.DataFrame:
    """
    tomotwin.scripts.evaluation
    """
    if size_dict is None:
        size = size
        df["width"] = size
        df["height"] = size
        df["depth"] = size
    else:
        df["width"] = 0
        df["height"] = 0
        df["depth"] = 0
        for row_index, row in df.iterrows():
            try:
                s = size_dict[str(row["class"]).upper()]
            except KeyError:
                #print(f"Can't find {str(row['class']).upper()} in size dict. Use default size {size}")
                s = size
            df.at[row_index, "width"] = s
            df.at[row_index, "height"] = s
            df.at[row_index, "depth"] = s

    return df