# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Functions to scale and unscale model parameters.
"""

import numpy as np
import torch

METHODS = ['std', 'std-param', 'min-max', 'min-max-param']


def scale_params(x, model_dict, scales=None, method='std', eps=1e-6):
    # x: psz, seq_len/batch (>=1 for training), state_dim
    assert x.dim() in [2, 3], x.shape
    sz_org = x.shape
    if x.dim() == 2:
        x = x.unsqueeze(1)  # (psz, 1, state_dim)

    if method not in METHODS:
        raise NotImplementedError(method, 'supported methods:', METHODS)

    offset = 0
    compute_scales = scales is None
    if compute_scales:
        scales = []
    per_param = method.endswith('-param')
    is_std = method.startswith('std')
    for layer, (_, p) in enumerate(model_dict):
        shape = p.shape if isinstance(p, torch.Tensor) else p
        n = len(x) if per_param else np.prod(tuple(shape))
        w = x[offset: offset + n]  # (n, seq_len, state_dim)
        if compute_scales:
            dims = 2 if per_param else (0, 2)
            if is_std:
                mn = torch.mean(w, dim=dims, keepdim=True)
                sd = torch.std(w, dim=dims, keepdim=True)  # (n, seq_len, 1) if per_param else (1, seq_len, 1)
            else:
                sd = torch.amax(w, dim=dims, keepdim=True)[0] - torch.amin(w, dim=dims, keepdim=True)[0]
                mn = torch.zeros_like(sd)

            if not is_std or per_param:
                sd[sd < 1e-2] = 1e-2
        else:
            mn, sd = scales[layer]

        x[offset: offset + n] = (w - mn) / (sd + eps)
        offset += n
        if compute_scales:
            scales.append((mn, sd))
        if per_param:
            break

    if len(sz_org) == 2:
        x = x.squeeze(1)

    return x, scales


def unscale_params(x, model_dict, scales, method='std'):
    # x: psz, seq_len/batch (>=1 for training), state_dim (>=1)
    assert x.dim() in [1, 2, 3], x.shape
    if x.dim() == 1:
        x = x.unsqueeze(1)  # (psz, 1)
    if x.dim() == 2:
        x = x.unsqueeze(1)  # (psz, 1, 1)

    if method not in METHODS:
        raise NotImplementedError(method, 'supported methods:', METHODS)

    per_param = method.endswith('-param')

    offset = 0
    for layer, (_, p) in enumerate(model_dict):
        sz = p.shape if isinstance(p, torch.Tensor) else p
        n = len(x) if per_param else sz.numel()
        w = x[offset: offset + n]  # (n, seq_len, state_dim)
        mn, sd = scales[layer]
        x[offset: offset + n] = w * sd.to(w) + mn.to(w)
        offset += n
        if per_param:
            break

    return x
