import random
from functools import partial
from typing import Sequence

import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from ml_collections import ConfigDict

from bpref_v2.data.instruct import TASK_TO_MAX_EPISODE_STEPS, TASK_TO_PHASE

from bpref_v2.networks.cnn_reward_model import MultiStageCNNRewardModel

# from bpref_v2.networks.cnn_reward_model import MultiStagePVRRewardModel
from bpref_v2.third_party.openai.model import IMAGE_RESOLUTION, load_clip_model, load_liv_model, normalize_image
from bpref_v2.utils.jax_utils import JaxRNG, TrainState, next_rng, sync_state_fn

from .core import RewardLearner


def random_crop(key, img, padding):
    crop_from = jax.random.randint(key, (2,), 0, 2 * padding + 1)
    crop_from = jnp.concatenate([crop_from, jnp.zeros((2,), dtype=jnp.int32)])
    padded_img = jnp.pad(img, ((padding, padding), (padding, padding), (0, 0), (0, 0)), mode="edge")
    return jax.lax.dynamic_slice(padded_img, crop_from, img.shape)


def batched_random_crop(key, obs, pixel_key, padding=4):
    imgs = obs[pixel_key]
    keys = jax.random.split(key, imgs.shape[0])
    imgs = jax.vmap(random_crop, (0, 0, None))(keys, imgs, padding)
    obs.update({pixel_key: imgs})
    return obs


