import abc
import warnings
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Literal,
    Optional,
    Tuple,
    Union,
)

import jax
from jax._src.basearray import Array as Array
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax import struct
from flax.core import frozen_dict
from flax.training import train_state
from optax._src.base import OptState

from ott import utils
from ott.geometry import costs
from ott.neural import models
from ott.neural.solvers import conjugate
from ott.neural.solvers.conjugate import DEFAULT_CONJUGATE_SOLVER, FenchelConjugateSolver
from ott.neural.solvers.neuraldual import BaseW2NeuralDual, W2NeuralDual
from ott.problems.linear import potentials

__all__ = ["W2StarStar", "W2StarStarSymmetric"]


def expectile_loss(adv, diff, expectile=0.9):
    weight = jnp.where(adv >= 0, expectile, (1 - expectile))
    return weight * diff ** 2


class W2StarStar(W2NeuralDual):

  def __init__(
      self,
      dim_data: int,
      neural_f: Optional[BaseW2NeuralDual] = None,
      neural_g: Optional[BaseW2NeuralDual] = None,
      optimizer_f: Optional[optax.OptState] = None,
      optimizer_g: Optional[optax.OptState] = None,
      num_train_iters: int = 1,
      num_inner_iters: int = 1,
      back_and_forth: Optional[bool] = None,
      valid_freq: int = 1000,
      log_freq: int = 1000,
      logging: bool = False,
      rng: Optional[jax.Array] = None,
      pos_weights: bool = True,
      beta: float = 1.0,
      conjugate_solver: Optional[conjugate.FenchelConjugateSolver
                                ] = conjugate.DEFAULT_CONJUGATE_SOLVER,
      amortization_loss: Literal["objective", "regression"] = "regression",
      parallel_updates: bool = True,
      expectile: float = 0.99):
    
    super(W2StarStar, self).__init__(dim_data, neural_f, neural_g, optimizer_f, optimizer_g, num_train_iters, num_inner_iters, back_and_forth, valid_freq, log_freq, logging, rng, pos_weights, beta, conjugate_solver, amortization_loss, parallel_updates)
    self.expectile = expectile

  def loss_fn(self, params_f, params_g, f_value, f_gradient, g_value, g_gradient, batch, to_optimize):
      """Loss function for both potentials."""
      # get two distributions
      source, target = batch["source"], batch["target"]

      f_value_partial = f_value(params_f)
      batch_dot = jax.vmap(jnp.dot)

      source_hat = g_gradient(params_g)(target) 
      source_hat_detach = jax.lax.stop_gradient(source_hat)
  
      f_source = f_value_partial(source)
      f_star_target = batch_dot(source_hat_detach, target) - f_value_partial(source_hat_detach)

     
      def conj_f(x, y):
        return batch_dot(x, y) - f_value(jax.lax.stop_gradient(params_f))(x) 
        

      u_f = jax.lax.stop_gradient(batch_dot(source, target) - conj_f(source_hat_detach, target)) - f_source 
      f_loss = expectile_loss(u_f, u_f, self.expectile).mean() 

      u_g = jax.lax.stop_gradient(conj_f(source, target)) - f_star_target
      g_loss = expectile_loss(u_g, u_g, self.expectile).mean() 
      dual_loss = f_loss + g_loss 

      dual_loss = dual_loss + ((f_source).mean() + (f_star_target).mean()) 
      amor_loss = - conj_f(source_hat, target).mean() 

      dual_loss = dual_loss
      
      if to_optimize == "both":
        loss = dual_loss + amor_loss
      elif to_optimize == "f":
        loss = dual_loss
      elif to_optimize == "g":
        loss = amor_loss
      else:
        raise ValueError(
            f"Optimization target {to_optimize} has been misspecified."
        )

      if not self.pos_weights:
        loss += self.beta * self._penalize_weights_icnn(params_f) + \
            self.beta * self._penalize_weights_icnn(params_g)

      # compute Wasserstein-2 distance
      C = jnp.mean(jnp.sum(source ** 2, axis=-1)) + \
          jnp.mean(jnp.sum(target ** 2, axis=-1))
      
      W2_dist = C - 2. * (f_source.mean() + f_star_target.mean())

      return loss, (dual_loss, amor_loss, W2_dist)


  def get_step_fn(
      self, train: bool, to_optimize: Literal["f", "g", "parallel", "both"]
  ):
    """Create a parallel training and evaluation function."""

    @jax.jit
    def step_fn(state_f, state_g, batch):
      """Step function of either training or validation."""
      grad_fn = jax.value_and_grad(self.loss_fn, argnums=[0, 1], has_aux=True)
      if train:
        # compute loss and gradients
        (loss, (loss_f, loss_g, W2_dist)), (grads_f, grads_g) = grad_fn(
            state_f.params,
            state_g.params,
            state_f.potential_value_fn,
            state_f.potential_gradient_fn,
            state_g.potential_value_fn,
            state_g.potential_gradient_fn,
            batch,
            to_optimize
        )
        # update state
        if to_optimize == "both":
          return (
              state_f.apply_gradients(grads=grads_f),
              state_g.apply_gradients(grads=grads_g), loss, loss_f, loss_g,
              W2_dist
          )
        if to_optimize == "f":
          return state_f.apply_gradients(grads=grads_f), loss_f, W2_dist
        if to_optimize == "g":
          return state_g.apply_gradients(grads=grads_g), loss_g, W2_dist
        raise ValueError("Optimization target has been misspecified.")

      # compute loss and gradients
      (loss, (loss_f, loss_g, W2_dist)), _ = grad_fn(
          state_f.params,
          state_g.params,
          state_f.potential_value_fn,
          state_f.potential_gradient_fn,
          state_g.potential_value_fn,
          state_g.potential_gradient_fn,
          batch,
          to_optimize
      )

      # do not update state
      if to_optimize == "both":
        return loss_f, loss_g, W2_dist
      if to_optimize == "f":
        return loss_f, W2_dist
      if to_optimize == "g":
        return loss_g, W2_dist
      raise ValueError("Optimization target has been misspecified.")

    return step_fn



