from collections import defaultdict, namedtuple
import logging
import os, json, random
from pathlib import Path
import sys
import time
import PIL.Image as Image
import copy
from collections import deque
from moviepy.editor import ImageSequenceClip
from calvin_agent.models.calvin_base_model import CalvinBaseModel
import time
import re
import cv2
import time

sys.path.insert(0, Path(__file__).absolute().parents[2].as_posix())
from calvin_agent.evaluation.multistep_sequences import get_sequences
from calvin_agent.evaluation.utils import (
    collect_plan,
    count_success,
    create_tsne,
    get_env_state_for_initial_condition,
    get_log_dir,
    print_and_save,
)
import hydra
import numpy as np
from omegaconf import OmegaConf
from termcolor import colored
import torch
from tqdm.auto import tqdm
from utils.data_utils import preprocess_image, preprocess_text_calvin
import functools
from utils.train_utils import get_cast_dtype
import calvin_env
from numpy import pi
import pyhash
hasher = pyhash.fnv1_32()
import contextlib
import math
import torch.nn.functional as F
from torch.distributions import Categorical
from utils.eval_utils_calvin import print_and_save_name

os.environ['PYOPENGL_PLATFORM'] = 'egl'
logger = logging.getLogger(__name__)

EP_LEN = 360
NUM_SEQUENCES = 1000
observation_space = {
        'rgb_obs': ['rgb_static', 'rgb_gripper'],
        'depth_obs': [],
        'state_obs': ['robot_obs'],
        'actions': ['rel_actions'],
        'language': ['language']}


def get_env(conf_path, obs_space=None, show_gui=True, **kwargs):
    from pathlib import Path

    from omegaconf import OmegaConf

    render_conf = OmegaConf.load(conf_path)  # change

    if obs_space is not None:
        exclude_keys = set(render_conf.cameras.keys()) - {
            re.split("_", key)[1] for key in obs_space["rgb_obs"] + obs_space["depth_obs"]
        }
        for k in exclude_keys:
            del render_conf.cameras[k]
    if "scene" in kwargs:
        scene_cfg = OmegaConf.load(Path(calvin_env.__file__).parents[1] / "conf/scene" / f"{kwargs['scene']}.yaml")
        OmegaConf.merge(render_conf, scene_cfg)
    if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
        hydra.initialize(".")
    env = hydra.utils.instantiate(render_conf.env, show_gui=show_gui, use_vr=False, use_scene_info=True)
    return env


def make_env(conf_path, obs_space):
    # val_folder = Path(dataset_path) / "validation"
    env = get_env(conf_path, obs_space, show_gui=False)

    return env


@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)


