"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

from functools import partial
from typing import Optional

import torch

__all__ = [
    "normalise_surf_var",
    "normalise_atmos_var",
    "unnormalise_surf_var",
    "unnormalise_atmos_var",
]


def normalise_surf_var(
    x: torch.Tensor,
    name: str,
    stats: Optional[dict[str, tuple[float, float]]] = None,
    unnormalise: bool = False,
) -> torch.Tensor:
    """Normalise a surface-level variable."""
    if stats and name in stats:
        location, scale = stats[name]
    else:
        if name not in locations:
            locations[name] = 0
        if name not in scales:
            scales[name] = 1
        location = locations[name]
        scale = scales[name]
    if unnormalise:
        return x * scale + location
    else:
        return (x - location) / scale


def normalise_atmos_var(
    x: torch.Tensor,
    name: str,
    atmos_levels: tuple[int | float, ...],
    unnormalise: bool = False,
) -> torch.Tensor:
    """Normalise an atmospheric variable."""
    level_locations: list[int | float] = []
    level_scales: list[int | float] = []
    for level in atmos_levels:
        if f"{name}_{level}" not in locations:
            locations[f"{name}_{level}"] = 0
        if f"{name}_{level}" not in scales:
            scales[f"{name}_{level}"] = 1
        level_locations.append(locations[f"{name}_{level}"])
        level_scales.append(scales[f"{name}_{level}"])
    location = torch.tensor(level_locations, dtype=x.dtype, device=x.device)
    scale = torch.tensor(level_scales, dtype=x.dtype, device=x.device)

    if unnormalise:
        return x * scale[..., None, None] + location[..., None, None]
    else:
        return (x - location[..., None, None]) / scale[..., None, None]


unnormalise_surf_var = partial(normalise_surf_var, unnormalise=True)
unnormalise_atmos_var = partial(normalise_atmos_var, unnormalise=True)


locations: dict[str, float] = {
    "z": -1.386496e03,
    "lsm": 0.000000e00,
    "slt": 0.000000e00,
    "2t": 2.785140e02,
    "10u": -5.135059e-02,
    "10v": 1.891580e-01,
    "msl": 1.009578e05,
    "z_50": 1.993730e05,
    "z_100": 1.576421e05,
    "z_150": 1.331414e05,
    "z_200": 1.153300e05,
    "z_250": 1.012231e05,
    "z_300": 8.941415e04,
    "z_400": 6.998038e04,
    "z_500": 5.411537e04,
    "z_600": 4.064833e04,
    "z_700": 2.892882e04,
    "z_850": 1.374978e04,
    "z_925": 7.015005e03,
    "z_1000": 7.381545e02,
    "u_50": 5.653076e00,
    "u_100": 1.027951e01,
    "u_150": 1.354061e01,
    "u_200": 1.420915e01,
    "u_250": 1.334584e01,
    "u_300": 1.180173e01,
    "u_400": 8.817291e00,
    "u_500": 6.563273e00,
    "u_600": 4.814521e00,
    "u_700": 3.345237e00,
    "u_850": 1.418379e00,
    "u_925": 6.172657e-01,
    "u_1000": -3.328723e-02,
    "v_50": 4.226111e-03,
    "v_100": 1.411897e-02,
    "v_150": -3.697671e-02,
    "v_200": -4.507801e-02,
    "v_250": -2.980338e-02,
    "v_300": -2.294770e-02,
    "v_400": -1.771003e-02,
    "v_500": -2.387986e-02,
    "v_600": -2.716674e-02,
    "v_700": 2.153583e-02,
    "v_850": 1.428150e-01,
    "v_925": 2.053480e-01,
    "v_1000": 1.867637e-01,
    "t_50": 2.124864e02,
    "t_100": 2.084042e02,
    "t_150": 2.133201e02,
    "t_200": 2.180615e02,
    "t_250": 2.227710e02,
    "t_300": 2.288696e02,
    "t_400": 2.421368e02,
    "t_500": 2.529492e02,
    "t_600": 2.611347e02,
    "t_700": 2.674010e02,
    "t_850": 2.745600e02,
    "t_925": 2.773572e02,
    "t_1000": 2.810130e02,
    "q_50": 2.678180e-06,
    "q_100": 2.633677e-06,
    "q_150": 5.254625e-06,
    "q_200": 1.940632e-05,
    "q_250": 5.773618e-05,
    "q_300": 1.273861e-04,
    "q_400": 3.855659e-04,
    "q_500": 8.529599e-04,
    "q_600": 1.541429e-03,
    "q_700": 2.431637e-03,
    "q_850": 4.575618e-03,
    "q_925": 6.033134e-03,
    "q_1000": 7.030342e-03,
    "cmorph": 0.109883055,
    "tp": 0.00059495,
    "r_50": 6.28536672,
    "r_100": 24.77944122,
    "r_150": 24.98625128,
    "r_200": 33.18924404,
    "r_250": 43.84053888,
    "r_300": 49.83448177,
    "r_400": 48.85782405,
    "r_500": 46.9529386,
    "r_600": 48.08965347,
    "r_700": 51.35312916,
    "r_850": 64.34232992,
    "r_925": 73.56247303,
    "r_1000": 72.96043561,
}

