import jax
import jax.numpy as jnp
import equinox as eqx
from jaxrl_m.typing import *
from tqdm.auto import tqdm

import flax
import flax.linen as nn
import optax

from src.agents.icvf_flax.icvf_learner import ICVFAgent
from src.agents.icvf_flax.icvf_networks import create_icvf
from src.agents.icvf_flax import icvf_learner as learner
from jaxrl_m.common import TrainState, target_update, nonpytree_field

# OT
from ott.neural.networks.potentials import PotentialTrainState
from ott.neural.networks.potentials import PotentialMLP
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

from jax.numpy import ndarray
from ott.geometry import costs

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))
def eval_ensemble_psi(ensemble, s):
    return eqx.filter_vmap(ensemble.psi_net)(s)

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))
def eval_ensemble_phi(ensemble, s):
    return eqx.filter_vmap(ensemble.phi_net)(s)

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None))
def eval_ensemble_icvf_viz(ensemble, s, g, z):
    return eqx.filter_vmap(ensemble.classic_icvf_initial)(s, g, z)

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None)) # V(s, g, z), g - dim 29, z - dim 256
def eval_ensemble_icvf_latent_z(ensemble, s, g, z):
    return eqx.filter_vmap(ensemble.classic_icvf)(s, g, z)

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None)) # V(s, g ,z ), g, z - dim 256
def eval_ensemble_icvf_latent_zz(ensemble, s, g, z):
    return eqx.filter_vmap(ensemble.icvf_zz)(s, g, z)
    
@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None))
def eval_ensemble_icvf_latent_zzz(ensemble, s, g, z):
    return eqx.filter_vmap(ensemble.icvf_zzz)(s, g, z)

@eqx.filter_jit
def get_gcvalue(agent, s, g, z):
    v_sgz_1, v_sgz_2 = eval_ensemble_icvf_viz(agent.value_learner.model, s, g, z)
    return (v_sgz_1 + v_sgz_2) / 2

def get_v_gz(agent, initial_state, target_goal, observations):
    initial_state = jnp.tile(initial_state, (observations.shape[0], 1))
    target_goal = jnp.tile(target_goal, (observations.shape[0], 1))
    return -1 * get_gcvalue(agent, initial_state, observations, target_goal)
    
def get_v_zz(agent, goal, observations):
    goal = jnp.tile(goal, (observations.shape[0], 1))
    return get_gcvalue(agent, observations, goal, goal)

@eqx.filter_vmap(in_axes=dict(agent=None, obs=None, goal=0))
def get_v_zz_heatmap(agent, obs, goal): # goal - traj
    goal = jnp.tile(goal, (obs.shape[0], 1))
    return get_gcvalue(agent, obs, goal, goal)

def expectile_loss(adv, diff, expectile=0.99):
    weight = jnp.where(adv >= 0, expectile, (1 - expectile))
    return weight * diff ** 2

def icvf_loss(value_fn, target_value_fn, batch, config):

    assert all([k in config for k in ['no_intent', 'min_q', 'expectile', 'discount']]), 'Missing ICVF config keys'

    if config['no_intent']:
        batch['desired_goals'] = jax.tree_map(jnp.ones_like, batch['desired_goals'])

    ###
    # Compute TD error for outcome s_+
    # 1(s == s_+) + V(s', s_+, z) - V(s, s_+, z)
    ###

    (next_v1_gz, next_v2_gz) = target_value_fn(batch['next_observations'], batch['goals'], batch['desired_goals'])
    q1_gz = batch['rewards'] + config['discount'] * batch['masks'] * next_v1_gz
    q2_gz = batch['rewards'] + config['discount'] * batch['masks'] * next_v2_gz
    q1_gz, q2_gz = jax.lax.stop_gradient(q1_gz), jax.lax.stop_gradient(q2_gz)

    (v1_gz, v2_gz) = value_fn(batch['observations'], batch['goals'], batch['desired_goals'])

    ###
    # Compute the advantage of s -> s' under z
    # r(s, z) + V(s', z, z) - V(s, z, z)
    ###

    (next_v1_zz, next_v2_zz) = target_value_fn(batch['next_observations'], batch['desired_goals'], batch['desired_goals'])
    if config['min_q']:
        next_v_zz = jnp.minimum(next_v1_zz, next_v2_zz)
    else:
        next_v_zz = (next_v1_zz + next_v2_zz) / 2
    
    q_zz = batch['desired_rewards'] + config['discount'] * batch['desired_masks'] * next_v_zz

    (v1_zz, v2_zz) = target_value_fn(batch['observations'], batch['desired_goals'], batch['desired_goals'])
    v_zz = (v1_zz + v2_zz) / 2
    adv = q_zz - v_zz

    if config['no_intent']:
        adv = jnp.zeros_like(adv)
    
    ###
    #
    # If advantage is positive (next state is better than current state), then place additional weight on
    # the value loss. 
    #
    ##
    value_loss1 = expectile_loss(adv, q1_gz-v1_gz, config['expectile']).mean()
    value_loss2 = expectile_loss(adv, q2_gz-v2_gz, config['expectile']).mean()
    value_loss = value_loss1 + value_loss2

    advantage = adv
    return value_loss, {
        'value_loss': value_loss,
        'v_gz max': v1_gz.max(),
        'v_gz min': v1_gz.min(),
        'v_zz': v_zz.mean(),
        'v_gz': v1_gz.mean(),
        'abs adv mean': jnp.abs(advantage).mean(),
        'reward mean': batch['rewards'].mean()
    }

