"""Prediction forward related functions."""

import time
import torch
from torch import Tensor
import torch.nn.functional as F
from TAMO.model.tamo import TAMO
from model.module import get_prediction_mean_std, get_prediction_nll
from utils.metrics import kendalltau_correlation
from utils.metrics import reduce
from data.data_masking import generate_dim_mask, gather_by_indices
from data.function_preprocessing import add_gaussian_noise
from data.sampler import sample_nc
from typing import Optional


def predict_with_metrics(
    model: TAMO,
    x_ctx: Tensor,
    y_ctx: Tensor,
    x_tar: Tensor,
    y_tar: Tensor,
    x_mask: Tensor,
    y_mask: Tensor,
    y_mask_tar: Optional[Tensor] = None,
    compute_nll: bool = True,
    compute_mse: bool = True,
    compute_ktt: bool = False,
    reduce_nll: bool = True,
    reduce_mse: bool = True,
    read_cache: bool = False,
    write_cache: bool = False,
):
    """Forward pass for prediction (model + loss).

    Args:
        model: TAMO model
        x_ctx: context inputs, [B, nc, max_x_dim]
        y_ctx: context function values, [B, nc, max_y_dim]
        x_tar: target locations, [B, nt, max_x_dim]
        y_tar: ground truth target function values, [B, nt, max_y_dim]
        x_mask: [B, dx_max]
        y_mask: [B, dy_max]
        compute_nll: whether to compute negative log likelihood (NLL), defaults to True
        compute_mse: whether to compute mean squared error (MSE), defaults to True
        compute_ktt: whether to compute Kendall's tau correlation, defaults to False
        reduce_nll: whether to reduce NLL, defaults to True
        reduce_mse: whether to reduce MSE, defaults to True

    Returns:
        nll: [B, nt, max_y_dim] or [1] if reduced
        mse: [B, nt, max_y_dim] or [max_y_dim] if reduced
        kt_tau: [max_y_dim]
    """
    results = {"nll": None, "mse": None, "kt_tau": None}

    start = time.time()
    out = model.predict(
        x_ctx=x_ctx,
        y_ctx=y_ctx,
        x_tar=x_tar,
        x_dim_mask=x_mask,
        y_dim_mask=y_mask,
        y_dim_mask_tar=y_mask_tar,
        read_cache=read_cache,
        write_cache=write_cache,
    )  # [B, nt, max_y_dim, M, out_dims]
    inference_time = time.time() - start

    def maybe_reduce(val, do_reduce, dim=None):
        return reduce(val, dim=dim) if do_reduce else val

    if compute_nll:
        nll = get_prediction_nll(out=out, target=y_tar)
        results["nll"] = maybe_reduce(nll, reduce_nll)

    if compute_mse or compute_ktt:
        mean, _ = get_prediction_mean_std(out)
        mean = mean.detach()

    if compute_mse:
        mse = F.mse_loss(input=mean, target=y_tar, reduction="none")
        results["mse"] = maybe_reduce(mse, reduce_mse, dim=(0, 1))

    if compute_ktt:
        results["kt_tau"] = kendalltau_correlation(input=mean, target=y_tar)

    return results["nll"], results["mse"], results["kt_tau"], inference_time


def prepare_prediction_dataset(
    x: Tensor,
    y: Tensor,
    valid_x_counts: Tensor | int,
    valid_y_counts: Tensor | int,
    dim_scatter_mode: str,
    min_nc: int,
    max_nc: int,
    nc_fixed: Optional[int] = None,
    warmup: bool = True,
    sigma: float = 0.0,
) -> tuple[Tensor, Tensor, Tensor, Tensor, int]:
    """Make data for prediction.

    Args:
        x: [B, N, max_x_dim]
        y: [B, N, max_y_dim]
        valid_x_counts: [B] | int
        valid_y_counts: [B] | int
        dim_scatter_mode: ["random_k", "top_k"]
        min_nc: Minimum number of contexts
        max_nc: Maximum number of contexts
        nc_fixed: Optional fixed number of contexts, if any

    Returns:
        x: (Rearranged) x, [B, N, max_x_dim]
        y: (Rearranged) y, [B, N, max_y_dim]
        x_mask: [B, max_x_dim] | [max_x_dim]
        y_mask: [B, max_y_dim] | [max_y_dim]
        nc: Context size (int)
    """
    max_x_dim = x.shape[-1]
    max_y_dim = y.shape[-1]

    x_mask, x_indices = generate_dim_mask(
        max_dim=max_x_dim,
        device=x.device,
        k=valid_x_counts,
        dim_scatter_mode=dim_scatter_mode,
    )

    y_mask, y_indices = generate_dim_mask(
        max_dim=max_y_dim,
        device=y.device,
        k=valid_y_counts,
        dim_scatter_mode=dim_scatter_mode,
    )

    # NOTE Rearrange x and y based on indices
    x = gather_by_indices(x, x_indices)
    y = gather_by_indices(y, y_indices)

    if nc_fixed is None:
        max_valid_x_dim = torch.max(valid_x_counts)
        nc = sample_nc(
            min_nc=min_nc, max_nc=max_nc, x_dim=max_valid_x_dim, warmup=warmup
        )
    else:
        nc = nc_fixed

    if sigma > 0.0:
        y = add_gaussian_noise(y, sigma)

    return x, y, x_mask, y_mask, nc
