#%%
import sys
sys.path.append(".")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)  # suppress nasty torch FutureWarning when loading checkpoints

import pandas as pd
import os
import json
import torch
from data.utils import load_h5_tensor_dict, save_h5_tensor_dict
import copy

from train import MyModel
from utils.mrctools import load_mrc_data
from utils.tomotwin import load_gt_positions_from_tomotwin_coords_file

from clustering_and_picking.tomotwin_evaluation_routine import SIZE_DICT, _add_size
from backup_tools import try_load_model_from_backup_file


from inference.get_pred_locmap_dict import get_pred_locmap_dict
from clustering_and_picking import get_best_case_cluster_based_picking_performance, get_best_case_tomotwin_picking_performance

from data_paths import TOMOTWIN_TOMO_BASE_DIR, TOMOTWIN_TRAIN_BASE_DIR, TROMOTWIN_VAL_BASE_DIR, FIXED_TOMOTWIN_TOMOS_PROMPTS_JSON, TOMOTWIN_DEMO_DIR
from evaluate_cfg import device, out_dir, subtomo_size, subtomo_overlap, batch_size, model_ckpts, backup_files, runs, model_ckpt_to_name_fn, get_best_case_cluster_based_picking_performance_args, save_pred_locmap_dicts, metric_of_interest, compute_tomotwin_picking_performance, limit_to_pdbs
from matplotlib import pyplot as plt
import numpy as np
import sys


def plot_metric_boxplots(all_metrics_dict):
    # copy all_metrics_dict to avoid modifying the original dict in place
    all_metrics_dict_ = copy.deepcopy(all_metrics_dict)

    all_metrics_dict_ = dict(sorted(all_metrics_dict_.items(), key=lambda x: x[0].split("_")[-1], reverse=True))

    n_runs_per_round, n_pdbs_per_round = {}, {}  # keep track how many runs were done per round
    all_metrics_list = []  # append all metrics of all pdbs in all runs to this list
    
    for round, round_metrics in all_metrics_dict_.items():
        n_runs_per_round[round] = len(round_metrics[list(round_metrics.keys())[0]])
        n_pdbs_per_round[round] = len(round_metrics.keys())
        for pdb in round_metrics.keys():
            round_metrics[pdb] = sum(round_metrics[pdb]) / len(round_metrics[pdb])
            all_metrics_list.append(round_metrics[pdb])

    fig, ax = plt.subplots(1, figsize=(5, 1*(len(all_metrics_dict_)+1)))
    
    # plot vertical boxplot for each round
    ax.set_title(f"{metric_of_interest} per tomotwin round")
    labels = [
        f"{round}\nruns={n_runs_per_round[round]}\npdbs={n_pdbs_per_round[round]}" for round in all_metrics_dict_.keys()
    ]
    labels += [f"all_rounds\nruns={sum(list(n_runs_per_round.values()))}\npdbs={len(all_metrics_list)}"]
    
    boxplot_data = [list(round_metrics.values()) for round_metrics in all_metrics_dict_.values()] + [all_metrics_list]
    boxplot = ax.boxplot(boxplot_data, labels=labels, vert=False, showmeans=True)
    
    ax.set_xlim(-0.05, 1.05)
    
    # Annotating the min, median, and max values on the plot
    for i, data in enumerate(boxplot_data):
        y_jittered = np.random.normal(i + 1, 0.04, size=len(data))
        ax.scatter(data, y_jittered, alpha=0.3, color='black', zorder=3, marker='o')
        
        min_val = np.min(data)
        max_val = np.max(data)
        median_val = np.median(data)
        
        # Median annotation
        ax.annotate(f'{median_val:.2f}', 
                    xy=(median_val, i + 1), 
                    xytext=(0, -25),
                    textcoords='offset points', 
                    ha='center', 
                    color='red')
        
        # Min annotation
        ax.annotate(f'{min_val:.2f}', 
                    xy=(min_val, i + 1), 
                    xytext=(0, -25),
                    textcoords='offset points', 
                    ha='center', 
                    color='blue')
        
        # Max annotation
        ax.annotate(f'{max_val:.2f}', 
                    xy=(max_val, i + 1), 
                    xytext=(0, -25),
                    textcoords='offset points', 
                    ha='center', 
                    color='blue')
    plt.tight_layout()
    return fig