class W2StarStarSymmetric(W2StarStar):

    def loss_fn(self, params_f, params_g, f_value, f_gradient, g_value, g_gradient, batch, to_optimize):
      """Loss function for both potentials."""
      # get two distributions
      source, target = batch["source"], batch["target"]

      def g_value_partial(y: jnp.ndarray) -> jnp.ndarray:
        """Lazy way of evaluating g if f's computation needs it."""
        return g_value(params_g)(y)
      
      f_value_partial = f_value(params_f, g_value_partial)
      batch_dot = jax.vmap(jnp.dot)

      source_hat = g_gradient(params_g)(target) 
      source_hat_detach = jax.lax.stop_gradient(source_hat)
      target_hat = f_gradient(params_f)(source) 
      target_hat_detach = jax.lax.stop_gradient(target_hat)

      f_source = f_value_partial(source)
      f_star_target = batch_dot(source_hat_detach, target) - f_value_partial(source_hat_detach)
      g_target = g_value_partial(target)
      g_star_source = batch_dot(source, target_hat_detach) - g_value_partial(target_hat_detach)


      def conj_f(x, y):
        return batch_dot(x, y) - f_value(jax.lax.stop_gradient(params_f))(x) 
    
      def conj_g(x, y):
        return batch_dot(x, y) - g_value(jax.lax.stop_gradient(params_g))(y) 
      
      u_f = jax.lax.stop_gradient(conj_g(source, target_hat_detach)) - f_source
      f_loss = expectile_loss(u_f, u_f, self.expectile).mean() 

      u_g = jax.lax.stop_gradient(conj_f(source_hat_detach, target)) - g_target
      g_loss = expectile_loss(u_g, u_g, self.expectile).mean() 
      dual_loss = f_loss 
      amor_loss = g_loss 

      # dual_loss = dual_loss + (f_source.mean() + g_target.mean())

      dual_loss = dual_loss + (f_source.mean() + f_star_target.mean()) 
      amor_loss = amor_loss - conj_f(source_hat, target).mean() 

      dual_loss = dual_loss - conj_g(source, target_hat).mean() 
      amor_loss = amor_loss + (g_target.mean() + g_star_source.mean())

      dual_loss = dual_loss / 2
      amor_loss = amor_loss / 2
      
      if to_optimize == "both":
        loss = dual_loss + amor_loss
      elif to_optimize == "f":
        loss = dual_loss
      elif to_optimize == "g":
        loss = amor_loss
      else:
        raise ValueError(
            f"Optimization target {to_optimize} has been misspecified."
        )

      if not self.pos_weights:
        loss += self.beta * self._penalize_weights_icnn(params_f) + \
            self.beta * self._penalize_weights_icnn(params_g)

      # compute Wasserstein-2 distance
      C = jnp.mean(jnp.sum(source ** 2, axis=-1)) + \
          jnp.mean(jnp.sum(target ** 2, axis=-1))
      
      W2_dist = C - 2. * (f_source.mean() + g_target.mean())

      return loss, (dual_loss, amor_loss, W2_dist)
