from pytorch_lightning.loggers import WandbLogger
from typing import Optional

from pathlib import Path

import numpy as np
import torch


def get_diam_bounds(data: torch.Tensor, var_dims: np.ndarray):
    upper_bound, lower_bound = {}, {}
    maxx, minn = data.max(dim=0)[0], data.min(dim=0)[0]
    for i in range(var_dims.shape[0]):
        start, end = np.sum(var_dims[:i]), np.sum(var_dims[:(i+1)])
        upper_bound[i] = torch.max(maxx[start:end], abs(minn[start:end]))
        lower_bound[i] = torch.min(minn[start:end], -abs(maxx[start:end]))
        alpha_min = 1.1 * (lower_bound[i] < 0.) + 0.9 * (lower_bound[i] >= 0.)
        alpha_max = 1.1 * (lower_bound[i] > 0.) + 0.9 * (lower_bound[i] <= 0.)
        lower_bound[i] *= alpha_min
        upper_bound[i] *= alpha_max
    diam = np.linalg.norm(data.max(dim=0)[0] - data.min(dim=0)[0])
    return diam, lower_bound, upper_bound
