from typing import Literal

import numpy as np
import torch

from app.SB import SchrodingerBridge
from app.sinkhorn import Sinkhorn


def calc_mse_comparison(
    sb: SchrodingerBridge,
    device: str,
    Ntest: int,
    MNsamples: list[tuple[int, int]],
    tauz: list[float],
    n_trials: int,
    eps: float,
    sinkhorn_n_iters: list[int],
    compare_target: str,
):
    mse_comparison = np.zeros((len(MNsamples), len(tauz), len(sinkhorn_n_iters), n_trials))
    for ni in range(len(MNsamples)):
        (m, n) = MNsamples[ni]
        for nti in range(n_trials):
            xA = sb.sample_from_source(m)
            xB = sb.sample_from_target(n)
            for sinki in range(len(sinkhorn_n_iters)):
                print(f"{MNsamples[ni]=}, {sinkhorn_n_iters[sinki]=}")
                sinkhorn_n_iter = sinkhorn_n_iters[sinki]
                sinkhorn = Sinkhorn(
                    x_w=torch.full(size=(m,), fill_value=1 / m, device=device),
                    x_spt=xA,
                    y_w=torch.full(size=(n,), fill_value=1 / n, device=device),
                    y_spt=xB,
                    eps=eps,
                    n_iter=sinkhorn_n_iter,
                )

                if compare_target == "m_n_star_drift":
                    sinkhorn_m_n_star = Sinkhorn(
                        x_w=torch.full(size=(m,), fill_value=1 / m, device=device),
                        x_spt=xA,
                        y_w=torch.full(size=(n,), fill_value=1 / n, device=device),
                        y_spt=xB,
                        eps=eps,
                        n_iter=100000,
                    )

                for taui in range(len(tauz)):
                    τ = tauz[taui]

                    if compare_target == "true_drift":
                        xtau = sb.sample_at_time_moment(τ, Ntest)
                        target_drift = sb.get_drift(xtau, τ)
                        estimated_drift = sinkhorn.get_drift(x=xtau, t=τ)
                    elif compare_target == "m_n_star_drift":
                        xtau = sinkhorn_m_n_star.sample_at_time_moment(τ, Ntest)
                        target_drift = sinkhorn_m_n_star.get_drift(xtau, τ)
                        estimated_drift = sinkhorn.get_drift(xtau, τ)
                    else:
                        raise Exception

                    mse_comparison[ni, taui, sinki, nti] = torch.mean(
                        torch.linalg.norm(estimated_drift - target_drift, dim=1) ** 2
                    ).item()

    return mse_comparison


def calc_expect_mse_comparison(
    sb: SchrodingerBridge,
    device: str,
    MNsamples: list[tuple[int, int]],
    tauz: list[float],
    expect_n_marginals: int,
    expect_n_per_marginals: int,
    n_trials: int,
    eps: float,
    sinkhorn_n_iters: list[int],
    compare_target: Literal["true_drift", "m_n_star_drift"],
):
    mse_comparison = np.zeros(
        (len(MNsamples), len(tauz), len(sinkhorn_n_iters), n_trials),
        dtype=np.float64,
    )

    for ni, (m, n) in enumerate(MNsamples):
        for nti in range(n_trials):
            xA = sb.sample_from_source(m)
            xB = sb.sample_from_target(n)

            sinkhorns: list[Sinkhorn] = []
            for n_iter in sinkhorn_n_iters:
                sinkhorns.append(
                    Sinkhorn(
                        x_w=torch.full((m,), 1 / m, device=device),
                        x_spt=xA,
                        y_w=torch.full((n,), 1 / n, device=device),
                        y_spt=xB,
                        eps=eps,
                        n_iter=n_iter,
                    )
                )
            if compare_target == "m_n_star_drift":
                sinkhorn_m_n_star = Sinkhorn(
                    x_w=torch.full((m,), 1 / m, device=device),
                    x_spt=xA,
                    y_w=torch.full((n,), 1 / n, device=device),
                    y_spt=xB,
                    eps=eps,
                    n_iter=100_000,
                )

            for taui, τ in enumerate(tauz):
                t_samples = torch.rand(expect_n_marginals) * τ
                total_error = np.zeros(len(sinkhorns), dtype=np.float64)

                for t in t_samples:
                    t_val = t.item()

                    if compare_target == "true_drift":
                        xt = sb.sample_at_time_moment(t_val, expect_n_per_marginals)
                        target_drift = sb.get_drift(xt, t_val)
                    elif compare_target == "m_n_star_drift":
                        xt = sinkhorn_m_n_star.sample_at_time_moment(t_val, expect_n_per_marginals)
                        target_drift = sinkhorn_m_n_star.get_drift(xt, t_val)

                    for sinki, sinkhorn in enumerate(sinkhorns):
                        estimated_drift = sinkhorn.get_drift(x=xt, t=t_val)
                        err = torch.sum(torch.linalg.norm(estimated_drift - target_drift, dim=1) ** 2).item()
                        total_error[sinki] += err

                num_samples = expect_n_marginals * expect_n_per_marginals
                for sinki in range(len(sinkhorns)):
                    mse_comparison[ni, taui, sinki, nti] = total_error[sinki] / num_samples * τ

                print(
                    f"ni={ni}, trial={nti + 1}/{n_trials}, τ={τ:.3f} →",
                    [
                        f"iter={sinkhorn_n_iters[sinki]}: MSE={mse_comparison[ni, taui, sinki, nti]:.4e}"
                        for sinki in range(len(sinkhorns))
                    ],
                )

    return mse_comparison
