import torch
import numpy as np
import pandas as pd

import os
import tqdm
import json
import torch
from clustering_and_picking.tomotwin_evaluation_routine import SIZE_DICT, get_stats
from clustering_and_picking.clustering_and_cleaning import get_cluster_centroids
import multiprocessing
from multiprocessing import Pool

def get_best_case_cluster_based_picking_performance(pred_locmap_dict, gt_positions, n_size_steps, n_thresh_steps=None, metric="F1", iou_thresh=0.6, optimize_thresh=True, outfile=None, min_thresh=None, max_thresh=None, skip_pdbs=["background"], num_workers=None):
    jobs, pdbs = [], []
    pbar_position = 0
    for pdb in pred_locmap_dict.keys():
        if pdb in skip_pdbs:
            continue
        job = {}
        pred_locmap = pred_locmap_dict[pdb]
        job = {
            "pred_locmap": pred_locmap,
            "gt_positions": gt_positions,
            "pdb": pdb,
            "optimize_thresh": optimize_thresh,
            "n_thresh_steps": n_thresh_steps,
            "n_size_steps": n_size_steps,
            "iou_thresh": iou_thresh,
            "min_thresh": min_thresh,
            "max_thresh": max_thresh,
            "metric": metric,
            "pbar_position": pbar_position
        }
        jobs.append(job)
        pdbs.append(pdb)
        pbar_position += 1

    num_workers = len(jobs) if num_workers is None else num_workers

    if num_workers == 1 or num_workers == 0:
        print(f"Running evaluation with 1 workers")
        best_stats = []
        for job in jobs:
            best_stats.append(_get_best_case_cluster_based_picking_performance_wrapper(job))
    else:
        #eval = lambda x: _get_best_case_cluster_based_picking_performance(**x)
        with Pool(num_workers) as p:
            best_stats = p.map(_get_best_case_cluster_based_picking_performance_wrapper, jobs)
    best_stats = dict(zip(pdbs, best_stats))       

    # save cluster based stats to json
    if outfile is not None:
        with open(outfile, "w") as f:
            json.dump(best_stats, f, indent=4)
    return best_stats

def _get_best_case_cluster_based_picking_performance_wrapper(args):
    return _get_best_case_cluster_based_picking_performance(**args)

def _get_best_case_cluster_based_picking_performance(pred_locmap, gt_positions, pdb, optimize_thresh, n_thresh_steps, n_size_steps, iou_thresh, min_thresh, max_thresh, metric, pbar_position=0):
    """
    Evaluate the performance of a predicted locmap using clustering and picking.
    """
    if optimize_thresh:
        min_thresh = pred_locmap.min() if min_thresh is None else min_thresh
        max_thresh = pred_locmap.max() if max_thresh is None else max_thresh
        threshs = np.linspace(min_thresh, max_thresh, n_thresh_steps+1, endpoint=False)[1:]
    else:
        threshs = [1.0]

    best_stats = {metric: 0.0}
    pbar = tqdm.tqdm(total=len(threshs) * n_size_steps**2, desc=f"{pdb} (Best {metric}: ?.??)", position=pbar_position, leave=True)
    for thresh in threshs:
        prediction_ds_thr = (pred_locmap >= thresh).float().numpy()
        if prediction_ds_thr.sum() == 0:
            continue
        clusters_labeled_by_size, centroids_list, cluster_size_list = get_cluster_centroids(
            dataset=prediction_ds_thr,
            min_cluster_size=0,
            max_cluster_size=np.inf,
            connectivity=1
        )

        df_located = pd.DataFrame(columns=["X", "Y", "Z", "size", "metric_best", "predicted_class", "width", "height", "depth"])
        df_located["X"] = [c[2] for c in centroids_list]
        df_located["Y"] = [c[1] for c in centroids_list]
        df_located["Z"] = [c[0] for c in centroids_list]
        df_located["size"] = cluster_size_list
        df_located["metric_best"] = 1.0
        df_located["predicted_class"] = 0  # important, only 1 class in this case
        size = SIZE_DICT[pdb.upper()] if pdb.upper() in SIZE_DICT.keys() else 37
        df_located["width"] = size
        df_located["height"] = size
        df_located["depth"] = size
        df_located.attrs["references"] = [pdb]

        min_size_range = np.linspace(1, np.quantile(cluster_size_list, 0.45), n_size_steps)
        max_size_range = np.linspace(np.quantile(cluster_size_list, 0.55), np.max(cluster_size_list), n_size_steps)

        for min_size in min_size_range:
            for max_size in max_size_range:
                df_located_ = df_located[df_located["size"].between(min_size, max_size)]
                if min_size < max_size and len(df_located_) > 0:
                    stats = get_stats(df_located_, gt_positions, iou_thresh=iou_thresh)
                    if stats[metric] > best_stats[metric]:
                        best_stats = stats
                        best_stats["min_size"] = min_size
                        best_stats["max_size"] = max_size
                        best_stats["thresh"] = thresh
                        best_stats["positions"] = df_located_.to_string()
                        pbar.set_description(f"{pdb} (Best {metric}: {best_stats[metric]:.2f})")
                pbar.update(1)
    return best_stats