def get_env_state_for_initial_condition_abcd(initial_condition, env_name):
    robot_obs = np.array(
        [
            0.02586889,
            -0.2313129,
            0.5712808,
            3.09045411,
            -0.02908596,
            1.50013585,
            0.07999963,
            -1.21779124,
            1.03987629,
            2.11978254,
            -2.34205014,
            -0.87015899,
            1.64119093,
            0.55344928,
            1.0,
        ]
    )
    block_rot_z_range = (pi / 2 - pi / 8, pi / 2 + pi / 8)
    if env_name == 'd':
        block_slider_left = np.array([-2.40851662e-01, 9.24044687e-02, 4.60990009e-01])
        block_slider_right = np.array([7.03416330e-02, 9.24044687e-02, 4.60990009e-01])
        block_table = [
            np.array([5.00000896e-02, -1.20000177e-01, 4.59990009e-01]),
            np.array([2.29995412e-01, -1.19995140e-01, 4.59990010e-01]),
        ]
    elif env_name == 'a':  # table
        block_slider_left = np.array([-2.40851662e-01, 9.24044687e-02, 4.60990009e-01])
        block_slider_right = np.array([7.03416330e-02, 9.24044687e-02, 4.60990009e-01])
        block_table = [
            np.array([-0.021428430, -1.20000177e-01, 4.59990009e-01]),
            np.array([0.201421361, -1.19995140e-01, 4.59990010e-01]),
        ]
    elif env_name == 'b':  # table slider
        block_slider_left = np.array([-2.40851662e-01 + 0.2, 9.24044687e-02, 4.60990009e-01])
        block_slider_right = np.array([7.03416330e-02 + 0.2, 9.24044687e-02, 4.60990009e-01])
        block_table = [
            np.array([-0.188571300, -1.20000177e-01, 4.59990009e-01]),
            np.array([0.021435125, -1.19995140e-01, 4.59990010e-01]),
        ]
    elif env_name == 'c':  # slider
        block_slider_left = np.array([-2.40851662e-01 + 0.2, 9.24044687e-02, 4.60990009e-01])
        block_slider_right = np.array([7.03416330e-02 + 0.2, 9.24044687e-02, 4.60990009e-01])
        block_table = [
            np.array([5.00000896e-02, -1.20000177e-01, 4.59990009e-01]),
            np.array([2.29995412e-01, -1.19995140e-01, 4.59990010e-01]),
        ]
    else:
        raise ValueError('only abcd env')
    # we want to have a "deterministic" random seed for each initial condition
    seed = hasher(str(initial_condition.values()))
    with temp_seed(seed):
        np.random.shuffle(block_table)

        scene_obs = np.zeros(24)
        if initial_condition["slider"] == "left":
            scene_obs[0] = 0.28
        if initial_condition["drawer"] == "open":
            scene_obs[1] = 0.22
        if initial_condition["lightbulb"] == 1:
            scene_obs[3] = 0.088
        scene_obs[4] = initial_condition["lightbulb"]
        scene_obs[5] = initial_condition["led"]
        # red block
        if initial_condition["red_block"] == "slider_right":
            scene_obs[6:9] = block_slider_right
        elif initial_condition["red_block"] == "slider_left":
            scene_obs[6:9] = block_slider_left
        else:
            scene_obs[6:9] = block_table[0]
        scene_obs[11] = np.random.uniform(*block_rot_z_range)
        # blue block
        if initial_condition["blue_block"] == "slider_right":
            scene_obs[12:15] = block_slider_right
        elif initial_condition["blue_block"] == "slider_left":
            scene_obs[12:15] = block_slider_left
        elif initial_condition["red_block"] == "table":
            scene_obs[12:15] = block_table[1]
        else:
            scene_obs[12:15] = block_table[0]
        scene_obs[17] = np.random.uniform(*block_rot_z_range)
        # pink block
        if initial_condition["pink_block"] == "slider_right":
            scene_obs[18:21] = block_slider_right
        elif initial_condition["pink_block"] == "slider_left":
            scene_obs[18:21] = block_slider_left
        else:
            scene_obs[18:21] = block_table[1]
        scene_obs[23] = np.random.uniform(*block_rot_z_range)

    return robot_obs, scene_obs


class PGReplayBuffer:
    def __init__(self) -> None:
        self.buffer = deque()

    def push(self, transitions):
        '''_summary_
        Args:
            trainsitions (tuple): _description_
        '''
        self.buffer.append(transitions)

    def sample(self):
        ''' sample all the transitions
        '''
        batch = list(self.buffer)
        return zip(*batch)

    def clear(self):
        self.buffer.clear()

    def __len__(self):
        return len(self.buffer)


