#!/usr/bin/env python

import argparse
import random
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Union

import dawgz
import numpy as np
from omegaconf import DictConfig

from lola.crps import CRPS
from lola.emulation import emulate_crps
from lola.hydra import compose


def evaluate(
    run: str,
    server: DictConfig,
    indices: Sequence[Union[int, float]],
    target: str = "state",
    split: str = "test",
    destination: str = "results",
    start: int = 0,
    context: int = 1,
    overlap: int = 1,
    samples: int = 1,
    guidance: Optional[str] = None,
    sampling: Dict[str, Any] = {},  # noqa: B006
    seed: Optional[int] = None,
    record: int = 0,
    ensemble_sizes_to_save: Sequence[int] = [],
    debug: bool = False,
    **ignore,
):
    import time
    from functools import partial
    from pathlib import Path

    import numpy as np
    import ot
    import torch
    from azula.guidance import MMPSDenoiser
    from einops import rearrange, reduce, repeat
    from filelock import FileLock
    from omegaconf import OmegaConf
    from tqdm import tqdm

    from lola.autoencoder import get_autoencoder
    from lola.data import (
        field_postprocess,
        field_preprocess,
        get_well_inputs,
        get_well_multi_dataset,
    )
    from lola.diffusion import get_denoiser
    from lola.emulation import (
        decode_traj,
        emulate_diffusion,
        emulate_rollout,
        emulate_surrogate,
        encode_traj,
    )
    from lola.fourier import isotropic_cross_correlation, isotropic_power_spectrum
    from lola.plot import draw_movie
    from lola.surrogate import get_surrogate
    from lola.utils import randseed

    device = torch.device("cuda")

    # Performance
    torch.set_float32_matmul_precision("high")

    # Config
    runpath = Path(run)
    runpath = runpath.expanduser().resolve()

    runname = runpath.name

    cfg = OmegaConf.load(runpath / "config.yaml")

    # Data
    dataset = get_well_multi_dataset(
        path=server.datasets,
        physics=cfg.dataset.physics,
        split=split,
        steps=-1,
        include_filters=cfg.dataset.include_filters,
        augment=[s for s in cfg.dataset.augment if "random" not in s],
    )

    preprocess = partial(
        field_preprocess,
        mean=torch.as_tensor(cfg.dataset.stats.mean, device=device),
        std=torch.as_tensor(cfg.dataset.stats.std, device=device),
        transform=cfg.dataset.transform,
    )

    postprocess = partial(
        field_postprocess,
        mean=torch.as_tensor(cfg.dataset.stats.mean, device=device),
        std=torch.as_tensor(cfg.dataset.stats.std, device=device),
        transform=cfg.dataset.transform,
    )

    if hasattr(cfg.dataset, "dimensions"):
        spatial = len(cfg.dataset.dimensions)
    else:
        spatial = 2

    # Autoencoder
    if (runpath / "autoencoder").exists():
        cfg.ae = OmegaConf.load(runpath / "autoencoder/config.yaml").ae

        state = torch.load(
            runpath / "autoencoder/state.pth", weights_only=True, map_location=device
        )

        autoencoder = get_autoencoder(**cfg.ae)
        autoencoder.load_state_dict(state)
        autoencoder.to(device)
        autoencoder.requires_grad_(False)
        autoencoder.eval()
    elif hasattr(cfg, "ae"):
        cfg.trajectory = {"stride": 1}

        state = torch.load(
            runpath / "state.pth", weights_only=True, map_location=device
        )

        autoencoder = get_autoencoder(**cfg.ae)
        autoencoder.load_state_dict(state)
        autoencoder.to(device)
        autoencoder.requires_grad_(False)
        autoencoder.eval()
    else:
        autoencoder = None

    # Emulator
    state = torch.load(
        runpath / f"{target}.pth", weights_only=True, map_location=device
    )
    new_state_dict = {}
    for k, v in state.items():
        name = k.replace("module.", "") if k.startswith("module.") else k
        new_state_dict[name] = v

    if hasattr(cfg, "denoiser"):
        denoiser = get_denoiser(**cfg.denoiser)
        denoiser.load_state_dict(new_state_dict)
        denoiser.to(device)
        denoiser.requires_grad_(False)
        denoiser.eval()
        denoiser = torch.compile(denoiser)
    elif hasattr(cfg, "surrogate"):
        surrogate = get_surrogate(**cfg.surrogate)
        surrogate.load_state_dict(new_state_dict)
        surrogate.to(device)
        surrogate.eval()

    del state

    surrogate_deterministic = None
    # Original deterministic surrogate prior to fine-tuning (if applicable)
    if hasattr(cfg, "surrogate_run"):
        surrogate_runpath = Path(cfg.surrogate_run)
        surrogate_cfg = OmegaConf.load(surrogate_runpath / "config.yaml")
        state = torch.load(
            surrogate_runpath / f"{target}.pth", weights_only=True, map_location=device
        )
        if hasattr(surrogate_cfg, "surrogate"):
            surrogate_deterministic = get_surrogate(**surrogate_cfg.surrogate)
            surrogate_deterministic.load_state_dict(state)
            surrogate_deterministic.to(device)
            surrogate_deterministic.eval()

    # CRPS loss
    surrogate_loss = CRPS()

    # Some settings from eval
    # Get ensemble sizes to track (only the ones we'll save)
    if samples not in ensemble_sizes_to_save:
        ensemble_sizes_to_save.append(samples)
        ensemble_sizes_to_save = sorted(ensemble_sizes_to_save)
    assert max(ensemble_sizes_to_save) <= samples, (
        f"max ensemble size to save ({max(ensemble_sizes_to_save)}) "
        f"cannot be larger than total samples ({samples})"
    )
    field_names = cfg.dataset.fields

    # RNG
    if seed is None:
        seed = torch.initial_seed()

    # Evaluation
    if "euler" in cfg.dataset.name:
        start = 0
    elif "rayleigh" in cfg.dataset.name:
        start = 5
    elif "shear_flow" in cfg.dataset.name:
        start = 5
    else:
        raise ValueError(
            f"No default number of start index for dataset {cfg.dataset.name}."
        )

    # Evaluation
    # indices = {
    #     int(index * len(dataset)) if isinstance(index, float) else index
    #     for index in indices
    # }
    print("Test indices", indices)
    random.seed(seed)
    plot_indices = random.sample(list(indices), min(len(indices), 5))
    print("Plot indices", plot_indices)

    # Delete previous test_stats if they exist
    outdir = Path(f"{server.storage}/{destination}/{cfg.dataset.name}")
    outdir = outdir.expanduser().resolve()
    (outdir / runname).mkdir(parents=True, exist_ok=True)

    if debug:
        stats_path = outdir / runname / f"Nov27_test_stats_{target}_debug.csv"
    else:
        stats_path = outdir / runname / f"Nov27_test_stats_{target}_all.csv"
    # Delete existing file at the start of test evaluation
    if stats_path.exists():
        stats_path.unlink()
        print(f"Deleted existing {stats_path}")

    for index in tqdm(indices, ncols=88, ascii=True):
        _ = torch.manual_seed(randseed(f"{seed},{index},{start}"))

        x, label = get_well_inputs(dataset[index], device=device)
        x = x[
            max(
                0, start - (context - 1) * cfg.trajectory.stride
            ) :: cfg.trajectory.stride
        ]
        x = preprocess(x)
        x = rearrange(x, "L ... C -> C L ...")

        with torch.no_grad():
            z = encode_traj(autoencoder, x)
            x_ae = decode_traj(autoencoder, z, noisy=False)

        compression = x.numel() / z.numel()

        ## Emulate
        if hasattr(cfg, "denoiser"):
            method = "diffusion"
            settings = f"{sampling.algorithm}{sampling.steps}"

            if guidance is None:
                emulate = lambda mask, z_obs, noise, i: emulate_diffusion(
                    denoiser,
                    mask,
                    z_obs,
                    label=label,  # noqa: B023
                    **sampling,
                )
            else:
                # fmt: off
                def D(z):
                    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
                        return decode_traj(autoencoder, z, batched=True, noisy=False)
                # fmt: on

                if guidance == "subsample" and spatial == 2:
                    A = lambda x: x[..., ::32, ::32]
                elif guidance == "subsample" and spatial == 3:
                    A = lambda x: x[..., ::8, ::8, ::8]
                elif guidance == "downscale" and spatial == 2:
                    A = lambda x: reduce(
                        x, "... (H h) (W w) -> ... H W", "mean", h=32, w=32
                    )
                elif guidance == "downscale" and spatial == 3:
                    A = lambda x: reduce(
                        x, "... (H h) (W w) (Z z) -> ... H W Z", "mean", h=8, w=8, z=8
                    )
                else:
                    raise ValueError(f"unknown operator '{guidance}'")

                y = A(x)
                y = y + 1e-1 * torch.randn_like(y)
                var_y = torch.tensor(1e-2, device=device)

                def emulate(mask, z_obs, i):
                    j = overlap if i > 0 else context
                    y_i = y[:, i + j : i + cfg.trajectory.length]  # noqa: B023
                    A_i = lambda z: A(D(z[:, :, j : j + y_i.shape[1]])).flatten(
                        1
                    )  # noqa: B023

                    return emulate_diffusion(
                        MMPSDenoiser(
                            denoiser,
                            y=y_i.flatten(),
                            A=A_i,
                            var_y=var_y,  # noqa: B023
                            iterations=1,
                        ),
                        mask,
                        z_obs,
                        label=label,  # noqa: B023
                        **sampling,
                    )

        elif hasattr(cfg, "surrogate"):
            if (
                hasattr(cfg.surrogate, "noise_emb_features")
                and cfg.surrogate.noise_emb_features > 0
            ):
                method = "crps"
                settings = None
                emulate = lambda mask, z_obs, noise, i: emulate_crps(
                    surrogate,
                    mask,
                    z_obs,
                    label=label,
                    noise=noise,
                )
                if surrogate_deterministic is not None:
                    emulate_deterministic = (
                        lambda mask, z_obs, noise, i: emulate_surrogate(
                            surrogate_deterministic,
                            mask,
                            z_obs,
                            label=label,  # noqa: B023
                        )
                    )
            else:
                method = "surrogate"
                settings = None
                emulate = lambda mask, z_obs, noise, i: emulate_surrogate(
                    surrogate,
                    mask,
                    z_obs,
                    label=label,  # noqa: B023
                )
        else:
            method = "autoencoder"
            settings = None

        tic = time.time()

        with torch.no_grad():
            if method in ("diffusion", "surrogate", "crps"):
                z_hat = emulate_rollout(
                    emulate,
                    z,
                    window=cfg.trajectory.length,
                    rollout=z.shape[1],
                    context=context,
                    overlap=overlap,
                    batch=samples if method != "surrogate" else 1,
                    crps_noise_emb=cfg.get("noise_emb_features", 0),
                )
                if (
                    (record > 0)
                    and (index in plot_indices)
                    and (surrogate_deterministic is not None)
                ):
                    z_hat_deterministic = emulate_rollout(
                        emulate_deterministic,
                        z,
                        window=cfg.trajectory.length,
                        rollout=z.shape[1],
                        context=context,
                        overlap=overlap,
                        batch=1,
                        crps_noise_emb=0,
                    )
                else:
                    z_hat_deterministic = None
            else:
                z_hat = z.expand(samples, *z.shape)

            if "euler" in cfg.dataset.name:
                if debug:
                    chunk_size = 64
                else:
                    chunk_size = 128
            elif "gravity" in cfg.dataset.name:
                chunk_size = 128
            elif "shear_flow" in cfg.dataset.name:
                chunk_size = 128
            else:
                chunk_size = 256

            x_hat = decode_traj(
                autoencoder, z_hat, batched=True, noisy=False, chunk_size=chunk_size
            )
            if z_hat_deterministic is not None:
                x_hat_deterministic = decode_traj(
                    autoencoder,
                    z_hat_deterministic,
                    batched=True,
                    noisy=False,
                    chunk_size=chunk_size,
                )
            else:
                x_hat_deterministic = None

        tac = time.time()

        del z_hat, z_hat_deterministic

        ## Speed
        speed = (tac - tic) / (x_hat.shape[0] * x_hat.shape[1])

        ## Postprocess
        x = postprocess(x, dim=-spatial - 2)
        x_ae = postprocess(x_ae, dim=-spatial - 2)
        x_hat = postprocess(x_hat, dim=-spatial - 2)
        if x_hat_deterministic is not None:
            x_hat_deterministic = postprocess(x_hat_deterministic, dim=-spatial - 2)

        if method == "surrogate":
            ensemble_sizes_to_save = [samples]
        ## Stats
        lines = []

        for field in range(x.shape[0]):
            for t in range(context - 1, x.shape[1]):
                u, v = x[field, t], x_hat[:, field, t]

                # Moments
                m1 = torch.mean(u)
                m2 = torch.mean(u**2)

                # Fourier analysis
                p_u, k = isotropic_power_spectrum(u, spatial=spatial)
                bins = torch.logspace(k[0].log2(), -1.0, steps=4, base=2)

                # Storage for ensemble-dependent members
                metrics_by_ensemble = {}

                # Iterate over ensemble sizes:
                for ensemble_i in ensemble_sizes_to_save:
                    v_subset = v[:ensemble_i]  # Use first ensemble_i members

                    # Spread
                    if ensemble_i > 1:
                        spread = torch.mean(
                            torch.square(v_subset - torch.mean(v_subset, dim=0))
                        )
                        spread = torch.sqrt(
                            (ensemble_i + 1) / (ensemble_i - 1) * spread
                        )

                        metrics_by_ensemble[ensemble_i] = {
                            "spread": spread,
                        }
                    else:
                        spread = 0.0

                    # Skill metrics
                    se = torch.square(u - torch.mean(v_subset, dim=0))
                    mse = torch.mean(se)
                    rmse = torch.sqrt(mse)
                    nrmse = torch.sqrt(mse / (torch.mean(u**2) + 1e-6))
                    vrmse = torch.sqrt(mse / (torch.var(u) + 1e-6))

                    # Spread_skill ratio
                    spread_skill = (spread + 1e-3) / (rmse + 1e-3)

                    # Fourier metrics
                    p_v, _ = isotropic_power_spectrum(v_subset, spatial=spatial)
                    p_v = torch.mean(p_v, dim=0)
                    se_p = torch.square(1 - (p_v + 1e-6) / (p_u + 1e-6))
                    rmse_p = torch.sqrt(torch.mean(se_p))

                    fourier_extras = []
                    for i in range(4):
                        if i < 3:
                            mask = torch.logical_and(bins[i] <= k, k <= bins[i + 1])
                        else:
                            mask = bins[i] <= k
                        fourier_extras.append(torch.sqrt(torch.mean(se_p[mask])).item())

                    extras = []
                    ## Wasserstein
                    w_uv = ot.lp.wasserstein_1d(
                        u.flatten(),
                        v_subset.flatten(),
                        p=1.0,
                    )
                    extras.append(w_uv.item())
                    ## Sliced EMD (only makes sense for density)
                    if "density" in cfg.dataset.fields[field] and False:
                        coo = torch.cartesian_prod(
                            *(
                                torch.linspace(0, 1, size, device=u.device)
                                for size in u.shape
                            )
                        )
                        edm = ot.sliced.sliced_wasserstein_distance(
                            coo,
                            coo,
                            a=u.flatten() / u.sum(),
                            b=v_subset.mean(dim=0).flatten()
                            / v_subset.mean(dim=0).sum(),
                            p=1.0,
                            n_projections=16,
                            seed=42,
                        )

                        extras.append(edm.item())
                    else:
                        extras.append(None)

                    # CRPS
                    crps_i = surrogate_loss(
                        predictions=v_subset[None, :, None, None, :],
                        target=u[None, None, None, :],
                        mask=None,
                        mem_efficient=False,
                    ).mean()

                    # Store all metrics for this ensemble size
                    metrics_by_ensemble[ensemble_i] = {
                        "vrmse": vrmse.item(),
                        "rmse": rmse.item(),
                        "nrmse": nrmse.item(),
                        "spread": (
                            spread if isinstance(spread, float) else spread.item()
                        ),
                        "spread_skill": spread_skill.item(),
                        "rmse_p": rmse_p.item(),
                        "rmse_p_low": fourier_extras[0],
                        "rmse_p_mid": fourier_extras[1],
                        "rmse_p_high": fourier_extras[2],
                        "rmse_p_sub": fourier_extras[3],
                        "crps": crps_i.item(),
                        "wasserstein": extras[0],
                        "emd": extras[1],
                    }

                    # Log everything
                    ensemble_metrics = metrics_by_ensemble[ensemble_i]

                    line = f"{runname},state,{compression:.1f},{method},{cfg.get('noise_emb_features', 0)},"
                    line += f"train,None,{context},{overlap},1.0,"
                    line += f"{split},{index},{start},{seed},"
                    line += f"{field_names[field]},{(t - context + 1) * cfg.trajectory.stride},False,"
                    line += f"{ensemble_i},"  # Add ensemble size here
                    line += f"{m1},{m2},"
                    line += f"{ensemble_metrics['spread']},{ensemble_metrics['spread_skill']},"
                    line += f"{ensemble_metrics['rmse']},{ensemble_metrics['nrmse']},{ensemble_metrics['vrmse']},"
                    line += f"{ensemble_metrics['rmse_p']},"

                    # Add Fourier extras
                    fourier_values = [
                        ensemble_metrics["rmse_p_low"],
                        ensemble_metrics["rmse_p_mid"],
                        ensemble_metrics["rmse_p_high"],
                        ensemble_metrics["rmse_p_sub"],
                    ]
                    line += ",".join(map(str, fourier_values)) + ","
                    # Add CRPS value for this ensemble size
                    line += f"{ensemble_metrics['crps']}" + ","

                    # Add Wasserstein and Sliced EMD
                    extra_values = [
                        ensemble_metrics["wasserstein"],
                        ensemble_metrics["emd"],
                    ]
                    line += ",".join(map(str, extra_values))

                    # Add label parameters
                    if hasattr(label, "tolist"):
                        line += f",{','.join(map(str, label.tolist()))}"

                    line += "\n"
                    lines.append(line)

        if index == 0:
            torch.save(x_hat, f"crps_x_hat_{target}_{split}_{index:06d}.pth")
            torch.save(x, f"crps_x_{target}_{split}_{index:06d}.pth")
        if debug:
            stats_path = outdir / runname / f"Nov27_test_stats_{target}_debug.csv"
        else:
            stats_path = outdir / runname / f"Nov27_test_stats_{target}_all.csv"

        with FileLock(str(stats_path) + ".lock"):
            with open(stats_path, "a") as f:
                # Only write header if file is empty/new
                if stats_path.stat().st_size == 0:
                    header = "run,target,compression,method,noise_emb_size,settings,guidance,context,overlap,speed,"
                    header += "split,index,start,seed,field,time,relative,"
                    header += "ensemble_size,"
                    header += "m1,m2,spread,spread_skill,rmse,nrmse,vrmse,rmse_p,"
                    header += "rmse_p_low,rmse_p_mid,rmse_p_high,rmse_p_sub,"
                    header += "crps,"
                    header += "wasserstein,emd"
                    if hasattr(label, "tolist") and len(label.tolist()) > 0:
                        param_names = getattr(
                            cfg.dataset,
                            "param_names",
                            [f"param_{i}" for i in range(len(label.tolist()))],
                        )
                        header += f",{','.join(param_names)}"
                    header += "\n"
                    f.write(header)
                f.writelines(lines)

        # NumPy
        if (record > 0) and (index in plot_indices):
            np.savez(
                outdir
                / runname
                / f"{runname}_{target}_{split}_{index:06d}_{start:03d}_{context}_{overlap}_{settings}_{guidance}_{seed}.npz",
                x=x.numpy(force=True),
                x_hat=x_hat[:record].numpy(force=True),
            )

        # Video
        if spatial == 3:
            x, x_hat = x[..., x.shape[-1] // 2], x_hat[..., x_hat.shape[-1] // 2]

        if x.shape[-1] == x.shape[-2] == 64:
            x = repeat(x, "... H W -> ... (H h) (W w)", h=4, w=4)
            x_hat = repeat(x_hat, "... H W -> ... (H h) (W w)", h=4, w=4)
            if x_hat_deterministic is not None:
                x_hat_deterministic = repeat(
                    x_hat_deterministic, "... H W -> ... (H h) (W w)", h=4, w=4
                )

        if x.shape[-1] < x.shape[-2]:
            x, x_hat = x.mT, x_hat.mT
            if x_hat_deterministic is not None:
                x_hat_deterministic = x_hat_deterministic.mT

        if (record > 0) and (index in plot_indices):
            if x_hat_deterministic is not None:
                frames = torch.stack((x, x_hat_deterministic[0], *x_hat[:record]))
            else:
                frames = torch.stack((x, *x_hat[:record]))
            frames = rearrange(frames, "N C L H W -> L N C H W")
            # Get rid of conditioning frames
            frames = frames[context:]

            draw_movie(
                frames,
                file=(
                    outdir
                    / runname
                    / f"{runname}_{target}_{split}_{index:06d}_{start:03d}_{context}_{overlap}_{settings}_{guidance}_{seed}.mp4"
                ),
                fps=4.0 / cfg.trajectory.stride,
                isolate={2},
            )

            del frames

        del x, x_ae, x_hat, x_hat_deterministic


if __name__ == "__main__":
    # Parser
    parser = argparse.ArgumentParser()
    parser.add_argument("overrides", nargs="*", type=str)
    parser.add_argument(
        "--local", action="store_true", help="Run locally instead of on SLURM"
    )
    parser.add_argument("--debug", action="store_true", help="Debug flag")

    args = parser.parse_args()

    # Config
    cfg = compose(
        config_file="./experiments/configs/eval.yaml",
        overrides=args.overrides,
    )

    ## RNG
    # random.seed(cfg.seed)

    # if isinstance(cfg.array, int):
    #     array = [random.random() for _ in range(cfg.array)]
    # else:
    #     array = cfg.array
    if "rayleigh_benard" in cfg["run"]:
        array = np.arange(175)
    elif "euler" in cfg["run"]:
        array = np.arange(1000)
    elif "shear_flow" in cfg["run"]:
        array = np.arange(112)
    else:
        raise NotImplementedError(f"unknown experiment {cfg['run']}")

    if args.debug:
        array = np.arange(7)

    if args.local:
        # Run locally on single GPU
        print(f"Running evaluation locally with {len(array)} indices...")
        evaluate(indices=array, **cfg, debug=args.debug)
    else:
        # Original SLURM scheduling
        def launch(i: int):
            evaluate(indices=array[i :: cfg.compute.jobs], **cfg, debug=args.debug)

        if args.debug:
            backend = "async"
        else:
            backend = "slurm"
        dawgz.schedule(
            dawgz.job(
                f=launch,
                name="eval",
                array=cfg.compute.jobs,
                cpus=cfg.compute.cpus,
                gpus=cfg.compute.gpus,
                ram=cfg.compute.ram,
                time=cfg.compute.time,
                partition=cfg.server.partition,
                constraint=cfg.server.constraint,
                exclude=cfg.server.exclude,
            ),
            name=f"eval {Path(cfg.run).name}",
            backend=backend,
            env=[
                "export XDG_CACHE_HOME=$HOME/.cache",
            ],
        )
