# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import numpy as np
import torch.nn as nn
import gym
import os
from collections import deque
import random
from agent.rad_utils import *
import clip
from gym.wrappers import ResizeObservation
from PIL import Image

class eval_mode(object):
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(False)

    def __exit__(self, *args):
        for model, state in zip(self.models, self.prev_states):
            model.train(state)
        return False


def soft_update_params(net, target_net, tau):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(
            tau * param.data + (1 - tau) * target_param.data
        )


def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def module_hash(module):
    result = 0
    for tensor in module.state_dict().values():
        result += tensor.sum().item()
    return result


def make_dir(dir_path):
    try:
        os.mkdir(dir_path)
    except OSError:
        pass
    return dir_path


def preprocess_obs(obs, bits=5):
    """Preprocessing image, see https://arxiv.org/abs/1807.03039."""
    bins = 2**bits
    assert obs.dtype == torch.float32
    if bits < 8:
        obs = torch.floor(obs / 2**(8 - bits))
    obs = obs / bins
    obs = obs + torch.rand_like(obs) / bins
    obs = obs - 0.5
    return obs


class ReplayBuffer(object):
    """Buffer to store environment transitions."""
    def __init__(self, obs_shape, action_shape, capacity, batch_size, device, pre_image_size, use_loss=False):
        self.capacity = capacity
        self.batch_size = batch_size
        self.device = device
        self.image_size = pre_image_size
        self.device = device
        self.pre_image_size = pre_image_size  # for translation
        self.use_loss = use_loss
        # the proprioceptive obs is stored as float32, pixels obs as uint8
        obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8

        self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.pre_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
        self.pre_actions = np.empty_like(self.actions)
        self.next_actions = np.empty_like(self.actions)
        self.curr_rewards = np.empty((capacity, 1), dtype=np.float32)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)
        self.time_steps = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones = np.empty((capacity, 1), dtype=np.float32)
        self.option = np.empty((capacity, 1), dtype=np.int)
        #
        mask_shape = (56, 56)
        self.mask_obs = np.empty((capacity, *mask_shape), dtype=np.float32)

        #
        # if self.use_loss == "loss" or self.use_loss == "value" or self.use_loss == "option" or self.use_loss == "anneal":
        #     self.vlm_actions = np.empty((capacity, *action_shape), dtype=np.float32)
        # else:
        #     self.vlm_actions = np.empty((capacity, 1), dtype=object)
        self.vlm_actions = np.empty((capacity, *action_shape), dtype=np.float32)

        # #
        # self.model, self.preprocess = clip.load("ViT-L/14", device=device)

        #
        self.idx = 0
        self.last_save = 0
        self.full = False
        self.len = 0

    def add_raw(self, obs, RL_action, VLM_action, curr_reward, reward, next_obs, done):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], RL_action)
        np.copyto(self.vlm_actions[self.idx], VLM_action)
        np.copyto(self.curr_rewards[self.idx], curr_reward)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)
        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0
        self.len = self.len + 1
        if self.len > self.capacity:
            self.len = self.capacity

    def add_mask(self, obs, RL_action, VLM_action, curr_reward, reward, next_obs, done, mask_feature):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], RL_action)
        np.copyto(self.vlm_actions[self.idx], VLM_action)
        np.copyto(self.curr_rewards[self.idx], curr_reward)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)
        np.copyto(self.mask_obs[self.idx], mask_feature)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def add(self, obs, RL_action, VLM_action, curr_reward, reward, next_obs, done, pred_obs):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.pre_obses[self.idx], pred_obs)
        np.copyto(self.actions[self.idx], RL_action)
        np.copyto(self.vlm_actions[self.idx], VLM_action)
        np.copyto(self.curr_rewards[self.idx], curr_reward)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def add_weight(self, obs, RL_action, VLM_action, pre_action, curr_reward, reward, next_obs, done, pred_obs, mask_feature):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.pre_obses[self.idx], pred_obs)
        np.copyto(self.pre_actions[self.idx], pre_action)
        np.copyto(self.actions[self.idx], RL_action)
        np.copyto(self.vlm_actions[self.idx], VLM_action)
        np.copyto(self.curr_rewards[self.idx], curr_reward)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)
        np.copyto(self.mask_obs[self.idx], mask_feature)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def add_pre(self, obs, RL_action, VLM_action, pre_action, curr_reward, reward, next_obs, done, pred_obs):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.pre_obses[self.idx], pred_obs)
        np.copyto(self.pre_actions[self.idx], pre_action)
        np.copyto(self.actions[self.idx], RL_action)
        np.copyto(self.vlm_actions[self.idx], VLM_action)
        np.copyto(self.curr_rewards[self.idx], curr_reward)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def add_time(self, obs, RL_action, VLM_action, pre_action, curr_reward, reward, next_obs, done, pred_obs, time_step):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.pre_obses[self.idx], pred_obs)
        np.copyto(self.pre_actions[self.idx], pre_action)
        np.copyto(self.actions[self.idx], RL_action)
        np.copyto(self.vlm_actions[self.idx], VLM_action)
        np.copyto(self.curr_rewards[self.idx], curr_reward)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)
        np.copyto(self.time_steps[self.idx], time_step)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def add_op(self, obs, RL_action, VLM_action, curr_reward, reward, next_obs, done, option):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], RL_action)
        np.copyto(self.vlm_actions[self.idx], VLM_action)
        np.copyto(self.curr_rewards[self.idx], curr_reward)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)
        np.copyto(self.option[self.idx], option)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def sample(self, aug_funcs, k=False):
        #
        # 将action的t1到tn时刻的数据复制到next_action
        self.next_actions[:-1] = self.actions[1:]
        # print(len(self.actions))
        # print(self.actions)
        # print(self.next_actions)

        #
        idxs = np.random.randint(
            0, self.capacity if self.full else self.idx, size=self.batch_size
        )
        obses = self.obses[idxs]
        next_obses = self.next_obses[idxs]
        # if aug_funcs:
        #     for aug, func in aug_funcs.items():
        #         # apply crop and cutout first
        #         if 'crop' in aug or 'cutout' in aug:
        #             obses = func(obses)
        #             next_obses = func(next_obses)
        #         elif 'translate' in aug:
        #             og_obses = center_crop_images(obses, self.pre_image_size)
        #             og_next_obses = center_crop_images(next_obses, self.pre_image_size)
        #             obses, rndm_idxs = func(og_obses, self.image_size, return_random_idxs=True)
        #             next_obses = func(og_next_obses, self.image_size, **rndm_idxs)

        vlm_actions = self.vlm_actions[idxs]
        text= clip.tokenize(vlm_actions.flatten().tolist()).to(self.device)
        text_features = self.model.encode_text(text)  # text: torch.Size([bt, 512])
        # print("text:", text_features.shape)
        vlm_actions = text_features
        # vlm_actions = torch.as_tensor(self.vlm_actions[idxs], device=self.device)

        next_actions = torch.as_tensor(self.next_actions[idxs], device=self.device)
        obses = torch.as_tensor(obses, device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        curr_rewards = torch.as_tensor(self.curr_rewards[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(next_obses, device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        options = torch.as_tensor(self.option[idxs], device=self.device)

        # augmentations go here
        # if aug_funcs:
        #     for aug,func in aug_funcs.items():
        #         # skip crop and cutout augs， because it has been processed before
        #         if 'crop' in aug or 'cutout' in aug or 'translate' in aug:
        #             continue
        #         obses = func(obses)
        #         next_obses = func(next_obses)

        if k:
            return obses, actions, next_actions, vlm_actions, rewards, next_obses, not_dones, torch.as_tensor(self.k_obses[idxs], device=self.device)
        return obses, actions, next_actions, vlm_actions, curr_rewards, rewards, next_obses, not_dones, options

    def sample_loss(self, k=False):
        #
        # 将action的t1到tn时刻的数据复制到next_action
        self.next_actions[:-1] = self.actions[1:]

        #
        idxs = np.random.randint(
            0, self.capacity if self.full else self.idx, size=self.batch_size
        )

        next_actions = torch.as_tensor(self.next_actions[idxs], device=self.device)
        vlm_actions = torch.as_tensor(self.vlm_actions[idxs], device=self.device)
        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        pre_obses = torch.as_tensor(self.pre_obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        curr_rewards = torch.as_tensor(self.curr_rewards[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(
            self.next_obses[idxs], device=self.device
        ).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)

        return obses, pre_obses, actions, next_actions, vlm_actions, curr_rewards, rewards, next_obses, not_dones

    def sample_pre(self, k=False):
        #
        # 将action的t1到tn时刻的数据复制到next_action
        self.next_actions[:-1] = self.actions[1:]

        #
        idxs = np.random.randint(
            0, self.capacity if self.full else self.idx, size=self.batch_size
        )

        next_actions = torch.as_tensor(self.next_actions[idxs], device=self.device)
        vlm_actions = torch.as_tensor(self.vlm_actions[idxs], device=self.device)
        pre_actions = torch.as_tensor(self.pre_actions[idxs], device=self.device)
        pre_obses = torch.as_tensor(self.pre_obses[idxs], device=self.device).float()
        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        curr_rewards = torch.as_tensor(self.curr_rewards[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(
            self.next_obses[idxs], device=self.device
        ).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)

        return obses, pre_obses, pre_actions, actions, next_actions, vlm_actions, curr_rewards, rewards, next_obses, not_dones

    def sample_time(self, k=False):
        #
        # 将action的t1到tn时刻的数据复制到next_action
        self.next_actions[:-1] = self.actions[1:]

        #
        idxs = np.random.randint(
            0, self.capacity if self.full else self.idx, size=self.batch_size
        )

        next_actions = torch.as_tensor(self.next_actions[idxs], device=self.device)
        vlm_actions = torch.as_tensor(self.vlm_actions[idxs], device=self.device)
        pre_actions = torch.as_tensor(self.pre_actions[idxs], device=self.device)
        pre_obses = torch.as_tensor(self.pre_obses[idxs], device=self.device).float()
        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        curr_rewards = torch.as_tensor(self.curr_rewards[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(
            self.next_obses[idxs], device=self.device
        ).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        time_steps = torch.as_tensor(self.time_steps[idxs], device=self.device)

        return obses, pre_obses, pre_actions, actions, next_actions, vlm_actions, curr_rewards, rewards, next_obses, not_dones, time_steps

    def sample_mask(self, k=False):
        #
        # 将action的t1到tn时刻的数据复制到next_action
        self.next_actions[:-1] = self.actions[1:]
        #
        idxs = np.random.randint(
            0, self.capacity if self.full else self.idx, size=self.batch_size
        )

        next_actions = torch.as_tensor(self.next_actions[idxs], device=self.device)
        vlm_actions = torch.as_tensor(self.vlm_actions[idxs], device=self.device)
        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        curr_rewards = torch.as_tensor(self.curr_rewards[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(
            self.next_obses[idxs], device=self.device
        ).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        mask_obses = torch.as_tensor(self.mask_obs[idxs], device=self.device)

        return obses, actions, next_actions, vlm_actions, curr_rewards, rewards, next_obses, not_dones, mask_obses

    def sample_weight(self, k=False):
        #
        # 将action的t1到tn时刻的数据复制到next_action
        self.next_actions[:-1] = self.actions[1:]

        #
        idxs = np.random.randint(
            0, self.capacity if self.full else self.idx, size=self.batch_size
        )

        next_actions = torch.as_tensor(self.next_actions[idxs], device=self.device)
        vlm_actions = torch.as_tensor(self.vlm_actions[idxs], device=self.device)
        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        pre_obses = torch.as_tensor(self.pre_obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        pre_actions = torch.as_tensor(self.pre_actions[idxs], device=self.device)
        curr_rewards = torch.as_tensor(self.curr_rewards[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(
            self.next_obses[idxs], device=self.device
        ).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        mask_obses = torch.as_tensor(self.mask_obs[idxs], device=self.device)

        return obses, pre_obses, actions, next_actions, pre_actions, vlm_actions, curr_rewards, rewards, next_obses, not_dones, mask_obses

    def sample_raw(self, k=False):
        #
        # 将action的t1到tn时刻的数据复制到next_action
        self.next_actions[:-1] = self.actions[1:]

        #
        idxs = np.random.randint(
            0, self.capacity if self.full else self.idx, size=self.batch_size
        )

        next_actions = torch.as_tensor(self.next_actions[idxs], device=self.device)
        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        vlm_actions = torch.as_tensor(self.vlm_actions[idxs], device=self.device)
        curr_rewards = torch.as_tensor(self.curr_rewards[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(
            self.next_obses[idxs], device=self.device
        ).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)

        return obses, actions, next_actions, vlm_actions, curr_rewards, rewards, next_obses, not_dones

    def sample_op(self, k=False):
        #
        # 将action的t1到tn时刻的数据复制到next_action
        self.next_actions[:-1] = self.actions[1:]

        #
        idxs = np.random.randint(
            0, self.capacity if self.full else self.idx, size=self.batch_size
        )

        next_actions = torch.as_tensor(self.next_actions[idxs], device=self.device)
        vlm_actions = torch.as_tensor(self.vlm_actions[idxs], device=self.device)
        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        curr_rewards = torch.as_tensor(self.curr_rewards[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(
            self.next_obses[idxs], device=self.device
        ).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        options = torch.as_tensor(self.option[idxs], device=self.device)

        return obses, actions, next_actions, vlm_actions, curr_rewards, rewards, next_obses, not_dones, options

    def save(self, save_dir):
        if self.idx == self.last_save:
            return
        path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx))
        payload = [
            self.obses[self.last_save:self.idx],
            self.next_obses[self.last_save:self.idx],
            self.actions[self.last_save:self.idx],
            self.rewards[self.last_save:self.idx],
            self.curr_rewards[self.last_save:self.idx],
            self.not_dones[self.last_save:self.idx]
        ]
        self.last_save = self.idx
        torch.save(payload, path)

    def load(self, save_dir):
        chunks = os.listdir(save_dir)
        chucks = sorted(chunks, key=lambda x: int(x.split('_')[0]))
        for chunk in chucks:
            start, end = [int(x) for x in chunk.split('.')[0].split('_')]
            path = os.path.join(save_dir, chunk)
            payload = torch.load(path)
            assert self.idx == start
            self.obses[start:end] = payload[0]
            self.next_obses[start:end] = payload[1]
            self.actions[start:end] = payload[2]
            self.rewards[start:end] = payload[3]
            self.curr_rewards[start:end] = payload[4]
            self.not_dones[start:end] = payload[5]
            self.idx = end


class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        gym.Wrapper.__init__(self, env)
        self._k = k
        self._frames = deque([], maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=0,
            high=1,
            shape=((shp[0] * k,) + shp[1:]),
            dtype=env.observation_space.dtype
        )
        self._max_episode_steps = env._max_episode_steps

    def reset(self):
        obs = self.env.reset()
        for _ in range(self._k):
            self._frames.append(obs)
        return self._get_obs()

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self._frames.append(obs)
        return self._get_obs(), reward, done, info

    def _get_obs(self):
        assert len(self._frames) == self._k
        return np.concatenate(list(self._frames), axis=0)


class FrameStack_Gym(gym.Wrapper):
    def __init__(self, env, k, image_size):
        gym.Wrapper.__init__(self, env)
        self._k = k
        self._frames = deque([], maxlen=k)
        self._max_episode_steps = env._max_episode_steps
        env = ResizeObservation(env, shape=(image_size, image_size))
        shp = env.observation_space.shape
        # print("shp:", shp)
        self.observation_space = gym.spaces.Box(
            low=0,
            high=1,
            shape=((shp[2] * k,) + shp[0:2]),
            dtype=env.observation_space.dtype
        )
        self.env = env
        # print("observation_space:", self.observation_space)

    def reset(self):
        obs = self.env.reset()
        obs = obs.transpose(2, 0, 1).copy()
        # print("obsx1:", obs.shape)
        for _ in range(self._k):
            self._frames.append(obs)
        return self._get_obs()

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        obs = obs.transpose(2, 0, 1).copy()
        # print("obsx2:", obs.shape)
        self._frames.append(obs)
        return self._get_obs(), reward, done, info

    def _get_obs(self):
        assert len(self._frames) == self._k
        return np.concatenate(list(self._frames), axis=0)