class RLFinetuneModelWrapper(CalvinBaseModel):
    def __init__(self, model, q_model, tokenizer, image_processor, cast_dtype, history_len=10,
                 calvin_eval_max_steps=360, action_pred_steps=3):
        super().__init__()
        self.model = model
        self.q_model = q_model
        self.cast_type = cast_dtype
        self.use_diff = False
        self.text_process_fn = functools.partial(preprocess_text_calvin, tokenizer=tokenizer)
        self.image_process_fn = functools.partial(preprocess_image, image_processor=image_processor)
        self.action_hist_queue = []
        self.history_len = history_len
        self.calvin_eval_max_steps = calvin_eval_max_steps
        self.action_pred_steps = action_pred_steps
        self.device = "cuda"
        self.img_queue = deque(maxlen=history_len)
        self.gripper_queue = deque(maxlen=history_len)
        self.state_queue = deque(maxlen=history_len)
        self.mask_queue = deque(maxlen=history_len)
        self.text_queue = deque(maxlen=history_len)
        self.act_queue = deque(maxlen=history_len - 1)

    def reset(self):
        self.img_queue = deque(maxlen=self.history_len)
        self.gripper_queue = deque(maxlen=self.history_len)
        self.state_queue = deque(maxlen=self.history_len)
        self.mask_queue = deque(maxlen=self.history_len)
        self.text_queue = deque(maxlen=self.history_len)
        self.act_queue = deque(maxlen=self.history_len - 1)

    def step(self, obs, goal, timestep):
        image = obs["rgb_obs"]['rgb_static']
        image = Image.fromarray(image)
        image_x = self.image_process_fn([image])
        image_x = image_x.unsqueeze(1).to(dtype=self.cast_type)

        gripper = obs["rgb_obs"]['rgb_gripper']
        gripper = Image.fromarray(gripper)
        gripper = self.image_process_fn([gripper])
        gripper = gripper.unsqueeze(1).to(dtype=self.cast_type)

        text_x = self.text_process_fn([goal])
        text_x = text_x.unsqueeze(1)

        state = obs['robot_obs']
        state = torch.from_numpy(np.stack([state]))
        state = state.unsqueeze(1).to(dtype=self.cast_type)
        state = torch.cat([state[..., :6], state[..., [-1]]], dim=-1)

        with torch.no_grad():
            device = 'cuda'
            image_x = image_x.to(device)
            text_x = text_x.to(device)
            gripper = gripper.to(device)
            state = state.to(device)
            self.img_queue.append(image_x)
            self.gripper_queue.append(gripper)
            self.state_queue.append(state)
            if len(self.text_queue) == 0 and text_x is not None:
                self.text_queue.append(text_x)
                for _ in range(self.history_len - 1):
                    self.text_queue.append(text_x)
            image_primary = torch.cat(list(self.img_queue), dim=1)
            image_wrist = torch.cat(list(self.gripper_queue), dim=1)
            state = torch.cat(list(self.state_queue), dim=1)
            input_text_token = torch.cat(list(self.text_queue), dim=1)
            num_step = image_primary.shape[1]
            if num_step < self.history_len:
                input_image_primary = torch.cat(
                    [image_primary, image_primary[:, -1].repeat(1, self.history_len - num_step, 1, 1, 1)], dim=1)
                input_image_wrist = torch.cat(
                    [image_wrist, image_wrist[:, -1].repeat(1, self.history_len - num_step, 1, 1, 1)], dim=1)
                input_state = torch.cat([state, state[:, -1].repeat(1, self.history_len - num_step, 1)], dim=1)
            else:
                input_image_primary = image_primary
                input_image_wrist = image_wrist
                input_state = state
            arm_action, discrete_arm_action, discrete_arm_action_porbs, gripper_action, \
                image_pred, arm_pred_state, gripper_pred_state, _ = self.model(
                    image_primary=input_image_primary,
                    image_wrist=input_image_wrist,
                    state=input_state,
                    text_token=input_text_token,
                    action=torch.zeros(1, self.history_len, 7).to(input_state.device),
                )  # arm_action: [1, 10, 3, 6]; discrete_arm_action_porbs: [1, 10, 3, 6, 1000]
            action = torch.concat((discrete_arm_action[0, :, 0, :], gripper_action[0, :, 0, :] > 0.5), dim=-1)  # [10, 7]
            action[:, -1] = (action[:, -1] - 0.5) * 2  # scale to -1 or 1
            discrete_arm_action_porbs = discrete_arm_action_porbs[0, :, 0, :, :]
            if num_step < self.history_len:
                action = action[num_step - 1]
                discrete_arm_action_porbs = discrete_arm_action_porbs[num_step - 1]
            else:
                action = action[-1]
                discrete_arm_action_porbs = discrete_arm_action_porbs[-1]
            return action, discrete_arm_action_porbs

    def q_step(self, obs, goal):
        image = obs["rgb_obs"]['rgb_static']
        image = Image.fromarray(image)
        image_x = self.image_process_fn([image])
        image_x = image_x.unsqueeze(1).to(dtype=self.cast_type)

        gripper = obs["rgb_obs"]['rgb_gripper']
        gripper = Image.fromarray(gripper)
        gripper = self.image_process_fn([gripper])
        gripper = gripper.unsqueeze(1).to(dtype=self.cast_type)

        text_x = self.text_process_fn([goal])
        text_x = text_x.unsqueeze(1)

        state = obs['robot_obs']
        state = torch.from_numpy(np.stack([state]))
        state = state.unsqueeze(1).to(dtype=self.cast_type)
        state = torch.cat([state[..., :6], state[..., [-1]]], dim=-1)

        with torch.no_grad():
            device = 'cuda'
            image_primary = image_x.to(device)
            image_wrist = gripper.to(device)
            text_token = text_x.to(device)
            state = state.to(device)

            num_step = image_primary.shape[1]
            input_image_primary = torch.cat(
                [image_primary, image_primary[:, -1].repeat(1, self.history_len - num_step, 1, 1, 1)], dim=1)
            input_image_wrist = torch.cat(
                [image_wrist, image_wrist[:, -1].repeat(1, self.history_len - num_step, 1, 1, 1)], dim=1)
            input_state = torch.cat([state, state[:, -1].repeat(1, self.history_len - num_step, 1)], dim=1)
            input_text_token = torch.cat(
                [text_token, text_token[:, -1].repeat(1, self.history_len - num_step, 1)], dim=1)

            value_pred, _, _, _, _ = self.q_model(
                image_primary=input_image_primary,
                image_wrist=input_image_wrist,
                state=input_state,
                text_token=input_text_token,
            )
            value_pred = value_pred.squeeze(0)
            value = value_pred.cpu().detach().to(dtype=torch.float16).numpy()
            if num_step < self.history_len:
                value = value[num_step - 1]
            else:
                value = value[-1]

        return value

    def state_list_step(self, batch_state_list):
        batch_size = len(batch_state_list)
        input_image_primary_list = []
        input_image_wrist_list = []
        input_state_list = []
        input_text_token_list = []
        num_step_list = []

        for b in range(batch_size):
            state_list = batch_state_list[b]
            img_queue = []
            gripper_queue = []
            robot_state_queue = []
            text_queue = []
            device = 'cuda'
            for i in range(len(state_list)):
                obs = state_list[i][0]

                image = obs["rgb_obs"]['rgb_static']
                image = Image.fromarray(image)
                image_x = self.image_process_fn([image])
                image_x = image_x.unsqueeze(1).to(dtype=self.cast_type)

                gripper = obs["rgb_obs"]['rgb_gripper']
                gripper = Image.fromarray(gripper)
                gripper = self.image_process_fn([gripper])
                gripper = gripper.unsqueeze(1).to(dtype=self.cast_type)

                state = obs['robot_obs']
                state = torch.from_numpy(np.stack([state]))
                state = state.unsqueeze(1).to(dtype=self.cast_type)
                state = torch.cat([state[..., :6], state[..., [-1]]], dim=-1)

                image_x = image_x.to(device)
                gripper = gripper.to(device)
                state = state.to(device)
                img_queue.append(image_x)
                gripper_queue.append(gripper)
                robot_state_queue.append(state)
            goal = state_list[-1][1]
            text_x = self.text_process_fn([goal])
            text_x = text_x.unsqueeze(1)
            text_x = text_x.to(device)
            for _ in range(self.history_len):
                text_queue.append(text_x)

            image_primary = torch.cat(img_queue, dim=1)
            image_wrist = torch.cat(gripper_queue, dim=1)
            state = torch.cat(robot_state_queue, dim=1)
            input_text_token = torch.cat(text_queue, dim=1)
            num_step = image_primary.shape[1]
            if num_step < self.history_len:
                input_image_primary = torch.cat(
                    [image_primary, image_primary[:, -1].repeat(1, self.history_len - num_step, 1, 1, 1)], dim=1)
                input_image_wrist = torch.cat(
                    [image_wrist, image_wrist[:, -1].repeat(1, self.history_len - num_step, 1, 1, 1)], dim=1)
                input_state = torch.cat([state, state[:, -1].repeat(1, self.history_len - num_step, 1)], dim=1)
            else:
                input_image_primary = image_primary
                input_image_wrist = image_wrist
                input_state = state
            num_step_list.append(num_step)
            input_image_primary_list.append(input_image_primary)
            input_image_wrist_list.append(input_image_wrist)
            input_state_list.append(input_state)
            input_text_token_list.append(input_text_token)

        batch_input_image_primary = torch.cat(input_image_primary_list, dim=0)
        batch_input_image_wrist = torch.cat(input_image_wrist_list, dim=0)
        batch_input_state = torch.cat(input_state_list, dim=0)
        batch_input_text_token = torch.cat(input_text_token_list, dim=0)
        arm_action, discrete_arm_action, discrete_arm_action_porbs, gripper_action, \
            image_pred, arm_pred_state, gripper_pred_state, _ = self.model(
                image_primary=batch_input_image_primary,
                image_wrist=batch_input_image_wrist,
                state=batch_input_state,
                text_token=batch_input_text_token,
                action=torch.zeros(batch_size, self.history_len, 7).to(input_state.device),
            )  # arm_action: [b, 10, 3, 6]; discrete_arm_action_porbs: [b, 10, 3, 6, 1000]
        discrete_arm_action_porbs = discrete_arm_action_porbs[:, :, 0, :, :]  # [b, 10, 6, 1000]

        discrete_arm_action_porbs_list = []
        for b in range(batch_size):
            num_step = num_step_list[b]
            if num_step < self.history_len:
                discrete_arm_action_porbs_list.append(discrete_arm_action_porbs[b][num_step - 1])
            else:
                discrete_arm_action_porbs_list.append(discrete_arm_action_porbs[b][-1])
        discrete_arm_action_porbs = torch.stack(discrete_arm_action_porbs_list, dim=0)
        return discrete_arm_action_porbs