def periodic_target_update(
    model: TrainState, target_model: TrainState, period: int
) -> TrainState:
    new_target_params = jax.tree_map(
        lambda p, tp: optax.periodic_update(p, tp, model.step, period),
        model.params, target_model.params
    )
    return target_model.replace(params=new_target_params)

@jax.jit
def step_fn(state_f, state_g, batch):
    """Step function of either training or validation."""
    grad_fn = jax.value_and_grad(loss_fn, argnums=[0, 1], has_aux=True)
    # compute loss and gradients
    (loss, (loss_f, loss_g, W2_dist)), (grads_f, grads_g) = grad_fn(
        state_f.params,
        state_g.params,
        f_value=state_f.potential_value_fn,
        f_gradient=state_f.potential_gradient_fn,
        g_value=state_g.potential_value_fn,
        g_gradient=state_g.potential_gradient_fn,
        batch=batch,
        to_optimize='both'
    )

    return (
        state_f.apply_gradients(grads=grads_f),
        state_g.apply_gradients(grads=grads_g), loss, loss_f, loss_g,
        W2_dist
    )
def loss_fn(params_f, params_g, f_value, f_gradient, g_value, g_gradient, batch, to_optimize,
            expectile=0.999):
      """Loss function for both potentials."""
      # get two distributions
      source, target = batch["source"], batch["target"]

      def g_value_partial(y: jnp.ndarray) -> jnp.ndarray:
        """Lazy way of evaluating g if f's computation needs it."""
        return g_value(params_g)(y)
      
      f_value_partial = f_value(params_f, g_value_partial)
      batch_dot = jax.vmap(jnp.dot)

      source_hat = g_gradient(params_g)(target) 
      source_hat_detach = jax.lax.stop_gradient(source_hat)
      target_hat = f_gradient(params_f)(source) 
      target_hat_detach = jax.lax.stop_gradient(target_hat)

      f_source = f_value_partial(source)
      f_star_target = batch_dot(source_hat_detach, target) - f_value_partial(source_hat_detach)
      g_target = g_value_partial(target)
      g_star_source = batch_dot(source, target_hat_detach) - g_value_partial(target_hat_detach)


      def conj_f(x, y):
        return batch_dot(x, y) - f_value(jax.lax.stop_gradient(params_f))(x) 
    
      def conj_g(x, y):
        return batch_dot(x, y) - g_value(jax.lax.stop_gradient(params_g))(y) 
      
      u_f = jax.lax.stop_gradient(conj_g(source, target_hat_detach)) - f_source
      f_loss = expectile_loss(u_f, u_f, expectile).mean() 

      u_g = jax.lax.stop_gradient(conj_f(source_hat_detach, target)) - g_target
      g_loss = expectile_loss(u_g, u_g, expectile).mean() 
      dual_loss = f_loss 
      amor_loss = g_loss 

      dual_loss = dual_loss + (f_source.mean() + f_star_target.mean()) 
      amor_loss = amor_loss - conj_f(source_hat, target).mean() 

      dual_loss = dual_loss - conj_g(source, target_hat).mean() 
      amor_loss = amor_loss + (g_target.mean() + g_star_source.mean())

      dual_loss = dual_loss / 2
      amor_loss = amor_loss / 2
      
      if to_optimize == "both":
        loss = dual_loss + amor_loss
      elif to_optimize == "f":
        loss = dual_loss
      elif to_optimize == "g":
        loss = amor_loss
      else:
        raise ValueError(
            f"Optimization target {to_optimize} has been misspecified."
        )
        
      # compute Wasserstein-2 distance
      C = jnp.mean(jnp.sum(source ** 2, axis=-1)) + \
          jnp.mean(jnp.sum(target ** 2, axis=-1))
      
      W2_dist = C - 2. * (f_source.mean() + g_target.mean())

      return loss, (dual_loss, amor_loss, W2_dist)