tomotwin_round_pdbs_dict = {
    "demo": ["1avo", "1e9r", "1fpy", "1fzg", "1jz8", "1oao", "2df7"],
    "1": ["1ss8", "4wrm", "6gy6", "6ahu", "6tps", "6vz8", "6x9q", "6zqj", "7b7u", "7s7k"],
    "2": ["1g3i", "6id1", "6pif", "6wzt", "6z80", "6ziu", "7blq", "7e6g", "7nyz", "7qj0"],
    "3": ["4uic", "5jh9", "6igc", "6vgr", "6x5z", "7k5x", "7kj2", "7o01", "7q21", "7wbt"],
    "4": ["1ul1", "5g04", "6cnj", "6mrc", "6z3a", "6vn1", "7kfe", "7shk"],
    "5": ["1n9g", "2vz9", "2ww2", "3ulv", "6klh", "6scj", "6tav", "7ege", "7etm", "7ey7"],
    "6": ["2dfs", "5a20", "6f8l", "6jy0", "6krk", "6ksp", "6ta5", "6tgc", "7jsn", "7niu"],
    "7": ["3lue", "3mkq", "5h0s", "5ljo", "5ool", "6bq1", "6i0d", "6lx3", "6up6", "7sfw"],
    "8": ["2rhs", "4xk8", "5csa", "6duz", "6lxk", "6m04", "6u8q", "6xf8", "7b5s", "7sgm"],
    "9": ["2r9r", "2xnx", "5o32", "6ces", "6emk", "6gym", "6lmt", "6w6m", "7blr", "7r04"],
    "10": ["6yt5", "6z6o", "7bkc", "7eep", "7egd", "7mei", "7sn7", "7t3u", "7vtq", "7woo"],
    "11": ["5vkq", "6lxv", "7amv", "7dd9", "7e1y", "7e8h", "7egq", "7kdv", "7lsy", "7nhs"],
}