def get_cluster_based_picking_performance(pred_locmap_thresh, gt_positions, pdb, iou_thresh=0.6, min_cluster_size=0, max_cluster_size=torch.inf, skip_pdbs=["background"]):
    clusters_labeled_by_size, centroids_list, cluster_size_list = get_cluster_centroids(
            dataset=pred_locmap_thresh,
            min_cluster_size=0,
            max_cluster_size=np.inf,
            connectivity=1
    )

    df_located = pd.DataFrame(columns=["X", "Y", "Z", "size", "metric_best", "predicted_class", "width", "height", "depth"])
    df_located["X"] = [c[2] for c in centroids_list]
    df_located["Y"] = [c[1] for c in centroids_list]
    df_located["Z"] = [c[0] for c in centroids_list]
    df_located["size"] = cluster_size_list
    df_located["metric_best"] = 1.0
    df_located["predicted_class"] = 0  # important, only 1 class in this case
    size = SIZE_DICT[pdb.upper()] if pdb.upper() in SIZE_DICT.keys() else 37
    df_located["width"] = size
    df_located["height"] = size
    df_located["depth"] = size
    df_located.attrs["references"] = [pdb]

    
    df_located_ = df_located[df_located["size"].between(min_cluster_size, max_cluster_size)]
    if min_cluster_size < max_cluster_size and len(df_located_) > 0:
        stats = get_stats(df_located_, gt_positions, iou_thresh=iou_thresh)
    else:
        raise ValueError("No clusters found in the prediction")
    
    stats["positions"] = df_located_.to_string()
    stats["min_size"] = min_cluster_size
    stats["max_size"] = max_cluster_size

    return stats

def get_best_case_tomotwin_picking_performance(
    pred_locmap_dict,
    gt_positions,
    out_dir,
    subtomo_size=37,
    undersampling_stride=2, 
    global_min=0.0,
    tolerance=0.2,
):
    os.makedirs(f"{out_dir}", exist_ok=True)
    
    # prepare map file from locmaps
    columns = ["X", "Y", "Z"] + [f"d_class_{k}" for k in range(len(pred_locmap_dict))]
    df_map = pd.DataFrame(columns=columns)

    shape = pred_locmap_dict[list(pred_locmap_dict.keys())[0]].shape
    gridpoints = torch.stack(torch.meshgrid(
        torch.arange(subtomo_size//2, shape[0] - subtomo_size//2, undersampling_stride),
        torch.arange(subtomo_size//2, shape[1] - subtomo_size//2, undersampling_stride),
        torch.arange(subtomo_size//2, shape[2] - subtomo_size//2, undersampling_stride),
    )).reshape(3, -1).T.flip(1)
    df_map[["X", "Y", "Z"]] = gridpoints.numpy()
    df_map.attrs = {"references": list(pred_locmap_dict.keys()), "stride": 3*[undersampling_stride], "window_size": subtomo_size}
    idx = torch.from_numpy(df_map[["X", "Y", "Z"]].values.astype(int)).flip(1)  # flip because z is first dimension in locmap
    for k, locmap in enumerate(pred_locmap_dict.values()):
        df_map[f"d_class_{k}"] = locmap[idx[:, 0], idx[:, 1], idx[:, 2]].cpu().numpy()
    mapfile = f"{out_dir}/map.tmap"
    pd.to_pickle(df_map, mapfile)

    # save gt_positions to csv file skip header and index so that tomotwin can read it
    gt_positions_file = f"{out_dir}/gt_positions.csv"
    gt_positions[["class", "X", "Y", "Z", "rx", "ry", "rz"]].to_csv(gt_positions_file, header=False, index=False)
    os.system(
        f"tomotwin_locate.py findmax \
            -m {mapfile} \
            -o {out_dir}/locate/ \
            -g {global_min} \
            -t {tolerance} \
            --processes {len(pred_locmap_dict)} \
        "
    )
    # save SIZE_DICT to boxsizes.json
    with open(f"{out_dir}/boxsizes.json", "w") as f:
        json.dump(SIZE_DICT, f)
    os.system(
        f"tomotwin_scripts_evaluate.py positions \
            -p {gt_positions_file} \
            -l {out_dir}/locate/*.tloc \
            -s {out_dir}/boxsizes.json \
            --optim"
    )
    os.remove(gt_positions_file)
    os.remove(mapfile)
