from typing import Tuple

import numpy as np
import torch
from gmm_util.gmm import GMM
from multi_daft_vi.util_multi_daft import create_initial_gmm_parameters

from ltsgns_mp.architectures.loss_functions.mse import mse
from ltsgns_mp.envs.eval_iterator import EvalIterator


def reshape_tensor_for_padded_chamfer_distance(point_cloud1: torch.Tensor, padded_point_cloud2: torch.Tensor) -> Tuple[
    torch.Tensor, torch.Tensor, torch.Size]:
    """
    View the point clouds in a way that they can be used for the padded_chamfer_distance function.
    :param point_cloud1: (aggregation_dim, [Batch_dims], P1, D)
    :param padded_point_cloud2: ([Batch_dims], P2, D)
    :return: Combining botch [Batch_dims] into one batch dimension, also returns the original shape of the [Batch_dims]
    """
    # save original shape of [Batch_dims]
    batch_shape = point_cloud1.shape[1:-2]
    point_cloud1 = point_cloud1.reshape(point_cloud1.shape[0], -1, *point_cloud1.shape[-2:])
    padded_point_cloud2 = padded_point_cloud2.view(-1, *padded_point_cloud2.shape[-2:])
    return point_cloud1, padded_point_cloud2, batch_shape


def padded_chamfer_distance(point_cloud1: torch.Tensor, padded_point_cloud2: torch.Tensor,
                            density_aware: bool = False,
                            forward_only: bool = False,
                            point_reduction: str = "mean") -> torch.Tensor:
    """
    Compute the Chamfer Distance between two point clouds. The second point_cloud can be padded with nan values.
    Args:
        point_cloud1 (torch.Tensor): Point cloud 1, shape ([Batch_dims], P1, D)
        padded_point_cloud2 (torch.Tensor): Point cloud 2, shape (Single_Batch_dim, P2, D) Single_Batch_dim has to be the
                                           shape of the last batch dimension of point_cloud1.
        density_aware (bool): If True, the Chamfer Distance is "density aware", i.e.
          the distance is scaled exponentially. Instead of d(x,y) = sqrt((x-y)^2), the distance is
            d(x,y) = 1-exp(-((x-y)^2)). This is useful if the point clouds have different densities.
        forward_only (bool): If True, only compute the forward distance (from point_cloud1 to point_cloud2).
        point_reduction (str): How to reduce the point dimension. Can be "mean" or "sum".
    Returns:
        chamfer_dist (torch.Tensor): Shape ([Batch_dims]) tensor representing the Chamfer Distance between the point clouds
          of each batch.
    """
    if point_reduction == "mean":
        aggregation = lambda x: torch.nanmean(x, dim=-1)
    elif point_reduction == "sum":
        aggregation = lambda x: torch.nansum(x, dim=-1)
    else:
        raise ValueError(f"Unknown point reduction {point_reduction}")
    # Replace nan values with 0, so that the gradient can be computed
    nan_mask = torch.isnan(padded_point_cloud2)
    padded_point_cloud2 = torch.where(nan_mask, torch.zeros_like(padded_point_cloud2), padded_point_cloud2)
    # Compute pairwise distance matrices
    dist_matrix = torch.cdist(point_cloud1, padded_point_cloud2) ** 2  # shape (Batch, P1, P2)
    if density_aware:
        dist_matrix = 1 - torch.exp(-dist_matrix)
    # now replace the indices where originally was nan with inf. We first have to create a mask for the output
    # remove the D dimension (since it is removed during the cdist computation)
    dist_matrix_nan_mask = torch.any(nan_mask, dim=-1)
    # get the size of the first point cloud and repeat the mask along this dimension
    pc1_size = point_cloud1.shape[-2]
    dist_matrix_nan_mask = dist_matrix_nan_mask[:, None, :]
    dist_matrix_nan_mask = dist_matrix_nan_mask.repeat(1, pc1_size, 1)
    # get the batch shape of the first point cloud and repeat the mask along this dimension
    pc1_batch_shape = point_cloud1.shape[:-2]
    # the last batch dimension of the first point cloud has to be the same as the batch dimension of the second point cloud
    assert pc1_batch_shape[-1] == padded_point_cloud2.shape[0]
    if len(pc1_batch_shape) == 1:
        # we are done
        pass
    elif len(pc1_batch_shape) == 2:
        # repeat this additional batch dimension
        dist_matrix_nan_mask = dist_matrix_nan_mask[None, :, :, :]
        dist_matrix_nan_mask = dist_matrix_nan_mask.repeat(pc1_batch_shape[0], 1, 1, 1)
    else:
        raise NotImplementedError(
            f"Batch shape {pc1_batch_shape} not supported, only batch with 1 or 2 dims are supported.")
    # now replace the indices where originally was nan with inf
    dist_matrix[dist_matrix_nan_mask] = float("inf")

    # # replace the nan values to infinity to take the correct min
    # dist_matrix = torch.nan_to_num(dist_matrix, nan=float("inf"))

    # Compute the minimum distance from points in point_cloud1 to point_cloud2 and vice versa
    forward_distance, _ = torch.min(dist_matrix, dim=-1)  # shape (Batch, P1,)
    # replace the inf values to nan, to take the correct aggregation
    forward_distance = torch.nan_to_num(forward_distance, posinf=float("nan"))
    forward_distance = aggregation(forward_distance)  # shape (Batch,)

    if forward_only:
        chamfer_dist = forward_distance
    else:
        # Average the minimum distances
        backward_distance, _ = torch.min(dist_matrix, dim=-2)  # shape (Batch, P2,)
        # replace the inf values to nan, to take the correct aggregation
        backward_distance = torch.nan_to_num(backward_distance, posinf=float("nan"))
        backward_distance = aggregation(backward_distance)  # shape (Batch,)
        chamfer_dist = 0.5 * (forward_distance + backward_distance)  # shape (Batch,)
    return chamfer_dist


