from tango import Step
import matlab.engine
from scipy.io import savemat
import numpy as np
import wandb
import json
from .utils import expand_task_and_dataset

from typing import Optional
import multiprocessing
import os
import pickle

def cpm_on_single(run_name, path_meta):
    dataset, task, path, train_test_split_path, pairwise_path = path_meta

    eng = matlab.engine.start_matlab()
    eng.cd(r'./matlab')

    print(f"worker is working on {dataset} {task}")
    R_pos, P_pos, R_neg, P_neg, test_R_pos, test_P_pos, test_R_neg, test_P_neg, pos_mask, neg_mask, r_hyperedge, r_pairwise, p_hyperedge, p_pairwise = eng.CPM2(path, pairwise_path, train_test_split_path, nargout=14)
    eng.quit()

    pos_idx = np.asarray(pos_mask).squeeze().nonzero()[0].tolist()
    neg_idx = np.asarray(neg_mask).squeeze().nonzero()[0].tolist()

    r_hyperedge = np.asarray(r_hyperedge).squeeze().tolist()
    r_pairwise = np.asarray(r_pairwise).squeeze().tolist()
    p_hyperedge = np.asarray(p_hyperedge).squeeze().tolist()
    p_pairwise = np.asarray(p_pairwise).squeeze().tolist()

    res = {
        "R_pos": R_pos,
        "P_pos": P_pos,
        "R_neg": R_neg,
        "P_neg": P_neg,
        "test_R_pos": test_R_pos,
        "test_P_pos": test_P_pos,
        "test_R_neg": test_R_neg,
        "test_P_neg": test_P_neg,
        "pos_idx": pos_idx,
        "neg_idx": neg_idx,
    }

    output_file_path = os.path.join(
        os.path.split(path.replace("hyperedges", "ckpts"))[0],
        run_name,
        f"{dataset}_{task}.txt"
    )

    output_string = ""
    for key, value in res.items():
        if "idx" not in key:
            value = round(value, 4)
        output_string += f"{key}: {value}\n"
        
    res["r_hyperedge"] = r_hyperedge,
    res["r_pairwise"] = r_pairwise,
    res["p_hyperedge"] = p_hyperedge,
    res["p_pairwise"] = p_pairwise,

    pickle.dump(res, open(output_file_path.replace(".txt", ".pkl"), "wb"))

    print(output_string)
    print(output_string, file=open(output_file_path, "w"))

    return res

@Step.register("cpm")
class CPM(Step):
    CACHEABLE = False
    def run(
        self,
        run_name: str,
        datasets: str,
        tasks: str,
        path: str,
        train_test_split_path: str,
        pairwise_path: str,
    ):  
        if '[RUN]' in path:
            path = path.replace('[RUN]', run_name)
        datasets_tasks = datasets + "__" + tasks
        datasets_tasks, path, train_test_split_path, pairwise_path = list(map(expand_task_and_dataset, [
                            datasets_tasks, path, train_test_split_path, pairwise_path
                            ]))
        datasets = [dataset_task.split('__')[0] for dataset_task in datasets_tasks]
        tasks = [dataset_task.split('__')[1] for dataset_task in datasets_tasks]

        path_meta = list(zip(
            datasets,
            tasks,
            path,
            train_test_split_path,
            pairwise_path,
        ))

        n_job = len(path_meta)
        # with multiprocessing.Pool(n_job) as p:
        #     results = p.map(cpm_on_single, path_meta)
        for _path in path_meta:
            cpm_on_single(run_name, _path)

        return