from functools import partial

import flax
import jax
import jax.numpy as jnp
import optax
from ml_collections import ConfigDict

from bpref_v2.networks.trans_reward_model import ARPV1RewardModel
from bpref_v2.third_party.openai.model import load_liv_model
from bpref_v2.utils.jax_utils import (
    TrainState,
    mse_loss,
    next_rng,
    value_and_multi_grad,
)

from .core import RewardLearner


class ARPV1Learner(RewardLearner):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.lr = 1e-4
        config.optimizer_type = "adamw"
        config.scheduler_type = "CosineDecay"
        config.vocab_size = 1
        config.n_layer = 3
        config.embd_dim = 256
        config.n_embd = config.embd_dim
        config.n_head = 1
        config.n_positions = 1024
        config.resid_pdrop = 0.1
        config.attn_pdrop = 0.1
        config.pref_attn_embd_dim = 256

        config.train_type = "last"

        # Weighted Sum option
        config.use_weighted_sum = False

        config.activation = "relu"
        config.activation_final = "none"

        # transfer type
        config.transfer_type = "liv"

        # frozen visual/textual represntations
        config.visual_only = False
        config.frozen_visual = False
        config.frozen_textual = False

        # Make Bidirectional Transformer for temporal understanding.
        config.use_bidirectional = True

        # liv
        config.gamma = 0.98
        config.epsilon = 1e-8
        config.lambda_liv = 0.0

        # IDM
        config.lambda_idm = 1.0

        # Contrastive
        config.lambda_contrastive = 0.0

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, observation_dim, action_dim, num_ensembles=1):
        self.config = config
        if num_ensembles == 1:
            self.network = self._define_network(observation_dim, action_dim)
        else:
            self.network = self._define_network(observation_dim, action_dim, num_ensembles)
        self.observation_dim = observation_dim
        self.action_dim = action_dim
        self.num_ensembles = num_ensembles
        self._train_states = {}

        optimizer_class = {
            "adam": optax.adam,
            "adamw": optax.adamw,
            "sgd": optax.sgd,
        }[self.config.optimizer_type]

        scheduler_class = {
            "CosineDecay": lambda lr: optax.warmup_cosine_decay_schedule(
                init_value=0.0,
                peak_value=lr,
                warmup_steps=self.config.warmup_steps,
                decay_steps=self.config.total_steps,
                end_value=lr * 0.01,
            ),
            "OnlyWarmup": lambda lr: optax.join_schedules(
                [
                    optax.linear_schedule(
                        init_value=0.0,
                        end_value=lr,
                        transition_steps=self.config.warmup_steps,
                    ),
                    optax.constant_schedule(value=lr),
                ],
                [self.config.warmup_steps],
            ),
            "none": None,
        }[self.config.scheduler_type]

        def flattened_traversal(fn):
            def mask(data):
                flat = flax.traverse_util.flatten_dict(flax.core.frozen_dict.unfreeze(data))
                return flax.traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})

            return mask

        self.tx = optax.chain(
            optax.masked(
                optimizer_class(scheduler_class(self.config.lr)),
                mask=flattened_traversal(lambda path, _: "clip_model" not in path),
            ),
            optax.masked(
                optimizer_class(scheduler_class(self.config.lr * 0.01)),
                mask=flattened_traversal(lambda path, _: "clip_model" in path),
            ),
        )

        self._total_steps = 0
        self._init_train_state()

    def _define_network(self, observation_dim, action_dim):
        clip_model, self.clip_model_var, logit_scale_val = load_liv_model()
        return ARPV1RewardModel(
            clip_model=clip_model,
            logit_scale_val=logit_scale_val,
            config=self.config,
            observation_dim=observation_dim,
            action_dim=action_dim,
            activation=self.config.activation,
            activation_final=self.config.activation_final,
        )

    def _init_train_state(self):
        variables = self.network.init(
            {"params": next_rng(), "dropout": next_rng()},
            jnp.zeros((1, *self.observation_dim)),
            jnp.zeros((1, 77), dtype=jnp.int32),
        )

        variables = flax.core.frozen_dict.unfreeze(variables)
        variables["params"]["clip_model"] = self.clip_model_var["params"]
        variables["batch_stats"]["clip_model"] = self.clip_model_var["batch_stats"]

        params = flax.core.frozen_dict.unfreeze(variables["params"])
        batch_stats = flax.core.frozen_dict.unfreeze(variables["batch_stats"])

        states = TrainState.create(params=params, batch_stats=batch_stats, tx=self.tx, apply_fn=self.network.apply)
        self._train_states["trans"] = states

        model_keys = ["trans"]
        self._model_keys = tuple(model_keys)

    @partial(jax.jit, static_argnames=("self"))
    def _get_reward_step(self, train_states, batch):
        obs = batch["images"]
        tokens = batch["instruct"]

        encoded_images = self.network.apply(
            {"params": train_states["trans"].params, "batch_stats": train_states["trans"].batch_stats},
            obs,
            training=False,
            method=self.network.encode_image,
        )
        encoded_texts = self.network.apply(
            {"params": train_states["trans"].params, "batch_stats": train_states["trans"].batch_stats},
            tokens,
            training=False,
            method=self.network.encode_text,
        )
        rewards = self.network.apply(
            {"params": train_states["trans"].params, "batch_stats": train_states["trans"].batch_stats},
            encoded_images,
            encoded_texts,
            method=self.network.image_score,
        )
        return jnp.diag(rewards)

    @partial(jax.jit, static_argnames=("self"))
    def _eval_pref_step(self, train_states, rng, batch):
        return None

    @partial(jax.jit, static_argnames=("self"))
    def _train_pref_step(self, train_states, rng, batch):
        return None

    def train_arp(self, batch):
        self._total_steps += 1
        self._train_states, metrics = self._train_arp_step(self._train_states, next_rng(), batch)
        return metrics

    @partial(jax.jit, static_argnames=("self"))
    def _train_arp_step(self, train_states, rng, batch):
        def loss_fn(train_params, rng):
            obs = batch["images"][:, -1]
            next_obs = batch["next_image"]
            initial = batch["initial_image"]
            goal = batch["goal_image"]
            act = batch["actions"][:, -1]
            tokens = batch["instruct"][:, -1]

            B = act.shape[0]

            liv_loss, contrastive_loss, idm_loss = 0.0, 0.0, 0.0
            param_dict = {"params": train_params["trans"]}
            if train_states["trans"].batch_stats is not None:
                param_dict.update({"batch_stats": train_states["trans"].batch_stats})

            # LIV-I Loss for fine-tuning in-domain dataset.
            clip_vis_obs = self.network.apply(
                param_dict,
                obs,
                normalize=False,
                rngs={"dropout": rng},
                method=self.network.get_clip_visual_feature,
            )

            if self.config.lambda_liv > 0:
                clip_vis_next_obs = self.network.apply(
                    param_dict,
                    next_obs,
                    normalize=True,
                    rngs={"dropout": rng},
                    method=self.network.get_clip_visual_feature,
                )
                clip_vis_initial = self.network.apply(
                    param_dict,
                    initial,
                    normalize=True,
                    rngs={"dropout": rng},
                    method=self.network.get_clip_visual_feature,
                )
                clip_vis_goal = self.network.apply(
                    param_dict,
                    goal,
                    normalize=True,
                    rngs={"dropout": rng},
                    method=self.network.get_clip_visual_feature,
                )

                clip_vis_current_obs = clip_vis_obs / jnp.linalg.norm(clip_vis_obs, axis=-1, keepdims=True)
                clip_vis_next_obs_normalized = clip_vis_next_obs / jnp.linalg.norm(
                    clip_vis_next_obs, axis=-1, keepdims=True
                )

                V_0 = clip_vis_initial @ clip_vis_goal.T / (1 - self.config.gamma)
                V_s = clip_vis_current_obs @ clip_vis_goal.T / (1 - self.config.gamma)
                V_s_next = clip_vis_next_obs_normalized @ clip_vis_goal.T / (1 - self.config.gamma)

                liv_loss = (1 - self.config.gamma) * -V_0.mean() + jnp.log(
                    self.config.epsilon + jnp.mean(jnp.exp(-(batch["r"] + self.config.gamma * V_s_next - V_s)))
                )

            encoded_images = self.network.apply(
                param_dict,
                None,
                training=True,
                image_features=clip_vis_obs,
                rngs={"dropout": rng},
                method=self.network.encode_image,
            )

            encoded_texts, updates_2 = self.network.apply(
                param_dict,
                tokens,
                training=True,
                rngs={"dropout": rng},
                method=self.network.encode_text,
                mutable=["batch_stats"],
            )

            if self.config.lambda_contrastive > 0:
                image_score = self.network.apply(
                    param_dict,
                    encoded_images,
                    encoded_texts,
                    rngs={"dropout": rng},
                    method=self.network.image_score,
                )

                text_score = self.network.apply(
                    param_dict,
                    encoded_images,
                    encoded_texts,
                    rngs={"dropout": rng},
                    method=self.network.text_score,
                )

                labels = jnp.arange(B, dtype=jnp.int32)
                label_target = jax.lax.stop_gradient(labels)
                ce_loss = lambda x, y: jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits=x, labels=y))
                contrastive_loss = (ce_loss(image_score, label_target) + ce_loss(text_score, label_target)) / 2

            loss_collection = {}

            if self.config.lambda_idm > 0:
                encoded_before_images = self.network.apply(
                    param_dict,
                    batch["images"][:, -2],
                    training=True,
                    rngs={"dropout": rng},
                    method=self.network.encode_image,
                )
                a_hat = self.network.apply(
                    param_dict,
                    encoded_before_images,
                    encoded_images,
                    encoded_texts,
                    training=True,
                    rngs={"dropout": rng},
                    method=self.network.predict_action,
                )
                idm_loss = mse_loss(a_hat, act)

            loss = (
                self.config.lambda_liv * liv_loss
                + self.config.lambda_contrastive * contrastive_loss
                + self.config.lambda_idm * idm_loss
            )
            loss_collection["trans"] = loss

            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key]) for i, key in enumerate(self.model_keys)
        }

        new_train_states = {
            key: new_train_states[key].replace(batch_stats=aux_values["updates_2"]["batch_stats"])
            for _, key in enumerate(self.model_keys)
        }

        metrics = dict(
            liv_loss=aux_values["liv_loss"],
            contrastive_loss=aux_values["contrastive_loss"],
            idm_loss=aux_values["idm_loss"],
            loss=aux_values["loss"],
        )

        return new_train_states, metrics

    def evaluation_arp(self, batch):
        metrics = self._eval_arp_step(self._train_states, next_rng(), batch)
        return metrics

    @partial(jax.jit, static_argnames=("self"))
    def _eval_arp_step(self, train_states, rng, batch):
        def loss_fn(train_params, rng):
            obs = batch["images"][:, -1]
            next_obs = batch["next_image"]
            initial = batch["initial_image"]
            goal = batch["goal_image"]
            act = batch["actions"][:, -1]
            tokens = batch["instruct"][:, -1]

            B = act.shape[0]

            liv_loss, contrastive_loss, idm_loss = 0.0, 0.0, 0.0
            param_dict = {"params": train_params["trans"]}
            if train_states["trans"].batch_stats is not None:
                param_dict.update({"batch_stats": train_states["trans"].batch_stats})

            # LIV-I Loss for fine-tuning in-domain dataset.
            clip_vis_obs = self.network.apply(
                param_dict,
                obs,
                normalize=True,
                rngs={"dropout": rng},
                method=self.network.get_clip_visual_feature,
            )

            if self.config.lambda_liv > 0:
                clip_vis_next_obs = self.network.apply(
                    param_dict,
                    next_obs,
                    normalize=True,
                    rngs={"dropout": rng},
                    method=self.network.get_clip_visual_feature,
                )
                clip_vis_initial = self.network.apply(
                    param_dict,
                    initial,
                    normalize=True,
                    rngs={"dropout": rng},
                    method=self.network.get_clip_visual_feature,
                )
                clip_vis_goal = self.network.apply(
                    param_dict,
                    goal,
                    normalize=True,
                    rngs={"dropout": rng},
                    method=self.network.get_clip_visual_feature,
                )

                clip_vis_current_obs = clip_vis_obs / jnp.linalg.norm(clip_vis_obs, axis=-1, keepdims=True)
                clip_vis_next_obs_normalized = clip_vis_next_obs / jnp.linalg.norm(
                    clip_vis_next_obs, axis=-1, keepdims=True
                )

                V_0 = clip_vis_initial @ clip_vis_goal.T / (1 - self.config.gamma)
                V_s = clip_vis_current_obs @ clip_vis_goal.T / (1 - self.config.gamma)
                V_s_next = clip_vis_next_obs_normalized @ clip_vis_goal.T / (1 - self.config.gamma)

                liv_loss = (1 - self.config.gamma) * -V_0.mean() + jnp.log(
                    self.config.epsilon + jnp.mean(jnp.exp(-(batch["r"] + self.config.gamma * V_s_next - V_s)))
                )

            encoded_images = self.network.apply(
                param_dict,
                None,
                training=True,
                image_features=clip_vis_obs,
                rngs={"dropout": rng},
                method=self.network.encode_image,
            )

            encoded_texts, updates_2 = self.network.apply(
                param_dict,
                tokens,
                training=True,
                rngs={"dropout": rng},
                method=self.network.encode_text,
                mutable=["batch_stats"],
            )

            if self.config.lambda_contrastive > 0:
                image_score = self.network.apply(
                    param_dict,
                    encoded_images,
                    encoded_texts,
                    rngs={"dropout": rng},
                    method=self.network.image_score,
                )

                text_score = self.network.apply(
                    param_dict,
                    encoded_images,
                    encoded_texts,
                    rngs={"dropout": rng},
                    method=self.network.text_score,
                )

                labels = jnp.arange(B, dtype=jnp.int32)
                label_target = jax.lax.stop_gradient(labels)
                ce_loss = lambda x, y: jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits=x, labels=y))
                contrastive_loss = (ce_loss(image_score, label_target) + ce_loss(text_score, label_target)) / 2

            loss_collection = {}

            if self.config.lambda_idm > 0:
                encoded_before_images = self.network.apply(
                    param_dict,
                    batch["images"][:, -2],
                    training=True,
                    rngs={"dropout": rng},
                    method=self.network.encode_image,
                )
                a_hat = self.network.apply(
                    param_dict,
                    encoded_before_images,
                    encoded_images,
                    encoded_texts,
                    training=True,
                    rngs={"dropout": rng},
                    method=self.network.predict_action,
                )
                idm_loss = mse_loss(a_hat, act)

            loss = (
                self.config.lambda_liv * liv_loss
                + self.config.lambda_contrastive * contrastive_loss
                + self.config.lambda_idm * idm_loss
            )
            loss_collection["trans"] = loss

            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), _ = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)

        metrics = dict(
            liv_loss=aux_values["liv_loss"],
            contrastive_loss=aux_values["contrastive_loss"],
            idm_loss=aux_values["idm_loss"],
            loss=aux_values["loss"],
        )

        return metrics

    @partial(jax.jit, static_argnames=("self"))
    def _train_semi_pref_step(self, train_states, rng, labeled_batch, unlabeled_batch, lmd, tau):
        return None