def chamfer_distance(point_cloud1: torch.Tensor, point_cloud2: torch.Tensor,
                     density_aware: bool = False,
                     forward_only: bool = False,
                     point_reduction: str = "mean") -> torch.Tensor:
    """
    Compute the Chamfer Distance between two point clouds. No Nan values are allowed.
    Args:
        point_cloud1 (torch.Tensor): Point cloud 1, shape ([Batch_dims], P1, D)
        point_cloud2 (torch.Tensor): Point cloud 2, shape (Single_Batch_dim, P2, D) Single_Batch_dim has to be the
                                           shape of the last batch dimension of point_cloud1.
        density_aware (bool): If True, the Chamfer Distance is "density aware", i.e.
          the distance is scaled exponentially. Instead of d(x,y) = sqrt((x-y)^2), the distance is
            d(x,y) = 1-exp(-((x-y)^2)). This is useful if the point clouds have different densities.
        forward_only (bool): If True, only compute the forward distance (from point_cloud1 to point_cloud2).
        point_reduction (str): How to reduce the point dimension. Can be "mean" or "sum".
    Returns:
        chamfer_dist (torch.Tensor): Shape ([Batch_dims]) tensor representing the Chamfer Distance between the point clouds
          of each batch.
    """
    if point_reduction == "mean":
        aggregation = lambda x: torch.nanmean(x, dim=-1)
    elif point_reduction == "sum":
        aggregation = lambda x: torch.nansum(x, dim=-1)
    else:
        raise ValueError(f"Unknown point reduction {point_reduction}")
    # Compute pairwise distance matrices
    dist_matrix = torch.cdist(point_cloud1, point_cloud2) ** 2  # shape (Batch, P1, P2)
    if density_aware:
        dist_matrix = 1 - torch.exp(-dist_matrix)
    # Compute the minimum distance from points in point_cloud1 to point_cloud2 and vice versa
    forward_distance, _ = torch.min(dist_matrix, dim=-1)  # shape (Batch, P1,)
    forward_distance = aggregation(forward_distance)  # shape (Batch,)
    if forward_only:
        chamfer_dist = forward_distance
    else:
        # Average the minimum distances
        backward_distance, _ = torch.min(dist_matrix, dim=-2)  # shape (Batch, P2,)
        backward_distance = aggregation(backward_distance)  # shape (Batch,)
        chamfer_dist = 0.5 * (forward_distance + backward_distance)  # shape (Batch,)
    return chamfer_dist

