from abc import ABC, abstractmethod
from typing import Mapping

import jax
import jax.numpy as jnp

from ol.models.common import AbstractOperator, Inputs
from ol.utils import Array, normalize, unnormalize
from ol.dataset.dataset import Stats


class Stepper(ABC):

  def __init__(self, operator: AbstractOperator):
    self._apply_operator = operator.apply

  def normalize_inputs(self, stats: Mapping[str, Stats], inputs: Inputs) -> Inputs:
    s_nrm = normalize(inputs.s, shift=stats['geo'].mean, scale=stats['geo'].std)
    a_nrm = normalize(inputs.a, shift=stats['dom'].mean, scale=stats['dom'].std)
    q_nrm = {key: normalize(inputs.q[key], shift=stats['seg'][key].mean, scale=stats['seg'][key].std) for key in inputs.q.keys()}
    x_inp_nrm = 2 * ((inputs.x_inp - stats['x'].min) / (stats['x'].max - stats['x'].min)) - 1
    x_out_nrm = 2 * ((inputs.x_out - stats['x'].min) / (stats['x'].max - stats['x'].min)) - 1
    if inputs.t is None:
      t_nrm = None
    else:
      t_nrm = (inputs.t - stats['t'].min) / (stats['t'].max - stats['t'].min)
    if inputs.tau is None:
      tau_nrm = None
    else:
      tau_nrm = (inputs.tau) / (stats['t'].max - stats['t'].min)

    inputs_nrm = Inputs(
      s=s_nrm,
      a=a_nrm,
      q=q_nrm,
      m=inputs.m,  # NOTE: Binary masks are not normalized
      x_inp=x_inp_nrm,
      x_out=x_out_nrm,
      t=t_nrm,
      tau=tau_nrm,
    )

    return inputs_nrm

  @abstractmethod
  def apply(self,
    variables,
    stats: Mapping[str, Stats],
    inputs: Inputs,
    **kwargs,
  ):
    """
    Normalizes raw inputs and applies the operator on it.

    t_inp is the time of the input and must be a non-negative integer.
    tau is the time difference and must be an integer greater than zero.
    """
    pass

  def unroll(self,
    variables,
    stats: Mapping[str, Stats],
    num_steps: int,
    inputs: Inputs,
    **kwargs,
  ):
    """Apply the stepper multiple times to reach t_inp+tau by dividing tau."""
    # NOTE: Assuming constant x
    # NOTE: Assuming constant c

    def scan_fn_fractional(carry, forcing):
      q, u_inp, t_inp = carry
      tau = forcing
      _inputs = Inputs(
        s=q,
        a=u_inp,
        x_inp=inputs.x_inp,
        x_out=inputs.x_out,
        t=t_inp,
        tau=tau
      )
      u_out = self.apply(
        variables,
        stats,
        inputs=_inputs,
        **kwargs,
      )
      s_next = q
      u_next = u_out
      t_next = t_inp + tau
      carry = (s_next, u_next, t_next)
      return carry, u_out

    # Split tau in num_steps fractional parts
    tau_tiled = jnp.repeat(inputs.tau, repeats=num_steps)
    tau_fract = tau_tiled / num_steps
    forcing = tau_fract

    (u_out, _), _ = jax.lax.scan(f=scan_fn_fractional,
      init=(inputs.s, inputs.a, inputs.t), xs=forcing, length=num_steps)

    return u_out

  @abstractmethod
  def get_loss_inputs(self,
    variables,
    stats: Mapping[str, Stats],
    inputs: Inputs,
    **kwargs,
  ):
    """
    Calculates prediction and target variables, ready to be given as input to the loss function.

    t_inp is the time of the input and must be a non-negative integer.
    tau is the time difference and must be an integer greater than zero.
    """
    pass

  def get_intermediates(self,
    variables,
    stats: Mapping[str, Stats],
    inputs: Inputs,
    **kwargs,
  ):
    # Normalize inputs
    inputs_nrm = self.normalize_inputs(stats, inputs)

    # Get predicted normalized derivatives
    _, state = self._apply_operator(
      variables,
      inputs=inputs_nrm,
      capture_intermediates=(lambda mdl, method_name: False), # Only get the registered intermediates
      **kwargs,
    )

    return state['intermediates']

class OutputStepper(Stepper):

  def apply(self,
    variables,
    stats: Mapping[str, Stats],
    inputs: Inputs,
    **kwargs,
  ):
    """
    Normalizes raw inputs and applies the operator on it.

    t_inp is the time of the input and must be a non-negative integer.
    tau is the time difference and must be an integer greater than zero.
    """

    # Normalize inputs
    inputs_nrm = self.normalize_inputs(stats, inputs)

    # Get predicted normalized output
    u_prd_nrm = self._apply_operator(
      variables,
      inputs=inputs_nrm,
      **kwargs,
    )

    # Unnormalize predicted output
    u_prd = unnormalize(
      u_prd_nrm,
      mean=stats['out'].mean,
      std=stats['out'].std,
    )

    return u_prd

  def get_loss_inputs(self,
    variables,
    stats: Mapping[str, Stats],
    u_tgt: Array,
    inputs: Inputs,
    **kwargs,
  ):
    """
    Calculates prediction and target variables, ready to be given as input to the loss function.

    t_inp is the time of the input and must be a non-negative integer.
    tau is the time difference and must be an integer greater than zero.
    """

    # Normalize inputs
    inputs_nrm = self.normalize_inputs(stats, inputs)

    # Get predicted normalized output
    u_prd_nrm = self._apply_operator(
      variables,
      inputs=inputs_nrm,
      **kwargs,
    )

    # Get target normalized output
    u_tgt_nrm = normalize(
      u_tgt,
      shift=stats['out'].mean,
      scale=stats['out'].std,
    )

    return (u_tgt_nrm, u_prd_nrm)