class DrSLearner(RewardLearner):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.lr = 3e-4
        config.embd_dim = 512
        config.optimizer_type = "adam"
        config.scheduler_type = "none"
        # transfer type
        config.transfer_type = "clip_vit_b16"

        config.image_keys = "image"
        config.num_images = 1
        config.features = (32, 32, 32, 32)
        config.filters = (7, 5, 3, 3)
        config.strides = (2, 2, 2, 1)

        # Optimizer parameters
        config.adam_beta1 = 0.9
        config.adam_beta2 = 0.98
        config.weight_decay = 0.02

        config.window_size = 4

        # seed
        config.seed = 0

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(
        self,
        config: ConfigDict = None,
        env_name: str = "one_leg",
        observation_dim: Sequence[int] = (224, 224, 3),
        action_dim: int = 8,
        state: flax.training.train_state.TrainState = None,
        jax_devices: Sequence[jax.Device] = None,
    ):
        self.config = config
        self.config.max_episode_steps = TASK_TO_MAX_EPISODE_STEPS[env_name.split("|")[0]]
        self.config.n_stages = TASK_TO_PHASE[env_name.split("|")[0]]
        self.network = self._define_network()
        self.observation_dim = observation_dim
        self.action_dim = action_dim
        self.pvr_model, self.pvr_model_var = self._load_pvr_network()
        self.n_devices = len(jax_devices) if jax_devices else 1

        self._total_steps = 0
        self._train_states = {}
        if state is None:
            state = self._init_train_state()
            if jax_devices is not None:
                self.train_pmap = jax.pmap(self._train_step, axis_name="pmap", devices=jax_devices)
                self.eval_pmap = jax.pmap(self._eval_step, axis_name="pmap", devices=jax_devices)

        model_keys = ["disc"]
        self._model_keys = tuple(model_keys)
        self.load_state(state, jax_devices=jax_devices)
        self.trained = [False] * self.config.n_stages

        def data_augmentation_fn(rng, observations):
            for obs_key in self.config.image_keys.split("|"):
                key, rng = jax.random.split(rng)
                observations = batched_random_crop(
                    key, observations, obs_key, padding=int(self.observation_dim[0] * (4 / 64))
                )
            return observations

        self.aug_fn = data_augmentation_fn

    def set_trained(self, stage_idx):
        self.trained[stage_idx] = True

    def load_state(self, state, jax_devices=None, reset_optimizer=False):
        if jax_devices is not None:
            if reset_optimizer:
                print("load state with reset optimizer for fine-tuning.")
                params = state["disc"].params
                tx = optax.adam(
                    learning_rate=self.config.lr,
                    weight_decay=self.config.weight_decay,
                    b1=self.config.adam_beta1,
                    b2=self.config.adam_beta2,
                )
                state = TrainState.create(params=params, batch_stats=None, tx=tx, apply_fn=self.network.apply)
            state = flax.jax_utils.replicate(state, jax_devices)
            self._train_states["disc"] = sync_state_fn(state)
        else:
            self._train_states["disc"] = state["disc"]

    def _load_pvr_network(self):
        if self.config.transfer_type == "liv":
            clip_model, clip_model_var, _ = load_liv_model()
            self.image_size = 224
        elif self.config.transfer_type.startswith("clip"):
            clip_type = self.config.transfer_type.split("_", 1)[-1]
            clip_model, clip_model_var, _ = load_clip_model(clip_type)
            self.image_size = IMAGE_RESOLUTION[clip_type]
        self.config.vision_embd_dim = clip_model.vision_features
        return clip_model, clip_model_var

    def _define_network(self):
        return MultiStageCNNRewardModel(config=self.config)
        # return MultiStagePVRRewardModel(config=self.config)

    def _init_train_state(self):
        num_patches = 1 + (
            self.image_size // self.pvr_model.vision_patch_size * self.image_size // self.pvr_model.vision_patch_size
        )
        variables = self.network.init(
            {"params": JaxRNG.from_seed(self.config.seed)(), "dropout": JaxRNG.from_seed(self.config.seed)()},
            jnp.ones(
                (self.config.num_images, 1, self.config.window_size, num_patches, self.pvr_model.vision_features),
                # (self.config.num_images, 1, self.config.window_size, self.pvr_model.vision_features),
                dtype=jnp.float32,
            ),
            # jnp.ones(
            #     (self.config.num_images, 1, self.config.window_size, *self.observation_dim),
            #     dtype=jnp.float32,
            # ),
        )

        variables = flax.core.frozen_dict.unfreeze(variables)
        params = flax.core.frozen_dict.unfreeze(variables["params"])
        batch_stats = (
            flax.core.frozen_dict.unfreeze(variables["batch_stats"])
            if variables.get("batch_stats") is not None
            else None
        )
        tx = optax.adam(self.config.lr)
        states = TrainState.create(params=params, batch_stats=batch_stats, tx=tx, apply_fn=self.network.apply)
        return states

    @partial(jax.jit, static_argnames=("self"))
    def _eval_pref_step(self, train_states, rng, batch):
        return None

    @partial(jax.jit, static_argnames=("self"))
    def _train_pref_step(self, train_states, rng, batch):
        return None

    def get_reward(self, batch, stage_idx, get_video_feature=False, get_text_feature=False):
        return self._get_reward_step(
            self._train_states, batch, stage_idx, get_video_feature=get_video_feature, get_text_feature=get_text_feature
        )

    @partial(jax.jit, static_argnames=("self", "get_video_feature", "get_text_feature"))
    def _get_reward_step(self, train_states, batch, stage_idx, get_video_feature=False, get_text_feature=False):
        obs = jnp.array(list(batch["image"].values()))
        _, B = obs.shape[:2]
        # pvr_image_feature = self._get_pvr_feature(obs)
        # image_feature = pvr_image_feature
        image_feature = obs

        stage_rewards = [
            jnp.tanh(self._extract_score(train_states["disc"].params, image_feature, i)).squeeze(0)
            for i in range(self.config.n_stages)
        ]
        stage_rewards = jnp.concatenate(stage_rewards + [jnp.zeros((B, 1))], axis=1)
        reward = stage_rewards[jnp.arange(B), stage_idx.astype(jnp.int32)]

        k = 3
        reward = k * stage_idx + stage_rewards[jnp.arange(B), stage_idx.astype(jnp.int32)]
        reward = reward / (k * self.config.n_stages)  # reward is in (0, 1]
        reward = reward - 2  # make the reward negative

        return reward

    def _get_pvr_feature(self, images):
        original_shape = images.shape[:-3]
        images = jnp.reshape(images, (-1,) + images.shape[-3:])
        images = (images / 255.0).astype(jnp.float32)
        if images.shape[-3] != 224:
            images = jax.image.resize(
                images, (images.shape[0], 224, 224, images.shape[-1]), method="bicubic"
            )  # to meet the input size of the clip model
        images = normalize_image(images)
        image_feature_map = self.pvr_model.apply(
            self.pvr_model_var,
            images,
            method=self.pvr_model.encode_image,
            normalize=False,
        )
        image_feat = image_feature_map[:, 0]
        image_feat = jnp.reshape(image_feat, original_shape + (-1,))
        return image_feat
        # image_feature_map = jnp.reshape(image_feature_map, original_shape + image_feature_map.shape[-2:])
        # return image_feature_map

    def subsample_reward_model_params(self, params, stage_idx):
        new_params = params.copy()
        new_params["nets"] = jax.tree_util.tree_map(lambda param: param[stage_idx][jnp.newaxis], params["nets"])
        return new_params

    def _extract_score(self, params, image_feature, stage_idx, training=False):
        target_params = self.subsample_reward_model_params(params, stage_idx)
        reward = self.network.apply(
            {"params": target_params},
            image_feature,
            training=training,
            method=self.network.predict_reward,
        )
        return reward

    def loss_fn(self, params, batch, stage_idx, rng, neg_batch=None, training=False):
        rng, key = jax.random.split(rng)
        images = self.aug_fn(key, batch["image"])
        obs = jnp.array(list(images.values()))
        if neg_batch is not None:
            rng, key = jax.random.split(rng)
            neg_images = self.aug_fn(key, neg_batch["image"])
            neg_obs = jnp.array(list(neg_images.values()))
            num_image, batch_size, seq_length = obs.shape[:3]
            total_image_feature = obs

        def custom_sigmoid_binary_cross_entropy(logits, labels):
            log_p = jax.nn.log_sigmoid(logits)
            log_not_p = jax.nn.log_sigmoid(-logits)
            return -labels * log_p - (1.0 - labels) * log_not_p

        ce_loss = lambda x, y: jnp.mean(custom_sigmoid_binary_cross_entropy(logits=x, labels=y))

        disc_loss = 0.0
        for _stage_idx in range(self.config.n_stages):
            disc_labels = jnp.asarray(stage_indices >= _stage_idx, dtype=jnp.int32)
            # disc_labels = jnp.concatenate(
            #     [jnp.ones((batch_size,), dtype=jnp.int32), jnp.zeros((batch_size,), dtype=jnp.int32)], axis=0
            # )
            score = self._extract_score(params, total_image_feature, _stage_idx, training=training).squeeze()
            disc_loss += ce_loss(score, disc_labels)
        aux = {"disc_loss": disc_loss, "total_loss": disc_loss}

        return disc_loss, aux

    def train_step(self, batch, rng, neg_batch=None):
        self._total_steps += 1
        # stage_idx = np.array([random.randint(0, self.config.n_stages - 1) for _ in range(self.n_devices)])
        stage_idx = batch["ep_phase"]
        self._train_states, metrics, rng = self.train_pmap(
            self._train_states, batch, stage_idx, rng, neg_batch=neg_batch
        )
        return metrics, rng

    @partial(jax.jit, static_argnames=("self", "stage_idx"))
    def _train_step(self, train_states, batch, stage_idx, rng, neg_batch=None):
        next_rng, split_rng = jax.random.split(rng)
        grad_fn = jax.value_and_grad(self.loss_fn, has_aux=True)
        (loss, metrics), grads = jax.lax.pmean(
            grad_fn(train_states["disc"].params, batch, stage_idx, split_rng, neg_batch=neg_batch, training=True),
            axis_name="pmap",
        )

        new_train_states = {"disc": train_states["disc"].apply_gradients(grads=grads)}
        return new_train_states, metrics, next_rng

    def eval_step(self, batch, rng, neg_batch=None):
        # stage_idx = np.array([random.randint(0, self.config.n_stages - 1) for _ in range(self.n_devices)])
        stage_idx = batch["ep_phase"]
        metrics, rng = self.eval_pmap(self._train_states, batch, stage_idx, rng, neg_batch=neg_batch)
        return metrics, rng

    def _eval_step(self, train_states, batch, stage_idx, rng, neg_batch=None):
        next_rng, split_rng = jax.random.split(rng)
        _, metrics = jax.lax.pmean(
            self.loss_fn(train_states["disc"].params, batch, stage_idx, split_rng, neg_batch=neg_batch, training=False),
            axis_name="pmap",
        )
        return metrics, next_rng

    def train_single_step(self, batch, stage_idx, neg_batch=None):
        rng = next_rng()
        self._total_steps += 1
        self._train_states, metrics, rng = self._train_single_step(
            self._train_states, batch, stage_idx, rng, neg_batch=neg_batch
        )
        return metrics, rng

    @partial(jax.jit, static_argnames=("self", "stage_idx"))
    def _train_single_step(self, train_states, batch, stage_idx, rng, neg_batch=None):
        next_rng, split_rng = jax.random.split(rng)
        grad_fn = jax.value_and_grad(self.loss_fn, has_aux=True)
        (loss, metrics), grads = grad_fn(
            train_states["disc"].params, batch, stage_idx, split_rng, neg_batch=neg_batch, training=True
        )

        new_train_states = {"disc": train_states["disc"].apply_gradients(grads=grads)}
        return new_train_states, metrics, next_rng
