from functools import partial

import jax
import jax.numpy as jnp
from ml_collections import ConfigDict

from bpref_v2.third_party.openai.model import (
    load_clip_model,
    load_clip_model_with_adapter,
    load_liv_model,
    normalize_image,
)
from bpref_v2.utils.jax_utils import cos_sim

from .core import RewardLearner


class CLIPLearner(RewardLearner):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()

        # transfer type
        config.transfer_type = "liv"

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, **kwargs):
        self.config = config
        self.network = self._define_network()
        self._init_train_state()

    def _define_network(self):
        if self.config.transfer_type == "liv":
            clip_model, self.clip_model_var, _ = load_liv_model()
        elif self.config.transfer_type.startswith("clip"):
            clip_type = self.config.transfer_type.split("_", 1)[-1]
            if "adapter" in self.config.transfer_type:
                clip_model, self.clip_model_var, _ = load_clip_model_with_adapter(clip_type)
            else:
                clip_model, self.clip_model_var, _ = load_clip_model(clip_type)

        return clip_model
        # bounded_model = clip_model.bind(self.clip_model_var)
        # return bounded_model

    def _init_train_state(self):
        self._train_states = {}

    @partial(jax.jit, static_argnames=("self"))
    def _get_reward_step(self, _, batch):
        images = jnp.array(list(batch["image"].values()))
        tokens = batch["instruct"]

        if images.ndim == 6:  # image: (num_images, batch_size, num_timestep, H, W, C)
            images = images[:, :, -1]
        num_images = images.shape[0]
        images = jnp.reshape(images, (-1,) + images.shape[-3:])
        images = jax.image.resize(
            images, (images.shape[0], 224, 224, images.shape[-1]), method="bicubic"
        )  # to meet the input size of the clip model
        images = (images / 255.0).astype(jnp.float32)
        images = normalize_image(images)
        image_features = self.network.apply(self.clip_model_var, images, method=self.network.encode_image)

        if tokens.ndim == 3:
            batch_size, num_sentences = tokens.shape[:2]
            tokens = jnp.reshape(tokens, (-1, tokens.shape[-1]))
            text_features = self.network.apply(self.clip_model_var, tokens, method=self.network.encode_text)
            text_features = text_features.reshape(batch_size, num_sentences, -1)
            text_features = jnp.mean(text_features, axis=1)
        else:
            text_features = self.network.apply(self.clip_model_var, tokens, method=self.network.encode_text)
        text_features = jnp.tile(text_features, (num_images, 1))

        rewards = jnp.diag(cos_sim(image_features, text_features))
        if self.config.transfer_type.startswith("clip"):
            rewards *= jnp.exp(self.network.apply(self.clip_model_var, method=self.network.get_logit_scale))
        rewards = jnp.reshape(rewards, (num_images, -1))
        return jnp.mean(rewards, axis=0)

    def get_visual_text_feature(self, batch):
        return self._get_visual_text_feature(self._train_states, batch)

    @partial(jax.jit, static_argnames=("self"))
    def _get_visual_text_feature(self, train_states, batch):
        obs = batch["image"]
        tokens = batch["instruct"]

        res = {}
        image_res = {}
        for key in obs.keys():
            image = obs[key]
            if image.ndim == 5:
                batch_size, seq_length = image.shape[:2]
            image = jnp.reshape(image, (-1,) + image.shape[-3:])
            image = jax.image.resize(
                image, (image.shape[0], 224, 224, image.shape[-1]), method="bicubic"
            )  # to meet the input size of the clip model
            image = (image / 255.0).astype(jnp.float32)
            image = normalize_image(image)
            image_feat = self.network.apply(self.clip_model_var, image, method=self.network.encode_image, normalize=False)
            image_feat = image_feat.reshape(batch_size, seq_length, -1)
            image_res[key] = image_feat

        if tokens.ndim == 3:
            batch_size, seq_length = tokens.shape[:2]
            tokens = jnp.reshape(tokens, (-1, tokens.shape[-1]))
            text_feat = self.network.apply(self.clip_model_var, tokens, method=self.network.encode_text, normalize=False)
            text_feat = text_feat.reshape(batch_size, seq_length, -1)
        else:
            text_feat = self.network.apply(self.clip_model_var, tokens, method=self.network.encode_text, normalize=False)

        res["image"] = image_res
        res["instruct"] = text_feat

        return res

    @partial(jax.jit, static_argnames=("self"))
    def _eval_pref_step(self, train_states, rng, batch):
        return None

    @partial(jax.jit, static_argnames=("self"))
    def _train_pref_step(self, train_states, rng, batch):
        return None

    @partial(jax.jit, static_argnames=("self"))
    def _train_semi_pref_step(self, train_states, rng, labeled_batch, unlabeled_batch, lmd, tau):
        return None