scales: dict[str, float] = {
    "z": 5.884467e04,
    "lsm": 1.000000e00,
    "slt": 7.000000e00,
    "2t": 2.122036e01,
    "10u": 5.547512e00,
    "10v": 4.765339e00,
    "msl": 1.332246e03,
    "z_50": 5.875553e03,
    "z_100": 5.510640e03,
    "z_150": 5.823912e03,
    "z_200": 5.820169e03,
    "z_250": 5.536585e03,
    "z_300": 5.091916e03,
    "z_400": 4.150851e03,
    "z_500": 3.353187e03,
    "z_600": 2.695808e03,
    "z_700": 2.136436e03,
    "z_850": 1.470321e03,
    "z_925": 1.228997e03,
    "z_1000": 1.072307e03,
    "u_50": 1.529281e01,
    "u_100": 1.352611e01,
    "u_150": 1.604335e01,
    "u_200": 1.767630e01,
    "u_250": 1.796710e01,
    "u_300": 1.711917e01,
    "u_400": 1.434276e01,
    "u_500": 1.198419e01,
    "u_600": 1.033421e01,
    "u_700": 9.168821e00,
    "u_850": 8.188043e00,
    "u_925": 7.940808e00,
    "u_1000": 6.141778e00,
    "v_50": 7.058931e00,
    "v_100": 7.479310e00,
    "v_150": 9.571990e00,
    "v_200": 1.188069e01,
    "v_250": 1.338039e01,
    "v_300": 1.334044e01,
    "v_400": 1.122955e01,
    "v_500": 9.181708e00,
    "v_600": 7.803569e00,
    "v_700": 6.871040e00,
    "v_850": 6.264443e00,
    "v_925": 6.470644e00,
    "v_1000": 5.308203e00,
    "t_50": 1.026284e01,
    "t_100": 1.252901e01,
    "t_150": 8.928709e00,
    "t_200": 7.189547e00,
    "t_250": 8.529282e00,
    "t_300": 1.071679e01,
    "t_400": 1.269102e01,
    "t_500": 1.306447e01,
    "t_600": 1.342046e01,
    "t_700": 1.476523e01,
    "t_850": 1.558880e01,
    "t_925": 1.608798e01,
    "t_1000": 1.713983e01,
    "q_50": 3.571687e-07,
    "q_100": 5.703754e-07,
    "q_150": 3.794077e-06,
    "q_200": 2.267534e-05,
    "q_250": 7.446644e-05,
    "q_300": 1.684361e-04,
    "q_400": 5.078644e-04,
    "q_500": 1.079294e-03,
    "q_600": 1.769722e-03,
    "q_700": 2.549169e-03,
    "q_850": 4.112368e-03,
    "q_925": 5.071058e-03,
    "q_1000": 5.913548e-03,
    "cmorph": 0.655754,
    "tp": 0.00194699,
    
    "r_50": 15.05099344,
    "r_100": 33.23311031,
    "r_150": 31.79981453,
    "r_200": 34.18327632,
    "r_250": 35.44982815,
    "r_300": 35.47569881,
    "r_400": 35.49229363,
    "r_500": 34.75275185,
    "r_600": 34.10519766,
    "r_700": 33.35618364,
    "r_850": 30.73419218,
    "r_925": 28.63213914,
    "r_1000": 26.28105185,
}
