from copy import copy
from dataclasses import dataclass
from typing import Sequence, Union

import jax
import jax.numpy as jnp
import numpy as np

from ol.utils import Array, ScalarArray

EPSILON = 1e-10

@dataclass
class BatchMetrics:
    mse: Array = None
    l1: Array = None
    l2: Array = None
    recall: Array = None
    chamfer: Array = None

    def map(self, f):
        for key in self.__dict__.keys():
            self.__setattr__(key, f(self.__getattribute__(key)))

    def reshape(self, shape):
        self.map(lambda m: m.reshape(shape))

    def __add__(self, obj):
      out = copy(self)
      for key in self.__dict__.keys():
        out.__setattr__(key, self.__getattribute__(key) + obj.__getattribute__(key))
      return out

@dataclass
class Metrics:
    mse: float = None
    l1: float = None
    l2: float = None
    recall: float = None
    chamfer: float = None

@dataclass
class EvalMetrics:
  median: Metrics = None
  std: Metrics = None
  maximum: Metrics = None

  def to_dict(self):
      return {key: val.__dict__ for key, val in self.__dict__.items()}

def lp_norm(arr: Array, p: int = 2, chunks: Union[None, Sequence[int]] = None, num_chunks: int = None) -> Array:
    """
    Returns the Bochner Lp-norm of an array.

    Args:
        arr: Point-wise values on a uniform grid with the dimensions
            [batch, time, space, var]
        p: Order of the norm. Defaults to 2.
        chunks: Index of variable chunks for vectorial functions.
            If None, the entries of the last axis are interpreted as values of
            independent scalar-valued functions. Defaults to None.

    Returns:
        A scalar value for each sample in the batch [batch, *remaining_axes]
    """

    # Set the default chunks
    if chunks is None:
        chunks = jnp.arange(arr.shape[-1])
        num_chunks = arr.shape[-1]
        keep_var_dim = False
    else:
        keep_var_dim = True

    # Compute power of absolute value
    pow_abs = jnp.power(jnp.abs(arr), p)
    # Sum on timespace (quadrature)
    abs_pow_sum_vars = jnp.sum(pow_abs, axis=(1, 2))
    # Sum on variable chunks
    abs_pow_sum = jax.vmap(jax.ops.segment_sum, in_axes=(0, None, None))(abs_pow_sum_vars, chunks, num_chunks)
    # Take the p-th root
    pth_root = jnp.power(abs_pow_sum, (1/p))
    # Squeeze variable axis
    if not keep_var_dim:
        pth_root = jnp.squeeze(pth_root, axis=-1)

    return pth_root

def rel_lp_error(gtr: Array, prd: Array, p: int = 2, chunks: Union[None, Sequence[int]] = None, num_chunks: int = None, exclude_percentile: float = 0.0) -> Array:
    """
    Returns the relative Bochner Lp-norm of an array with respect to a ground truth.

    Args:
        gtr: Point-wise values of a ground truth function on a uniform
            grid with the dimensions [batch, time, space, var]
        prd: Point-wise values of a predicted function on a uniform
            grid with the dimensions [batch, time, space, var]
        p: Order of the norm. Defaults to 2.
        chunks: Index of variable chunks for vectorial functions.
            If None, the entries of the last axis are interpreted as values of
            independent scalar-valued functions. Defaults to None.
        exclude_large_percentile: Exclude the n% largest (in magnitude, per sample, per component) values


    Returns:
        A scalar value for each sample in the batch [batch, var]
    """

    if chunks is None:
        chunks = jnp.arange(gtr.shape[-1])
        num_chunks = gtr.shape[-1]
    else:
        chunks = jnp.array(chunks)

    # Calculate the error
    err = (prd - gtr)
    # Exclude the largest values (in gtr) from the computations
    percentiles = jax.vmap(lambda arr: jax.vmap(lambda a: jnp.percentile(a, q=(100-exclude_percentile)), in_axes=-1)(arr), in_axes=0)(jnp.abs(gtr))
    err_corrected = jnp.where(jnp.abs(gtr) > percentiles[:, None, None, :], 0.0, err)
    gtr_corrected = jnp.where(jnp.abs(gtr) > percentiles[:, None, None, :], 0.0, gtr)
    # Calculate the norms
    err_norm = lp_norm(err_corrected, p=p, chunks=chunks, num_chunks=num_chunks)
    gtr_norm = lp_norm(gtr_corrected, p=p, chunks=chunks, num_chunks=num_chunks)

    return (err_norm / (gtr_norm + EPSILON))