class JointTrainAgent(flax.struct.PyTreeNode):
    icvf_value: TrainState
    icvf_target_value: TrainState
    icvf_expert: ICVFAgent

    state_f: PotentialTrainState
    state_g: PotentialTrainState
    config: dict = nonpytree_field()
    
    @jax.jit
    def update(agent, agent_batch):
        def loss_fn(network_params):
            info = {}

            # ICVF loss
            value_fn = lambda s, g, z: agent.icvf_value(s, g, z, params=network_params)
            target_value_fn = lambda s, g, z: agent.icvf_target_value(s, g, z)
            ailot_loss, icvf_info = icvf_loss(value_fn, target_value_fn, agent_batch, agent.config)
            for k, v in icvf_info.items():
                info[f'icvf_value/{k}'] = v
            
            # OT loss
            psi_agent_1 = agent.icvf_value(agent_batch['observations'], method='get_psi').mean(0)
            psi_agent_2 = agent.icvf_value(agent_batch['next_observations'], method='get_psi').mean(0)
            
            intents_pair = jnp.concatenate((psi_agent_1, psi_agent_2), axis=-1) ** 2
            intents_loss = jnp.mean(jnp.sum(intents_pair, axis=-1)) - 2.0 * (agent.state_f.potential_value_fn(agent.state_f.params)(intents_pair)).mean()
            
            info.update({'OT_value/Intents_loss': intents_loss})
            info.update({'OT+ICVF_loss': ailot_loss + intents_loss})

            return ailot_loss + 0.01 * intents_loss, info
        
        if agent.config['periodic_target_update']:
            icvf_new_target_value = periodic_target_update(agent.icvf_value, agent.icvf_target_value, int(1.0 / agent.config['target_update_rate']))
        else:
            icvf_new_target_value = target_update(agent.icvf_value, agent.icvf_target_value, agent.config['target_update_rate'])

        icvf_new_value, icvf_info = agent.icvf_value.apply_loss_fn(loss_fn=loss_fn, has_aux=True)
        return agent.replace(icvf_value=icvf_new_value, icvf_target_value=icvf_new_target_value), icvf_info

    @jax.jit
    def update_not(agent, batch_agent, batch_expert):
        not_batch = {}
        not_info = {}

        source_agent_obs1 = agent.icvf_value(batch_agent['observations'], method='get_psi').mean(0)
        source_agent_obs2 = agent.icvf_value(batch_agent['next_observations'], method='get_psi').mean(0)

        target_expert_obs1 = agent.icvf_expert.value(batch_expert['observations'], method='get_psi').mean(0)
        target_expert_obs2 = agent.icvf_expert.value(batch_expert['next_observations'], method='get_psi').mean(0)

        not_batch['source'] = jnp.concatenate((source_agent_obs1, source_agent_obs2), axis=-1) # source intents
        not_batch['target'] = jnp.concatenate((target_expert_obs1, target_expert_obs2), axis=-1) # target intents
        
        (new_state_f, new_state_g, loss, loss_f, loss_g, w_dist) = step_fn(
            agent.state_f,
            agent.state_g,
            not_batch,
        )
        not_info.update({"Not_loss": loss,
                            "f_loss": loss_f,
                            "g_loss": loss_g,
                            "w_dist": w_dist})
        
        return agent.replace(state_f=new_state_f, state_g=new_state_g), not_info


