from functools import partial
from typing import Sequence

import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from ml_collections import ConfigDict

from bpref_v2.data.instruct import TASK_TO_MAX_EPISODE_STEPS
from bpref_v2.networks.cnn_reward_model import PVRRewardModel
from bpref_v2.third_party.openai.model import IMAGE_RESOLUTION, load_clip_model, load_liv_model, normalize_image
from bpref_v2.utils.jax_utils import TrainState, next_rng, sync_state_fn

from .core import RewardLearner


class DiscriminatorLearner(RewardLearner):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.lr = 3e-4
        config.embd_dim = 512
        config.optimizer_type = "adam"
        config.scheduler_type = "none"
        # transfer type
        config.transfer_type = "clip_vit_b16"

        config.image_keys = "image"
        config.num_images = 1
        config.window_size = 4
        config.features = (32, 32, 32, 32)
        config.filters = (2, 1, 1, 1)
        config.strides = (2, 1, 1, 1)

        # Optimizer parameters
        config.adam_beta1 = 0.9
        config.adam_beta2 = 0.98
        config.weight_decay = 0.02
        config.max_grad_norm = 1.0

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(
        self,
        config: ConfigDict = None,
        env_name: str = "one_leg",
        observation_dim: Sequence[int] = (224, 224, 3),
        action_dim: int = 8,
        state: flax.training.train_state.TrainState = None,
        jax_devices: Sequence[jax.Device] = None,
    ):
        self.config = config
        self.config.max_episode_steps = TASK_TO_MAX_EPISODE_STEPS[env_name.split("|")[0]]
        self.network = self._define_network()
        self.observation_dim = observation_dim
        self.action_dim = action_dim
        self.pvr_model, self.pvr_model_var = self._load_pvr_network()

        self._total_steps = 0
        if state is None:
            state = self._init_train_state(jax_devices)
            self.train_pmap = jax.pmap(self._train_step, axis_name="pmap", devices=jax_devices)
            self.eval_pmap = jax.pmap(self._eval_step, axis_name="pmap", devices=jax_devices)

        self._train_states = {}
        model_keys = ["disc"]
        self._model_keys = tuple(model_keys)
        self.load_state(state, jax_devices=jax_devices)

    def load_state(self, state, jax_devices=None):
        if jax_devices is not None:
            state = flax.jax_utils.replicate(state, jax_devices)
            self._train_states["disc"] = sync_state_fn(state)
        else:
            self._train_states["disc"] = state["disc"]

    def _load_pvr_network(self):
        if self.config.transfer_type == "liv":
            clip_model, clip_model_var, _ = load_liv_model()
            self.image_size = 224
        elif self.config.transfer_type.startswith("clip"):
            clip_type = self.config.transfer_type.split("_", 1)[-1]
            clip_model, clip_model_var, _ = load_clip_model(clip_type)
            self.image_size = IMAGE_RESOLUTION[clip_type]
        self.config.vision_embd_dim = clip_model.vision_features
        return clip_model, clip_model_var

    def _define_network(self):
        return PVRRewardModel(config=self.config)

    def _init_train_state(self, jax_devices):
        num_patches = 1 + (
            self.image_size // self.pvr_model.vision_patch_size * self.image_size // self.pvr_model.vision_patch_size
        )
        variables = self.network.init(
            {"params": next_rng(), "dropout": next_rng()},
            jnp.ones(
                (self.config.num_images, 1, self.config.window_size, num_patches, self.pvr_model.vision_features),
                # (self.config.num_images, 1, self.config.window_size, self.pvr_model.vision_features),
                dtype=jnp.float32,
            ),
        )

        variables = flax.core.frozen_dict.unfreeze(variables)
        params = flax.core.frozen_dict.unfreeze(variables["params"])
        batch_stats = (
            flax.core.frozen_dict.unfreeze(variables["batch_stats"])
            if variables.get("batch_stats") is not None
            else None
        )

        optimizer_class = {
            "adam": optax.adam,
            "adamw": partial(
                optax.adamw,
                weight_decay=self.config.weight_decay,
                b1=self.config.adam_beta1,
                b2=self.config.adam_beta2,
                # mask=get_weight_decay_mask,
            ),
            "sgd": optax.sgd,
        }[self.config.optimizer_type]

        partition_optimizers = {
            "trainable": optimizer_class(self.config.lr),
            "adapter": optimizer_class(self.config.lr * 0.1),
            # "phase_predictor": optimizer_class(self.config.lr),
            "frozen": optax.set_to_zero(),
        }

        def param_partition_condition(path, _):
            return "trainable"

        param_partitions = flax.traverse_util.path_aware_map(param_partition_condition, params)
        tx = optax.chain(
            optax.clip_by_global_norm(self.config.max_grad_norm),
            optax.multi_transform(partition_optimizers, param_partitions),
        )
        states = TrainState.create(params=params, batch_stats=batch_stats, tx=tx, apply_fn=self.network.apply)
        return states

    @partial(jax.jit, static_argnames=("self"))
    def _eval_pref_step(self, train_states, rng, batch):
        return None

    @partial(jax.jit, static_argnames=("self"))
    def _train_pref_step(self, train_states, rng, batch):
        return None

    def get_reward(self, batch, get_video_feature=False, get_text_feature=False):
        return np.asarray(
            self._get_reward_step(
                self._train_states, batch, get_video_feature=get_video_feature, get_text_feature=get_text_feature
            )
        )

    @partial(jax.jit, static_argnames=("self", "get_video_feature", "get_text_feature"))
    def _get_reward_step(self, train_states, batch, get_video_feature=False, get_text_feature=False):
        obs = jnp.array(list(batch["image"].values()))
        pvr_image_feature = self._get_pvr_feature(obs) if obs.ndim == 6 else obs
        image_feature = pvr_image_feature

        return self._extract_score(train_states["disc"].params, image_feature, training=False)

    def _get_pvr_feature(self, images):
        original_shape = images.shape[:-3]
        images = jnp.reshape(images, (-1,) + images.shape[-3:])
        images = (images / 255.0).astype(jnp.float32)
        if images.shape[-3] != 224:
            images = jax.image.resize(
                images, (images.shape[0], 224, 224, images.shape[-1]), method="bicubic"
            )  # to meet the input size of the clip model
        images = normalize_image(images)
        image_feature_map = self.pvr_model.apply(
            self.pvr_model_var,
            images,
            method=self.pvr_model.encode_image,
            normalize=False,
        )
        # image_feat = image_feature_map[:, 0]
        # image_feat = jnp.reshape(image_feat, original_shape + (-1,))
        # return image_feat
        image_feature_map = jnp.reshape(image_feature_map, original_shape + image_feature_map.shape[-2:])
        return image_feature_map

    def _extract_score(self, params, image_feature, training=False):
        reward = self.network.apply(
            {"params": params}, image_feature, training=training, method=self.network.predict_reward
        )
        return reward

    def loss_fn(self, params, batch, rng, neg_batch=None, training=False):
        obs = jnp.array(list(batch["image"].values()))
        neg_obs = jnp.array(list(neg_batch["image"].values()))
        num_image, batch_size, seq_length = obs.shape[:3]
        image_feature = self._get_pvr_feature(obs)
        neg_image_feature = self._get_pvr_feature(neg_obs)
        total_image_feature = jnp.concatenate([image_feature, neg_image_feature], axis=1)

        disc_labels = jnp.concatenate(
            [jnp.zeros((batch_size,), dtype=jnp.int32), jnp.ones((batch_size,), dtype=jnp.int32)], axis=0
        )

        def custom_sigmoid_binary_cross_entropy(logits, labels):
            log_p = jax.nn.log_sigmoid(logits)
            log_not_p = jax.nn.log_sigmoid(-logits)
            return -labels * log_p - (1.0 - labels) * log_not_p

        ce_loss = lambda x, y: jnp.mean(custom_sigmoid_binary_cross_entropy(logits=x, labels=y))
        score = self._extract_score(params, total_image_feature, training=training)
        disc_loss = ce_loss(score, disc_labels)
        aux = {"disc_loss": disc_loss, "total_loss": disc_loss}

        return disc_loss, aux

    def train_step(self, batch, rng, neg_batch=None):
        self._total_steps += 1
        self._train_states, metrics, rng = self.train_pmap(self._train_states, batch, rng, neg_batch=neg_batch)
        return metrics, rng

    @partial(jax.jit, static_argnames=("self"))
    def _train_step(self, train_states, batch, rng, neg_batch=None):
        next_rng, split_rng = jax.random.split(rng)
        grad_fn = jax.value_and_grad(self.loss_fn, has_aux=True)
        (loss, metrics), grads = jax.lax.pmean(
            grad_fn(train_states["disc"].params, batch, split_rng, neg_batch=neg_batch, training=True),
            axis_name="pmap",
        )

        new_train_states = {"disc": train_states["disc"].apply_gradients(grads=grads)}
        return new_train_states, metrics, next_rng

    def eval_step(self, batch, rng, neg_batch=None):
        metrics, rng = self.eval_pmap(self._train_states, batch, rng, neg_batch=neg_batch)
        return metrics, rng

    def _eval_step(self, train_states, batch, rng, neg_batch=None):
        next_rng, split_rng = jax.random.split(rng)
        _, metrics = jax.lax.pmean(
            self.loss_fn(train_states["disc"].params, batch, split_rng, neg_batch=neg_batch, training=False),
            axis_name="pmap",
        )
        return metrics, next_rng