def rel_lp_error_norm(gtr: Array, prd: Array, p: int = 2, chunks: Union[None, Sequence[int]] = None, num_chunks: int = None, exclude_percentile: float = 0.0) -> Array:
    """
    Returns the norm of the relative Bochner Lp-norm of an array with respect to a ground truth.
    The entries of the last axis are interpreted as values of independent scalar-valued
    functions. This results in an error vector. The vector norm of the error vector is returned.

    Args:
        gtr: Point-wise values of a ground truth function on a uniform
            grid with the dimensions [batch, time, space, var]
        prd: Point-wise values of a predicted function on a uniform
            grid with the dimensions [batch, time, space, var]
        p: Order of the norm. Defaults to 2.
        chunks: Index of variable chunks for vectorial functions.
            If None, the entries of the last axis are interpreted as values of
            independent scalar-valued functions. Defaults to None.
        exclude_large_percentile: Exclude the n% largest (in magnitude, per sample, per component) values

    Returns:
        The vector norm of the error vector [batch,]
    """

    err_per_var = rel_lp_error(gtr, prd, p=p, chunks=chunks, num_chunks=num_chunks, exclude_percentile=exclude_percentile)
    err_agg = jnp.linalg.norm(err_per_var, ord=p, axis=1)
    return err_agg

def rel_lp_error_mean(gtr: Array, prd: Array, p: int = 2, chunks: Union[None, Sequence[int]] = None, num_chunks: int = None, exclude_percentile: float = 0.0) -> Array:
    """
    Returns the average of the relative Bochner Lp-norm of an array with respect to a ground truth.
    The entries of the last axis are interpreted as values of independent scalar-valued
    functions. This results in an error vector. The mean of the error vector is returned.

    Args:
        gtr: Point-wise values of a ground truth function on a uniform
            grid with the dimensions [batch, time, space, var]
        prd: Point-wise values of a predicted function on a uniform
            grid with the dimensions [batch, time, space, var]
        p: Order of the norm. Defaults to 2.
        chunks: Index of variable chunks for vectorial functions.
            If None, the entries of the last axis are interpreted as values of
            independent scalar-valued functions. Defaults to None.
        exclude_large_percentile: Exclude the n% largest (in magnitude, per sample, per component) values

    Returns:
        The mean of the error vector [batch,]
    """

    err_per_var = rel_lp_error(gtr, prd, p=p, chunks=chunks, num_chunks=num_chunks, exclude_percentile=exclude_percentile)
    err_agg = jnp.mean(err_per_var, axis=1)
    return err_agg

def rel_lp_loss(gtr: Array, prd: Array, p: int = 2, q: float = 0.0) -> ScalarArray:
    """
    Returns the mean relative Bochner Lp-norm of an array with respect to a ground truth.

    Args:
        gtr: Point-wise values of a ground truth function on a uniform
            grid with the dimensions [batch, time, space, var]
        prd: Point-wise values of a predicted function on a uniform
            grid with the dimensions [batch, time, space, var]
        p: Order of the norm. Defaults to 2.
        q: Percentage of the largest (in magnitude, per sample, per component) values to be excluded

    Returns:
        Mean relative Lp-norm over the batch.
    """

    return jnp.mean(rel_lp_error_norm(gtr, prd, p=p, exclude_percentile=q))

def mse_error(gtr: Array, prd: Array) -> Array:
    """
    Returns the mean squared error per variable.
    All input shapes are [batch, time, space, var]
    Output shape is [batch,].
    """

    return jnp.mean(jnp.power(prd - gtr, 2), axis=(1, 2, 3))

def mse_loss(gtr: Array, prd: Array) -> ScalarArray:
    """
    Returns the mean squared error.
    All input shapes are [batch, time, space, var]
    Output shape is a scalar.
    """

    return jnp.mean(jnp.power(prd - gtr, 2))

def get_critical_values_mask(arr: Array, q: float) -> Array:
  """
  Returns a mask indicating the highest n% values (in magnitude, per sample, per component)
  """
  percentiles = jax.vmap(lambda arr: jax.vmap(lambda a: jnp.percentile(a, q=(100-q)), in_axes=-1)(arr), in_axes=0)(jnp.abs(arr))
  mask = jnp.where(jnp.abs(arr) > percentiles[:, None, None, :], True, False)
  return mask

def recall(gtr: Array, prd: Array, q: float) -> Array:
  """
  Returns the recall score for two given functions.
  Args:
    gtr: Ground truth function [batch, time, space, var]
    prd: Prediction function [batch, time, space, var]
    q: Based on top q% percentage (in magnitude, per sample, per component)

  Returns:
    A scalar value between 0 and 1 for each sample in the batch [batch, var]
  """

  mask_gtr = get_critical_values_mask(gtr, q=q)
  mask_prd = get_critical_values_mask(prd, q=q)
  true_positive = (mask_gtr & mask_prd)
  score = true_positive.sum(axis=(1, 2)) / mask_gtr.sum(axis=(1, 2))
  return score

