import os
from glob import glob
import pickle
import json


import jax
import jax.numpy as jnp

from tqdm import tqdm
import ott

import h5py
import numpy as np

from functools import partial

DIM = 2

@partial(jnp.vectorize, signature="(N,D),(N,D)->()")
def emd_fn(pred, ground):
    geom = ott.geometry.pointcloud.PointCloud(
        pred, ground, cost_fn=ott.geometry.costs.Euclidean()
    )
    out = ott.solvers.linear.sinkhorn.Sinkhorn()(
        ott.problems.linear.linear_problem.LinearProblem(geom)
    )
    return out.reg_ot_cost


def compute_metric(pred, true, use_emd=False, num_chunks=4):

    l = (true - pred) ** 2
    l = l.mean()
    if use_emd:
        # i = 0
        chunk_size = len(pred) // num_chunks
        # print(pred[i * chunk_size:(i + 1) * chunk_size].shape, pred[i * chunk_size:(i + 1) * chunk_size].shape)

        wst = sum(
            emd_fn(pred[i * chunk_size:(i + 1) * chunk_size], true[i * chunk_size:(i + 1) * chunk_size]).mean()
            for i in range(num_chunks)
        ) / num_chunks #+ emd_fn(pred[-chunk_size:], true[-chunk_size:]).mean() / num_chunks

        return l, wst.item()



    return l, None


def load_rollout(file):
    suffix = file.split(".")[-1]
    if suffix == "h5":
        with h5py.File(file, "r") as f:
            true = f["ground_truth_rollout"][..., :DIM]
            pred = f["predicted_rollout"][..., :DIM]
            types = f["types"][()]

    elif suffix == "pkl":
        with open(file, "rb") as f:
            data = pickle.load(f)
            true = data["ground_truth_rollout"][..., :DIM]
            pred = data["predicted_rollout"][..., :DIM]
            types = data["particle_types"][()]

    else:
        raise "Unknown file type"

    return pred, true, types


def main(list_of_files, save_path, use_emd, num_chunks):

    mse_list = []
    emd_list = []
    for file in tqdm(list_of_files) :
        if file.split('.')[-1] == 'json':
            continue
        pred, true, types = load_rollout(file)

        mask = np.nonzero(types)[0]
        true = true[:, mask, :]
        pred = pred[:, mask, :]

        mse, wst = compute_metric(pred, true, use_emd=use_emd, num_chunks=num_chunks)

        mse_list.append(mse)
        if use_emd:
            emd_list.append(wst)


    mse = sum(mse_list) / len(mse_list)
    if use_emd:
        emd = sum(emd_list) / len(emd_list)
    else:
        emd = None


    metrics = {'mse': mse, 'emd': emd}
    os.makedirs(save_path, exist_ok=True)
    with open(os.path.join(save_path, 'metrics.json'), 'w') as f:
        json.dump(metrics, f)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser("NeuralMPM metrics")
    parser.add_argument("--use-emd", action="store_true")
    parser.add_argument("--save-path", type=str)
    parser.add_argument("--data", type=str)
    parser.add_argument("--batch", type=int)
    parser.add_argument("--num-chunks", type=int, default=4)
    args = parser.parse_args()

    if len(args.data.split('.')) == 2:
        list_of_files = [args.data]

    else:
        list_of_files = glob(os.path.join(args.data, '*'))

    main(list_of_files, args.save_path, args.use_emd, args.num_chunks)