def create_ailot_learner(
    seed: int,
    observations: np.array,
    icvf_expert: ICVFAgent,
    icvf_hypers: dict,
    lr: float = 1e-4,
    icvf_hidden_dims: Sequence[int] = (256, 256, 256),
    ot_mlp_hidden_dims: Sequence[int] = (512, 512, 512),
):
    rng = jax.random.PRNGKey(seed)
    rng_f, rng_g = jax.random.split(rng, 2)

    icvf_value_def = create_icvf("multilinear", hidden_dims=icvf_hidden_dims, use_layer_norm=True)
    
    neural_f = PotentialMLP(dim_hidden=ot_mlp_hidden_dims, act_fn=jax.nn.gelu, is_potential=True)
    neural_g = PotentialMLP(dim_hidden=ot_mlp_hidden_dims, act_fn=jax.nn.gelu, is_potential=True)
    lr_schedule = optax.cosine_decay_schedule(
            init_value=lr, decay_steps=1_000_000, alpha=1e-2
        )
    optimizer_f = optax.adamw(learning_rate=lr_schedule)
    optimizer_g = optax.adamw(learning_rate=lr_schedule)
    state_f = neural_f.create_train_state(
        rng_f,
        optimizer_f,
        (1, icvf_hidden_dims[0] * 2),
    )
    state_g = neural_g.create_train_state(
        rng_g,
        optimizer_g,
        (1, icvf_hidden_dims[0] * 2),
    )
    
    icvf_params = icvf_value_def.init(rng, observations, observations, observations)['params']
    icvf_value = TrainState.create(icvf_value_def, icvf_params, tx=optax.adam(**icvf_hypers['optim_kwargs']))
    icvf_target_value = TrainState.create(icvf_value_def, icvf_params)

    config = flax.core.FrozenDict(dict(
            discount=icvf_hypers['discount'],
            target_update_rate=icvf_hypers['target_update_rate'],
            expectile=icvf_hypers['expectile'],
            no_intent=icvf_hypers['no_intent'], 
            min_q=icvf_hypers['min_q'],
            periodic_target_update=icvf_hypers['periodic_target_update'],
        ))
    
    return JointTrainAgent(icvf_value=icvf_value, icvf_target_value=icvf_target_value,
                           icvf_expert=icvf_expert,
                           state_f=state_f,
                           state_g=state_g,
                           config=config)
    

@jax.tree_util.register_pytree_node_class
class MyCost(costs.CostFn):
    def __init__(self) -> None:
        super().__init__()
        self.cost = costs.SqEuclidean()

    def pairwise(self, x: ndarray, y: ndarray) -> float:
        d = self.cost(x, y)
        return jnp.minimum(5000, d)

class OTRewardsExpert:
    def __init__(
        self, expert_traj, icvf_model
    ):
        self.expert_states = expert_traj
        self.expert_z = eval_ensemble_psi(icvf_model.value_learner.model, expert_traj).mean(axis=0)
        self.sub_steps = 5
        self.icvf_model = icvf_model
        
    def make_subs(self, z, sub_steps):
        sub_indx = jnp.minimum(jnp.arange(0, z.shape[0]) + sub_steps, z.shape[0] - 1)
        return jax.tree_map(lambda arr: arr[sub_indx], z)

    @eqx.filter_jit
    def get_z_and_start_index(self, obs):
        # obs - trajectory
        z = eval_ensemble_psi(self.icvf_model.value_learner.model, obs).mean(axis=0)
        diff = z[0][jnp.newaxis,] - self.expert_z #eqx_get_state_traj(icvf_model.value_learner.model, z[0][None], self.expert_z).mean(1)#z[0][jnp.newaxis,] - self.expert_z
        i_min = jnp.argmin((diff**2).sum(-1)).squeeze() #jnp.argmin(diff, -1).squeeze()
        return z, i_min, diff

    def compute_rewards(
        self,
        dataset,
        gc_agent_dataset
    ):
        rewards = []
        observations = dataset.dataset_dict['observations']
        episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()
        
        for i1 in tqdm(range(len(episode_starts))):
            zi, start_index, diff = self.get_z_and_start_index(observations[episode_starts[i1]:episode_ends[i1]])
            ri = self.compute_rewards_one_episode(zi, self.expert_z[start_index:])
            #print(eval_ensemble_icvf_latent_zzz(icvf_model.value_learner.model, zi[0][None], self.expert_z[start_index][5][None], self.expert_z[start_index][5][None]).mean(0))
            rewards.append(jax.device_get(ri))
                  
        return np.concatenate(rewards)#, selected_index

    @eqx.filter_jit
    def compute_rewards_one_episode(
        self, episode_obs, expert_obs
    ):

        za_1 = episode_obs
        za_2 = self.make_subs(za_1, self.sub_steps)
        x = jnp.concatenate([za_1, za_2], axis=1)

        ze_1 = expert_obs
        ze_2 = self.make_subs(ze_1, self.sub_steps)
        y = jnp.concatenate([ze_1, ze_2], axis=1)
        
        geom = pointcloud.PointCloud(x, y, epsilon=0.001)
        ot_prob = linear_problem.LinearProblem(geom)
        solver = sinkhorn.Sinkhorn(max_iterations=250, use_danskin=True)

        ot_sink = solver(ot_prob)
        transp_cost = jnp.sum(ot_sink.matrix * geom.cost_matrix, axis=1)
        rewards = -transp_cost * episode_obs.shape[0] / 10

        return rewards