if __name__ == "__main__":
    # unit test for chamfer distance
    # load deformable plate task and compare matching trajectories with pointclouds with not matching ones
    from ltsgns_mp.envs.data_loader.sofa_dataloader import SofaDataloader as dataloader
    import ltsgns_mp.util.keys as keys

    torch.set_printoptions(linewidth=200)

    env_config = {
        "path_to_datasets": "../datasets/lts_gns/",
        "name": "deformable_plate",
        "dataset_name": "deformable_plate",
        "use_collider_velocities": False,
        "use_canonic_mesh_positions": True,
        "use_point_cloud": True,
        "use_poisson_ratio_as_node_feature": False,
        "eval_only": False,
        "second_order_dynamics": False,
        "use_point_cloud_as_graph": False,
        "debug": {
            "max_tasks_per_split": 2},
        "connectivity_setting": {
            "collider_mesh_edge_creation": "knn",
            "collider_mesh_k": 2,
            "collider_collider_radius": 0.08,
            "collider_mesh_radius": 0.3,
            "point_cloud_mesh_radius": 0.3,
            "point_cloud_point_cloud_radius": 0.1,
        },
    }
    eval_iterator_config = {
        "evaluation_split": "val",
        "indices": {
            "mesh": {
                "indices": [0],  # if indices is present, take that over start/stop indices
            },
            "point_cloud": {
                "indices": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            },
            "evaluation": {
                "start_idx": 9, # which index to start the evaluation at, usually should include the anchor idx or the start of the trajectory
                "stop_idx": None,
                "step": 1,
            },
            "anchor_idx": 0,  # which index has initial x features and edge features
        },
    }
    # create omegaconfig
    from omegaconf import OmegaConf
    env_config = OmegaConf.create(env_config)
    eval_iterator_config = OmegaConf.create(eval_iterator_config)
    dataloader = dataloader(config=env_config)
    traj_dict = dataloader.load()
    val_tasks = traj_dict["val"]
    eval_iterator = EvalIterator(eval_iterator_config, val_tasks, device="cpu")
    task_list = eval_iterator._data_list
    task = task_list[0]
    task_2 = task_list[1]
    mesh = task[keys.CONTEXT_NODE_POSITIONS]
    mesh_2 = task_2[keys.CONTEXT_NODE_POSITIONS]
    pcs = task[keys.CONTEXT_POINT_CLOUD_POSITIONS][0] # remove batch dim

    likelihood_std = 0.1

    # compute chamfer distance between mesh and point clouds
    good_result = padded_chamfer_distance(mesh, pcs, density_aware=False, forward_only=False, point_reduction="mean")
    bad_result =  padded_chamfer_distance(mesh_2, pcs, density_aware=False, forward_only=False, point_reduction="mean")
    log_like_bad_mesh = - 0.5 * torch.sum((mesh - mesh_2) ** 2 / likelihood_std ** 2, dim=-1)
    log_like_bad_mesh = torch.mean(log_like_bad_mesh, dim=-1)
    log_like_good = - 0.5 * good_result / (likelihood_std ** 2)
    log_like_bad = - 0.5 * bad_result / (likelihood_std ** 2)

    # add first 10 terms together
    log_like_good = torch.sum(log_like_good[:10])
    log_like_bad = torch.sum(log_like_bad[:10])
    log_like_bad_mesh = torch.sum(log_like_bad_mesh[:10])
    like_good = torch.exp(log_like_good)
    like_bad = torch.exp(log_like_bad)
    like_bad_mesh = torch.exp(log_like_bad_mesh)
    print("good result: ", good_result)
    print("bad result: ", bad_result)
    print("---------------")
    print("log like good: ", log_like_good)
    print("log like bad: ", log_like_bad)
    print("log like bad mesh: ", log_like_bad_mesh)
    print("---------------")
    print("like good: ", like_good)
    print("like bad: ", like_bad)
    print("like bad mesh: ", like_bad_mesh)
    print("---------------")
    print("stop")

    # Comparison with prior
    # create new params
    prior_w, prior_mean, prior_cov = create_initial_gmm_parameters(
        n_tasks=1,
        d_z=2,
        n_components=3,
        prior_scale=1.0,
        initial_var=1.0,
    )
    prior = GMM(
        log_w=prior_w,
        mean=prior_mean,
        prec=torch.linalg.inv(prior_cov),
        device="cpu"
    )
    # evaluate on different z from -5 to 5 in a 2d grid
    z1 = np.linspace(-5, 5, 10)
    z2 = np.linspace(-5, 5, 10)
    z1, z2 = np.meshgrid(z1, z2)
    z = np.stack([z1, z2], axis=-1)
    z = torch.tensor(z, dtype=torch.float32)
    prior_log_like = prior.log_density(z, compute_grad=False)[0]
    prior_like = torch.exp(prior_log_like)
    print("prior log like: ", prior_log_like)
    print("prior like: ", prior_like)

