"""
Serializable distance metric class. Allows the user provided metric for an
environment to be serialized and loaded again.
"""

from typing import List

import copy
import gymnasium as gym
import numpy as np
import torch

def euclidean_distance(a : torch.Tensor, b : torch.Tensor):
    return torch.abs(a - b)

def euclidean_clamp(input : torch.Tensor, min : torch.Tensor, max : torch.Tensor):
    return torch.clamp(input, min=min, max=max)

def radial_distance(a : torch.Tensor, b : torch.Tensor):
        # d: Distance around the circle in radians
        d = torch.fmod(torch.abs(a - b), 2 * torch.pi)
        # We can also go the other way, use the shortest distance
        return torch.min(d, 2*torch.pi - d)

def radial_clamp(input : torch.Tensor, radial_min : torch.Tensor):
    """
    Radial minimum is a value between -pi and 0.

    The clamping operation is as follows:
        mod(x-m, 2π) + m

    Assuming here mod(*, 2π) -> [0, 2π), then this correctly preserves the angle.
    """
    return torch.fmod(torch.fmod(input - radial_min, 2*torch.pi) + 2*torch.pi, 2*torch.pi) + radial_min


class DistanceMetric(object):
    def __init__(self, config, obs_space):
        """
        config is a ndarray, or a nested list of strings which describe the metric for each part of the observation space
        Example config:
        For a Pendulum-v1 environment which consist of [coord, angle, velocity]
        we would have config = ["euclidean", "radial", "euclidean"].

        Idea is to JIT compile this with torchscript such that it does not
        become too sluggish.
        """
        self.obs_space = copy.copy(obs_space)
        self.config = np.array(config)
        assert isinstance(self.obs_space, gym.spaces.Box), f"Expected gym.spaces.Box space, got {self.obs_space}"
        assert self.config.shape == self.obs_space.shape, f"Shape mismatch: {self.config} != {self.obs_space}"
        config_members = set(self.config.flatten().tolist())
        valid_members = {"radial", "euclidean"}
        assert config_members.issubset(valid_members), f"Invalid config members: {config_members - valid_members}"

        self.np_euclidean_idxs = np.where(self.config == "euclidean")
        self.np_radial_idxs = np.where(self.config == "radial")
        self.euclidean_idxs = tuple(torch.as_tensor(t) for t in self.np_euclidean_idxs)
        self.radial_idxs = tuple(torch.as_tensor(t) for t in self.np_radial_idxs)

        self.has_euclidean = bool(len(self.euclidean_idxs[0]) > 0)
        self.has_radial = bool(len(self.radial_idxs[0]) > 0)

        self.euclidean_low = torch.as_tensor(self.obs_space.low[self.np_euclidean_idxs]).float()
        self.euclidean_high = torch.as_tensor(self.obs_space.high[self.np_euclidean_idxs]).float()

        self.radial_min = torch.nan_to_num(
            torch.as_tensor(self.obs_space.low[self.np_radial_idxs]),
            # If a radial lower bound is set to -inf, we default to -pi
            neginf=-torch.pi,
        ).float()

        self.__euclidean_scalar = torch.ones(self.obs_space.shape)
        self.__radial_scalar = torch.full(self.obs_space.shape, torch.pi)

    # Managed properties

    @property
    def euclidean_scalar(self): return self.__euclidean_scalar
    @property
    def radial_scalar(self): return self.__radial_scalar

    @euclidean_scalar.setter
    def euclidean_scalar(self, v):
        assert v.shape == self.config.shape
        self.__euclidean_scalar = v
        return self.__euclidean_scalar

    @radial_scalar.setter
    def radial_scalar(self, v):
        assert v.shape == self.config.shape
        self.__radial_scalar = v
        return self.__radial_scalar

    def __call__(self, a : torch.Tensor, b : torch.Tensor, scaled=False):
        assert a.shape == b.shape, f"a.shape = {a.shape} | b.shape = {b.shape}"
        assert a.shape[-len(self.config.shape):] == self.config.shape, f"a.shape[{-len(self.config.shape)}:] = {a.shape[-len(self.config.shape):]} | self.config.shape = {self.config.shape}"

        diff = torch.zeros(a.shape, device=a.device)
        pfxidx = tuple(slice(None) for _ in range(len(a.shape) - len(self.config.shape)))
        scale_shape = (1,) * len(pfxidx) + self.config.shape
        if self.has_euclidean:
            idx = pfxidx + self.euclidean_idxs
            diff[idx] = euclidean_distance(a[idx], b[idx]).float()
            if scaled:
                euc_scale = self.euclidean_scalar.to(a.device).reshape(scale_shape)[idx]
                diff[idx] = diff[idx] * euc_scale
        if self.has_radial:
            idx = pfxidx + self.radial_idxs
            diff[idx] = radial_distance(a[idx], b[idx]).float()
            if scaled:
                rad_scale = self.radial_scalar.to(a.device).reshape(scale_shape)[idx]
                diff[idx] = diff[idx] * rad_scale

        return diff

    def scaled_distance(self, a : torch.Tensor, b : torch.Tensor):
        return self.__call__(a, b, scaled=True)

    def clamp(self, input):
        output = torch.zeros(input.shape, device=input.device)
        pfxidx = tuple(slice(None) for _ in range(len(input.shape) - len(self.config.shape)))
        if self.has_euclidean:
            idx = pfxidx + self.euclidean_idxs
            output[idx] = euclidean_clamp(
                input[idx],
                min=self.euclidean_low.to(input.device),
                max=self.euclidean_high.to(input.device),
            ).float()
        if self.has_radial:
            idx = pfxidx + self.radial_idxs
            output[idx] = radial_clamp(
                input[idx],
                radial_min=self.radial_min.to(input.device),
            ).float()
        return output