def recall_mean(gtr: Array, prd: Array, q: float) -> Array:
  score_per_var = recall(gtr, prd, q=q)
  score_agg = jnp.mean(score_per_var, axis=1)
  return score_agg

def recall_loss(gtr: Array, prd: Array, q: float = 0.1) -> ScalarArray:
  return 1 - jnp.mean(recall(gtr, prd, q=q))

def iou(gtr: Array, prd: Array, q: float) -> Array:
  """
  Returns the Intersection over Union score for two given function.
  Args:
    gtr: Ground truth function [batch, time, space, var]
    prd: Prediction function [batch, time, space, var]
    q: Based on top q% percentage (in magnitude, per sample, per component)

  Returns:
    A scalar value between 0 and 1 for each sample in the batch [batch, var]
  """

  mask_gtr = get_critical_values_mask(gtr, q=q)
  mask_prd = get_critical_values_mask(prd, q=q)
  intersection = (mask_gtr & mask_prd)
  union = (mask_gtr | mask_prd)
  score = intersection.sum(axis=(1, 2)) / union.sum(axis=(1, 2))
  return score

def _chamfer_distance_single_instance(x, m_gtr, m_prd, size) -> ScalarArray:
  """
  Computes the Chamfer distance between two segment masks of a spatial field.
  Since the false positives are already taken into account in the relative error,
  here we only take the average distance of the false negatives.

  Args:
    x: Space coordinates of shape [space, dim]
    m_gtr: Ground truth mask of shape [space,]
    m_prd: Assessed mask of shape [space,]
    size: Expected size of the segments
  """

  x_gtr = x[jnp.where(m_gtr, size=size)]
  x_prd = x[jnp.where(m_prd, size=size)]
  d = jnp.linalg.norm(x_gtr[:, None, :] - x_prd[None, :, :], axis=-1)
  return jnp.min(d, axis=1).mean()

def chamfer(x, u_gtr, u_prd, q):
  mask_gtr = get_critical_values_mask(u_gtr, q=q)
  mask_prd = get_critical_values_mask(u_prd, q=q)
  size = np.ceil(q / 100 * u_gtr.shape[2]).astype(int)
  _f_per_var = jax.vmap(_chamfer_distance_single_instance, in_axes=(None, 1, 1, None))
  _f_per_time_var = jax.vmap(_f_per_var, in_axes=(0, 0, 0, None))
  _f_per_sample_time_var = jax.vmap(_f_per_time_var, in_axes=(0, 0, 0, None))
  score = _f_per_sample_time_var(x, mask_gtr, mask_prd, size)

  return score[:, 0, :]

def chamfer_mean(x, u_gtr, u_prd, q):
  return jnp.mean(chamfer(x, u_gtr, u_prd, q), axis=-1)

def chamfer_loss(x, u_gtr, u_prd, q):
   return jnp.mean(chamfer(x, u_gtr, u_prd, q))

def _recall_tol_single_instance(x, m_gtr, m_prd, tol, size) -> ScalarArray:
  """
  Returns the recall score for two given functions with a given tolerance.

  Args:
    x: Space coordinates of shape [space, dim]
    m_gtr: Ground truth mask of shape [space,]
    m_prd: Assessed mask of shape [space,]
    tol: Tolerance for counting a true positive
    size: Expected size of the segments
  """

  x_gtr = x[jnp.where(m_gtr, size=size)]
  x_prd = x[jnp.where(m_prd, size=size)]
  d = jnp.linalg.norm(x_gtr[:, None, :] - x_prd[None, :, :], axis=-1)
  minimum_positive_distance = jnp.min(d, axis=1)
  true_positive = minimum_positive_distance <= tol
  score = true_positive.sum(axis=0) / size
  return score

def recall_tol(x, u_gtr, u_prd, q, tol):
  mask_gtr = get_critical_values_mask(u_gtr, q=q)
  mask_prd = get_critical_values_mask(u_prd, q=q)
  size = np.ceil(q / 100 * u_gtr.shape[2]).astype(int)
  _f_per_var = jax.vmap(_recall_tol_single_instance, in_axes=(None, 1, 1, None, None))
  _f_per_time_var = jax.vmap(_f_per_var, in_axes=(0, 0, 0, None, None))
  _f_per_sample_time_var = jax.vmap(_f_per_time_var, in_axes=(0, 0, 0, None, None))
  score = _f_per_sample_time_var(x, mask_gtr, mask_prd, tol, size)

  return score[:, 0, :]

def recall_tol_mean(x, u_gtr, u_prd, q, tol):
  return jnp.mean(recall_tol(x, u_gtr, u_prd, q, tol), axis=-1)

def recall_tol_loss(x, u_gtr, u_prd, q, tol):
   return jnp.mean(recall_tol(x, u_gtr, u_prd, q, tol))