#%%
if __name__ == "__main__":
    LIMIT_TO_PDBS = limit_to_pdbs

    for model_ckpt in model_ckpts:
        all_metrics_dict = {}

        # setup directory in which results for this model will be saved
        model_run_name = model_ckpt_to_name_fn(model_ckpt)
        out_dir_model = f"{out_dir}/{model_run_name}"
        try:
            # if the model ckpt contains an epoch, create a subdirectory for this epoch
            if "min_val_loss" in model_ckpt:
                min_val_loss = model_ckpt.split("min_val_loss_epoch=")
                out_dir_model = f"{out_dir_model}/min_val_loss_epoch={min_val_loss[1].split('.ckpt')[0]}"
            elif "min_exclusive_val_loss" in model_ckpt:
                min_val_loss = model_ckpt.split("min_exclusive_val_loss_epoch=")
                out_dir_model = f"{out_dir_model}/min_exclusive_val_loss_epoch={min_val_loss[1].split('.ckpt')[0]}"
            else:
                epoch = model_ckpt.split("epoch=")[1].split(".ckpt")[0]
                # check if an epoch has ben found
                out_dir_model = f"{out_dir}/{model_run_name}/epoch={epoch}"
        except:
            pass

        out_dir_model = f"{out_dir_model}/stride={subtomo_size-subtomo_overlap}"
        
        # make the output directory
        print(f"Saving all outputs to '{out_dir_model}'")
        if not os.path.exists(out_dir_model):
            os.makedirs(out_dir_model)

        # copy this file and evaluate_cfg.py to the output directory
        eval_code_backup_dir = f"{out_dir_model}/backup_eval_script_and_cfg"
        os.makedirs(eval_code_backup_dir, exist_ok=True)
        os.system(f"cp {__file__} {eval_code_backup_dir}")
        os.system(f"cp evaluate_cfg.py {eval_code_backup_dir}")

        model = try_load_model_from_backup_file(model_ckpt, backup_file=None, model_class_name="MyModel")
        if model is None:
            print(sys.path)
            print("Could not load model from backup file. Trying to load from default path.")
            model = MyModel.load_from_checkpoint(model_ckpt)
        else:
            print("Succcessfully loaded model from backup.")

        model = model.to(device).eval()
        model.freeze()

        for run in runs:
            out_dir_run = f"{out_dir_model}/{run}"
            if not os.path.exists(out_dir_run):
                os.makedirs(out_dir_run)
            
            
            if run == "tomo_simulation_round_demo/tomo_01":
                round = "demo"
                tomo_id = "01"
                tomo_file = f"{TOMOTWIN_DEMO_DIR}/data/tomogram/tomo.mrc"
                gt_positions_file = f"{TOMOTWIN_DEMO_DIR}/data/gt.txt"
                gt_positions = load_gt_positions_from_tomotwin_coords_file(gt_positions_file, skiprows=1)  # skiprows=1 because the first row is a header, this is only for the demo run
            else:
                tomo_file = f"{TOMOTWIN_TOMO_BASE_DIR}/{run}/tiltseries_rec.mrc"
                # load ground truth particle positions
                round = run.split("/")[0].split("_")[-1]
                tomo_id = run.split("/")[1].split(".")[0].split("tomo_")[-1]
                found_gt_positions = False
                gt_positions_file = f"{TROMOTWIN_VAL_BASE_DIR}/round_{round}/tomo_{tomo_id}/particle_positions.txt"
                if os.path.exists(gt_positions_file):
                    found_gt_positions = True
                else: 
                    gt_positions_file = f"{TOMOTWIN_TRAIN_BASE_DIR}/round_{round}/tomo_{tomo_id}/particle_positions.txt"
                    if os.path.exists(gt_positions_file.replace("/validation/", "/training/")):
                        found_gt_positions = True
                if not found_gt_positions:
                    print(f"Ground truth coordinates for run '{run}' were not found in '{TOMOTWIN_TRAIN_BASE_DIR}' or '{TROMOTWIN_VAL_BASE_DIR}'! Skipping run.")
                    continue            
                gt_positions = load_gt_positions_from_tomotwin_coords_file(gt_positions_file, skiprows=0)
            
            # load tomogram
            tomo = -1 * load_mrc_data(tomo_file)



            # load prompts
            with open(FIXED_TOMOTWIN_TOMOS_PROMPTS_JSON, "r") as f:
                prompt_embeds_dict = json.load(f)
            prompt_embeds_dict = {k: torch.tensor(v).float() for k, v in prompt_embeds_dict.items()}
            pdbs = tomotwin_round_pdbs_dict[round]
            if LIMIT_TO_PDBS is None:
                limit_to_pdbs = pdbs
            if "fine_tune" in model_ckpt:
                limit_to_pdbs = model_ckpt.split("pdbs=")[1].split("/")[0].split(",")
            prompt_embeds_dict = {pdb: prompt_embeds_dict[pdb] for pdb in prompt_embeds_dict.keys() if pdb in pdbs and pdb in limit_to_pdbs}


            # run inference
            print(f"Running inference for pdbs {list(prompt_embeds_dict.keys())} on tomogram {run}")
            pred_locmap_dict = get_pred_locmap_dict(
                model=model, 
                tomo=tomo, 
                prompt_embeds_dict=prompt_embeds_dict, 
                subtomo_size=subtomo_size,
                subtomo_overlap=subtomo_overlap,
                batch_size=batch_size,
                zero_border=0,
            )
            if save_pred_locmap_dicts:
                save_h5_tensor_dict(pred_locmap_dict, f"{out_dir_run}/pred_locmaps.h5")

            #%%

            # #%%
            if compute_tomotwin_picking_performance:
                out_dir_tt = f"{out_dir_run}/tomotwin_pipeline_picking_results"
                os.makedirs(out_dir_tt, exist_ok=True)

                get_best_case_tomotwin_picking_performance(
                    pred_locmap_dict,
                    gt_positions,
                    out_dir=out_dir_tt,
                    subtomo_size=subtomo_size,
                    undersampling_stride=2,  # this is not our stride, but the stride used in tomotwin
                    global_min=0.0,
                    tolerance=0.2,
                )
            #%% 
            outfile = f"{out_dir_run}/clustering_pipeline_picking_results/best_stats.json"
            os.makedirs(os.path.dirname(outfile), exist_ok=True)

            best_stats = get_best_case_cluster_based_picking_performance(
                pred_locmap_dict=pred_locmap_dict, 
                gt_positions=gt_positions, 
                **get_best_case_cluster_based_picking_performance_args,
                outfile=outfile
            )

            # save best stats to model output directory
            for pdb, metrics in best_stats.items():
                if f"round_{round}" not in all_metrics_dict.keys():
                    all_metrics_dict[f"round_{round}"] = {}
                round_metrics = all_metrics_dict[f"round_{round}"]
                if pdb not in round_metrics.keys():
                    round_metrics[pdb] = [metrics[metric_of_interest]]
                else:
                    round_metrics[pdb].append(metrics[metric_of_interest])
        
            # save metrics to model output directory, to this each time a run is finished so it is not lost if the script crashes
            with open(f"{out_dir_model}/all_runs_{metric_of_interest}.json", "w") as f:
                json.dump(all_metrics_dict, f, indent=4)

            fig = plot_metric_boxplots(all_metrics_dict)
            fig.savefig(f"{out_dir_model}/all_runs_{metric_of_interest}.png")