from functools import partial
from typing import Sequence

import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from ml_collections import ConfigDict

from bpref_v2.data.instruct import TASK_TO_MAX_EPISODE_STEPS
from bpref_v2.networks.cnn_reward_model import CNNRewardModel
from bpref_v2.third_party.openai.model import IMAGE_RESOLUTION, load_clip_model, load_liv_model
from bpref_v2.utils.jax_utils import TrainState, next_rng, sync_state_fn

from .core import RewardLearner


class R2RRankLearner(RewardLearner):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.lr = 3e-4
        config.embd_dim = 128
        config.optimizer_type = "adam"
        config.scheduler_type = "none"
        # transfer type
        config.transfer_type = "clip_vit_b16"

        config.image_keys = "image"
        config.num_images = 1
        config.window_size = 4
        config.features = (32, 64, 64)
        config.filters = (8, 4, 3)
        config.strides = (4, 2, 1)

        # Optimizer parameters
        config.adam_beta1 = 0.9
        config.adam_beta2 = 0.98
        config.weight_decay = 0.02
        config.max_grad_norm = 1.0

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(
        self,
        config: ConfigDict = None,
        env_name: str = "one_leg",
        observation_dim: Sequence[int] = (224, 224, 3),
        action_dim: int = 8,
        state: flax.training.train_state.TrainState = None,
        jax_devices: Sequence[jax.Device] = None,
    ):
        self.config = config
        self.config.max_episode_steps = TASK_TO_MAX_EPISODE_STEPS[env_name.split("|")[0]]
        self.network = self._define_network()
        self.observation_dim = observation_dim
        self.action_dim = action_dim
        self.pvr_model, self.pvr_model_var = self._load_pvr_network()

        self._total_steps = 0
        if state is None:
            state = self._init_train_state(jax_devices)
            self.train_pmap = jax.pmap(self._train_step, axis_name="pmap", devices=jax_devices)
            self.eval_pmap = jax.pmap(self._eval_step, axis_name="pmap", devices=jax_devices)

        self._train_states = {}
        model_keys = ["disc"]
        self._model_keys = tuple(model_keys)
        self.load_state(state, jax_devices=jax_devices)

    def load_state(self, state, jax_devices=None):
        if jax_devices is not None:
            state = flax.jax_utils.replicate(state, jax_devices)
            self._train_states["disc"] = sync_state_fn(state)
        else:
            self._train_states["disc"] = state["disc"]

    def _load_pvr_network(self):
        if self.config.transfer_type == "liv":
            clip_model, clip_model_var, _ = load_liv_model()
            self.image_size = 224
        elif self.config.transfer_type.startswith("clip"):
            clip_type = self.config.transfer_type.split("_", 1)[-1]
            clip_model, clip_model_var, _ = load_clip_model(clip_type)
            self.image_size = IMAGE_RESOLUTION[clip_type]
        self.config.vision_embd_dim = clip_model.vision_features
        return clip_model, clip_model_var

    def _define_network(self):
        # return PVRRewardModel(config=self.config)
        return CNNRewardModel(config=self.config)

    def _init_train_state(self, jax_devices):
        # num_patches = 1 + (
        #     self.image_size // self.pvr_model.vision_patch_size * self.image_size // self.pvr_model.vision_patch_size
        # )
        variables = self.network.init(
            {"params": next_rng(), "dropout": next_rng()},
            jnp.ones(
                # (self.config.num_images, 1, self.config.window_size, self.pvr_model.vision_features),
                (self.config.num_images, 1, self.config.window_size, *self.observation_dim),
                dtype=jnp.float32,
            ),
        )

        variables = flax.core.frozen_dict.unfreeze(variables)
        params = flax.core.frozen_dict.unfreeze(variables["params"])
        batch_stats = (
            flax.core.frozen_dict.unfreeze(variables["batch_stats"])
            if variables.get("batch_stats") is not None
            else None
        )

        tx = optax.adam(learning_rate=self.config.lr, b1=self.config.adam_beta1, b2=self.config.adam_beta2)
        states = TrainState.create(params=params, batch_stats=batch_stats, tx=tx, apply_fn=self.network.apply)
        return states

    @partial(jax.jit, static_argnames=("self"))
    def _eval_pref_step(self, train_states, rng, batch):
        return None

    @partial(jax.jit, static_argnames=("self"))
    def _train_pref_step(self, train_states, rng, batch):
        return None

    def get_reward(self, batch, get_video_feature=False, get_text_feature=False):
        return np.asarray(
            self._get_reward_step(
                self._train_states, batch, get_video_feature=get_video_feature, get_text_feature=get_text_feature
            )
        )

    @partial(jax.jit, static_argnames=("self", "get_video_feature", "get_text_feature"))
    def _get_reward_step(self, train_states, batch, get_video_feature=False, get_text_feature=False):
        obs = jnp.array(list(batch["image"].values()))
        pvr_image_feature = self._get_pvr_feature(obs) if obs.ndim == 6 else obs
        image_feature = pvr_image_feature

        return jax.nn.sigmoid(self._extract_score(train_states["disc"].params, image_feature, training=False))

    def _get_pvr_feature(self, images):
        return images

    # def _get_pvr_feature(self, images):
    #     original_shape = images.shape[:-3]
    #     images = jnp.reshape(images, (-1,) + images.shape[-3:])
    #     images = (images / 255.0).astype(jnp.float32)
    #     if images.shape[-3] != 224:
    #         images = jax.image.resize(
    #             images, (images.shape[0], 224, 224, images.shape[-1]), method="bicubic"
    #         )  # to meet the input size of the clip model
    #     images = normalize_image(images)
    #     image_feature = self.pvr_model.apply(
    #         self.pvr_model_var,
    #         images,
    #         method=self.pvr_model.encode_image,
    #         normalize=False,
    #     )[:, 0]
    #     # image_feat = image_feature_map[:, 0]
    #     image_feature = jnp.reshape(image_feature, original_shape + (-1,))
    #     return image_feature
    #     # image_feature = jnp.reshape(image_feature, original_shape + image_feature.shape[-2:])
    #     # return image_feature

    def _extract_score(self, params, image_feature, training=False):
        reward = self.network.apply(
            {"params": params}, image_feature, training=training, method=self.network.predict_reward
        )
        return reward

    def loss_fn(self, params, batch, rng, training=False):
        obs = jnp.array(list(batch["image"].values()))
        num_image, batch_size, seq_length = obs.shape[:3]
        image_feature = self._get_pvr_feature(obs)

        expert_logits_t = self._extract_score(params, image_feature, training=training)

        other_obs = jnp.array(list(batch["random_next_image"].values()))
        other_image_feature = self._get_pvr_feature(other_obs)
        expert_logits_other_t = self._extract_score(params, other_image_feature, training=training)

        expert_logits = jnp.concatenate([expert_logits_t, expert_logits_other_t], axis=-1)
        disc_labels = jnp.ones((batch_size,), dtype=jnp.int32)

        # disc_labels = jnp.concatenate(
        #     [jnp.zeros((batch_size,), dtype=jnp.int32), jnp.ones((batch_size,), dtype=jnp.int32)], axis=0
        # )

        # def custom_sigmoid_binary_cross_entropy(logits, labels):
        #     log_p = jax.nn.log_sigmoid(logits)
        #     log_not_p = jax.nn.log_sigmoid(-logits)
        #     return -labels * log_p - (1.0 - labels) * log_not_p

        # ce_loss = lambda x, y: jnp.mean(custom_sigmoid_binary_cross_entropy(logits=x, labels=y))
        # score = self._extract_score(params, total_image_feature, training=training)
        rank_loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits=expert_logits, labels=disc_labels))
        # disc_loss = ce_loss(score, disc_labels)
        aux = {"rank_loss": rank_loss, "total_loss": rank_loss}

        return rank_loss, aux

    def train_step(self, batch, rng, neg_batch=None):
        self._total_steps += 1
        self._train_states, metrics, rng = self.train_pmap(self._train_states, batch, rng)
        return metrics, rng

    @partial(jax.jit, static_argnames=("self"))
    def _train_step(self, train_states, batch, rng):
        next_rng, split_rng = jax.random.split(rng)
        grad_fn = jax.value_and_grad(self.loss_fn, has_aux=True)
        (loss, metrics), grads = jax.lax.pmean(
            grad_fn(train_states["disc"].params, batch, split_rng, training=True),
            axis_name="pmap",
        )

        new_train_states = {"disc": train_states["disc"].apply_gradients(grads=grads)}
        return new_train_states, metrics, next_rng

    def eval_step(self, batch, rng, neg_batch=None):
        metrics, rng = self.eval_pmap(self._train_states, batch, rng)
        return metrics, rng

    def _eval_step(self, train_states, batch, rng):
        next_rng, split_rng = jax.random.split(rng)
        _, metrics = jax.lax.pmean(
            self.loss_fn(train_states["disc"].params, batch, split_rng, training=False),
            axis_name="pmap",
        )
        return metrics, next_rng