def rl_finetune_policy_ddp(args, env_conf_dir, env_name_list, model, ref_model, optimizer,
                           rl_finetune_log_dir=None, debug=False, reset=False, diverse_inst=False):
    """
    Run this function to evaluate a model on the CALVIN challenge.

    Args:
        model: Must implement methods of CalvinBaseModel.
        rl_finetune_log_dir: Path where to log evaluation results. If None, logs to /tmp/evaluation/
        debug: If True, show camera view and debug info.
    Returns:
        Dictionary with results
    """
    calvin_conf_path = args.calvin_conf_path
    num_finetune_seq = args.num_finetune_seq
    update_episode_per_sep = args.update_episode_per_sep

    conf_dir = Path(calvin_conf_path)
    task_cfg = OmegaConf.load(conf_dir / "callbacks/rollout/tasks/new_playtable_tasks.yaml")
    task_oracle = hydra.utils.instantiate(task_cfg)

    if diverse_inst:
        with open('./utils/lang_annotation_cache.json', 'r') as f:
            finetune_annotations = json.load(f)
    else:
        finetune_annotations = OmegaConf.load(conf_dir / "annotations/new_playtable.yaml")

    rl_finetune_log_dir = get_log_dir(rl_finetune_log_dir)
    total_sequences = get_sequences(NUM_SEQUENCES)
    finetune_sequences = random.sample(total_sequences, num_finetune_seq)
    device_num = int(torch.distributed.get_world_size())
    device_id = torch.distributed.get_rank()
    assert num_finetune_seq % device_num == 0
    interval_len = int(num_finetune_seq // device_num)
    finetune_sequences = finetune_sequences[device_id * interval_len:min((device_id + 1) * interval_len, num_finetune_seq)]
    rest_update_episode = update_episode_per_sep * len(finetune_sequences)
    results = []
    plans = defaultdict(list)
    local_sequence_i = 0
    base_sequence_i = device_id * interval_len

    if not debug:
        finetune_sequences = tqdm(finetune_sequences, position=0, leave=True)

    for initial_state, finetune_sequence in finetune_sequences:
        result, rest_update_episode = rl_finetune_sequence(args, env_conf_dir, env_name_list, model, ref_model,
                                                           optimizer, task_oracle, initial_state, finetune_sequence,
                                                           finetune_annotations, plans, debug, rest_update_episode,
                                                           rl_finetune_log_dir, base_sequence_i + local_sequence_i,
                                                           reset=reset, diverse_inst=diverse_inst)
        if rest_update_episode == 0:
            break
        results.append(result)
        description = f'rank: {args.rank}: ' + " ".join(
            [f"{i + 1}/5 : {v * 100:.1f}% |" for i, v in enumerate(count_success(results))]) + "|"
        finetune_sequences.set_description(description)
        local_sequence_i += 1
    assert rest_update_episode == 0
    print(f'rank: {args.rank} finished fine-tuning')

    return results


def rl_finetune_sequence(args, env_conf_dir, env_name_list, model, ref_model, optimizer, task_checker, initial_state,
                         finetune_sequence, finetune_annotations, plans, debug, rest_update_episode,
                         rl_finetune_log_dir='', sequence_i=-1, reset=False, diverse_inst=False):
    """
    Evaluates a sequence of language instructions.
    """
    env_name = env_name_list[sequence_i % len(env_name_list)]
    env_conf = os.path.join(env_conf_dir, f'{env_name}_merged_config.yaml')
    robot_obs, scene_obs = get_env_state_for_initial_condition_abcd(initial_state, env_name)
    env = make_env(env_conf, obs_space=observation_space)
    env.reset(robot_obs=robot_obs, scene_obs=scene_obs)
    success_counter = 0

    for subtask_i, subtask in enumerate(finetune_sequence):
        if rest_update_episode > 0:
            if reset:
                success = rollout(args, env, model, ref_model, optimizer, task_checker, subtask, finetune_annotations,
                                  plans, debug, subtask_i, sequence_i,
                                  diverse_inst=diverse_inst, robot_obs=robot_obs, scene_obs=scene_obs)
            else:
                success = rollout(args, env, model, ref_model, optimizer, task_checker, subtask, finetune_annotations,
                                  plans, debug, subtask_i, sequence_i,
                                  diverse_inst=diverse_inst)
            rest_update_episode -= 1
            print(f'rank {args.rank}: rest_update_episode = {rest_update_episode}')
            if success:
                success_counter += 1
            else:
                break
        else:
            break
    del env
    return success_counter, rest_update_episode


def rollout(args, env, model, ref_mode, optimizer, task_oracle, subtask, finetune_annotations, plans, debug,
            subtask_i=-1, sequence_i=-1, robot_obs=None, scene_obs=None, diverse_inst=False):
    """
    Run the actual rollout on one subtask (which is one natural language instruction).
    """
    planned_actions = []
    if robot_obs is not None and scene_obs is not None:
        env.reset(robot_obs=robot_obs, scene_obs=scene_obs)
    obs = env.get_obs()
    # get lang annotation for subtask
    if diverse_inst:
        lang_annotation = finetune_annotations[sequence_i][subtask_i]
    else:
        lang_annotation = random.choice(finetune_annotations[subtask])
    lang_annotation = lang_annotation.split('\n')[0]
    if '\u2019' in lang_annotation:
        lang_annotation.replace('\u2019', '\'')
    model.reset()
    start_info = env.get_info()

    memory = PGReplayBuffer()
    for t in range(EP_LEN):
        action, discrete_arm_action_logits = model.step(obs, lang_annotation, t)  # [7]; [6, 1000]
        value = model.q_step(obs, lang_annotation).item()
        discrete_arm_action_porbs = F.softmax(discrete_arm_action_logits, dim=-1)
        memory.push(((obs, lang_annotation), discrete_arm_action_porbs, value))
        np_action = action.cpu().detach().to(dtype=torch.float16).numpy()
        if len(planned_actions) == 0:
            if np_action.shape == (7,):
                planned_actions.append(np_action)
            else:
                planned_actions.extend([np_action[i] for i in range(np_action.shape[0])])
        np_action = planned_actions.pop(0)
        obs, _, _, current_info = env.step(np_action)
        if t == 0:
            collect_plan(model, plans, subtask)
        current_task_info = task_oracle.get_task_info_for_set(start_info, current_info, {subtask})
        if len(current_task_info) > 0:
            value = model.q_step(obs, lang_annotation).item()
            memory.push(((obs, lang_annotation), None, value))
            update(args, model, ref_mode, optimizer, memory, success=True)
            return True
    value = model.q_step(obs, lang_annotation).item()
    memory.push(((obs, lang_annotation), None, value))
    update(args, model, ref_mode, optimizer, memory, success=False)
    return False


def sigmoid(x):
    return 1 / (1 + math.exp(-x))


def get_gae_advantage_list(value_list, success, gamma=0.99, lbd=0.95, balance_coef=0.25):
    delta_list = [0] * len(value_list)
    # delta_list[-1] = 1 if success else -1
    for i in range(len(value_list) - 1):
        delta_list[i] = gamma * value_list[i + 1] - value_list[i]
    gae_advantage_list = [0] * len(value_list)  # for easy to read
    # gae_advantage_list[-1] = delta_list[-1]
    for i in reversed(range(len(value_list) - 1)):
        gae_advantage_list[i] = delta_list[i] + lbd * gamma * gae_advantage_list[i + 1]
    sparse_baseline = 1 if success else -1
    gae_advantage_list = [x + sparse_baseline for x in gae_advantage_list]
    if success:
        gae_advantage_list = [x * balance_coef for x in gae_advantage_list]
    return gae_advantage_list


def update(args, model, ref_model, optimizer, memory, success, gamma=0.99, seq_len=10, eps_clip=0.1):
    update_num_per_episode = args.update_num_per_episode
    sample_num_per_episode = args.sample_num_per_episode

    old_state_list, old_arm_prob_list, old_value_list = memory.sample()  # return list
    total_t = len(memory)
    # normalize the value list
    if total_t > 1:
        value_list_mean, value_list_std = np.mean(old_value_list), np.std(old_value_list)
        normalized_old_value_list = (old_value_list - value_list_mean) / value_list_std
    else:
        normalized_old_value_list = old_value_list
    # get GAE advantage list
    gae_advantage_list = get_gae_advantage_list(normalized_old_value_list, success, gamma=gamma)

    for _ in range(update_num_per_episode):
        # 不能采样到最后一个，所以要total_t - 1
        sample_index_list = sorted(list(np.random.choice(range(total_t - 1), size=sample_num_per_episode, replace=True)))
        # -----for advantage-----
        advantage_list = []
        for sample_index in sample_index_list:
            advantage_list.append(gae_advantage_list[sample_index])

        # -----for ratio-----
        batch_memory_state_list = []
        for sample_index in sample_index_list:
            if sample_index + 1 >= seq_len:
                sub_memory_state_list = old_state_list[sample_index - seq_len + 1: sample_index + 1]
            else:
                sub_memory_state_list = old_state_list[: sample_index + 1]
            batch_memory_state_list.append(sub_memory_state_list)
        # get new and old probs
        new_arm_logits = model.state_list_step(batch_memory_state_list)
        new_arm_probs = F.softmax(new_arm_logits, dim=-1)  # [b, 6, 1000]
        old_arm_probs = torch.stack([old_arm_prob_list[sample_index] for sample_index in sample_index_list], dim=0)  # [b, 6, 1000]

        # get new_log_prob by old_action
        old_actions = old_arm_probs.max(dim=-1).indices  # [b, 6]
        old_log_arm_probs = torch.log(old_arm_probs.max(dim=-1).values)  # [b, 6]
        new_cate_dist = Categorical(new_arm_probs)
        new_log_arm_probs = new_cate_dist.log_prob(old_actions)  # [b, 6]
        ratios = torch.exp(new_log_arm_probs.sum(dim=-1) - old_log_arm_probs.sum(dim=-1))  # [b]

        # kl
        with torch.no_grad():
            ref_new_arm_logits = ref_model.state_list_step(batch_memory_state_list)
            ref_new_arm_probs = F.softmax(ref_new_arm_logits, dim=-1)
        kl_loss = F.kl_div(torch.clamp(new_arm_probs, min=1e-6).log(), torch.clamp(ref_new_arm_probs, min=1e-6)) / sample_num_per_episode

        # ppo loss
        advantages = torch.tensor(advantage_list, requires_grad=False).to(ratios.device)  # [b]
        surr1 = ratios * advantages  # [b]
        surr2 = torch.clamp(ratios, 1 - eps_clip, 1 + eps_clip) * advantages  # [b]
        loss = -torch.min(surr1, surr2).mean()
        loss += kl_loss
        optimizer.zero_grad()
        loss.backward()
        # print('rank: {} have backward, time is {}, grad is {}'.format(args.rank, time.asctime(), model.model.module.discrete_arm_action_decoder.linear_list_1[0].weight.grad.max()))
        optimizer.step()
    memory.clear()
    return


def rl_finetune_one_epoch_calvin_ddp(args, model, q_model, ref_model, optimizer,
                                     dataset_path, image_processor, tokenizer,
                                     rl_finetune_log_dir=None, debug=False, future_act_len=-1,
                                     reset=False, diverse_inst=False, env_name_list=None, results_name='results'):
    cast_dtype = get_cast_dtype(args.precision)
    hist_len = args.sequence_length
    action_pred_steps = args.action_pred_steps
    wrapped_model = RLFinetuneModelWrapper(
        model,
        q_model,
        tokenizer,
        image_processor,
        cast_dtype,
        history_len=hist_len,
        calvin_eval_max_steps=EP_LEN,
        action_pred_steps=action_pred_steps)
    ref_wrapped_model = RLFinetuneModelWrapper(
        ref_model,
        None,
        tokenizer,
        image_processor,
        cast_dtype,
        history_len=hist_len,
        calvin_eval_max_steps=EP_LEN,
        action_pred_steps=action_pred_steps)
    env_conf_dir = f'./env_conf'
    rl_finetune_policy_ddp(args, env_conf_dir, env_name_list, wrapped_model,
                           ref_wrapped_model, optimizer,
                           rl_finetune_log_dir=rl_finetune_log_dir, debug=debug, reset=reset, diverse_inst=diverse_inst)
