
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

#

# This source code is licensed under the MIT license found in the

# LICENSE file in the root directory of this source tree.

import random

import re

import time

 

import numpy as np

import h5py

from collections import deque

import dmc

from dm_env import StepType

from drqbc.numpy_replay_buffer import EfficientReplayBuffer

 

import torch

import torch.nn as nn

from torch import distributions as pyd

from torch.distributions.utils import _standard_normal

 

from torch.utils.data import DataLoader

import torchvision.datasets as datasets

import torchvision.transforms as transforms

 

class eval_mode:

    def __init__(self, *models):

        self.models = models

 

    def __enter__(self):

        self.prev_states = []

        for model in self.models:

            self.prev_states.append(model.training)

            model.train(False)

 

    def __exit__(self, *args):

        for model, state in zip(self.models, self.prev_states):

            model.train(state)

        return False

 

 

def set_seed_everywhere(seed):

    torch.manual_seed(seed)

    if torch.cuda.is_available():

        torch.cuda.manual_seed_all(seed)

    np.random.seed(seed)

    random.seed(seed)

 

 

def soft_update_params(net, target_net, tau):

    for param, target_param in zip(net.parameters(), target_net.parameters()):

        target_param.data.copy_(tau * param.data +

                                (1 - tau) * target_param.data)

 

 

def to_torch(xs, device):

    return tuple(torch.as_tensor(x, device=device) for x in xs)

 

 

def weight_init(m):

    if isinstance(m, nn.Linear):

        nn.init.orthogonal_(m.weight.data)

        if hasattr(m.bias, 'data'):

            m.bias.data.fill_(0.0)

    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):

        gain = nn.init.calculate_gain('relu')

        nn.init.orthogonal_(m.weight.data, gain)

        if hasattr(m.bias, 'data'):

            m.bias.data.fill_(0.0)

 

 

class Until:

    def __init__(self, until, action_repeat=1):

        self._until = until

        self._action_repeat = action_repeat

 

    def __call__(self, step):

        if self._until is None:

            return True

        until = self._until // self._action_repeat

        return step < until

 

 

class Every:

    def __init__(self, every, action_repeat=1):

        self._every = every

        self._action_repeat = action_repeat

 

    def __call__(self, step):

        if self._every is None:

            return False

        every = self._every // self._action_repeat

        if step % every == 0:

            return True

        return False

 

 

class Timer:

    def __init__(self):

        self._start_time = time.time()

        self._last_time = time.time()

 

    def reset(self):

        elapsed_time = time.time() - self._last_time

        self._last_time = time.time()

        total_time = time.time() - self._start_time

        return elapsed_time, total_time

 

    def total_time(self):

        return time.time() - self._start_time

 

 

class TruncatedNormal(pyd.Normal):

    def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):

        super().__init__(loc, scale, validate_args=False)

        self.low = low

        self.high = high

        self.eps = eps

 

    def _clamp(self, x):

        clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)

        x = x - x.detach() + clamped_x.detach()

        return x

 

    def sample(self, clip=None, sample_shape=torch.Size()):

        shape = self._extended_shape(sample_shape)

        eps = _standard_normal(shape,

                               dtype=self.loc.dtype,

                               device=self.loc.device)

        eps *= self.scale

        if clip is not None:

            eps = torch.clamp(eps, -clip, clip)

        x = self.loc + eps

        return self._clamp(x)

 

 

def schedule(schdl, step):

    try:

        return float(schdl)

    except ValueError:

        match = re.match(r'linear\((.+),(.+),(.+)\)', schdl)

        if match:

            init, final, duration = [float(g) for g in match.groups()]

            mix = np.clip(step / duration, 0.0, 1.0)

            return (1.0 - mix) * init + mix * final

        match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl)

        if match:

            init, final1, duration1, final2, duration2 = [

                float(g) for g in match.groups()

            ]

            if step <= duration1:

                mix = np.clip(step / duration1, 0.0, 1.0)

                return (1.0 - mix) * init + mix * final1

            else:

                mix = np.clip((step - duration1) / duration2, 0.0, 1.0)

                return (1.0 - mix) * final1 + mix * final2

    raise NotImplementedError(schdl)

 

 

step_type_lookup = {

    0: StepType.FIRST,

    1: StepType.MID,

    2: StepType.LAST

}

 

 

def load_offline_dataset_into_buffer(offline_dir, replay_buffer, random_dataset1, random_dataset2, random_dataset3, frame_stack, replay_buffer_size, megaenv=True):

    filenames = sorted(offline_dir.glob('*.hdf5'))

    num_steps = 0

    for filename in filenames:

        print("filename is", filename, replay_buffer_size)

        episodes = h5py.File(filename, 'r')

        episodes = {k: episodes[k][:] for k in episodes.keys()}

        if megaenv:

            #assert random_offline_dir is not None

            random_episodes1 = h5py.File(random_dataset1, 'r')
            random_episodes1 = {k: random_episodes1[k][:] for k in random_episodes1.keys()}
            for k,v in random_episodes1.items():
                random_episodes1[k] = v[:10000]
            random_episodes2 = h5py.File(random_dataset2, 'r')
            random_episodes2 = {k: random_episodes2[k][:] for k in random_episodes2.keys()}
            for k,v in random_episodes2.items():
                random_episodes1[k] = v[:10000]
            random_episodes3 = h5py.File(random_dataset3, 'r')
            random_episodes3 = {k: random_episodes3[k][:] for k in random_episodes3.keys()}
            for k,v in random_episodes3.items():
                random_episodes1[k] = v[:10000]
            random_epi_list = [random_episodes1, random_episodes2, random_episodes3]
            del(random_episodes1)
            del(random_episodes2)
            del(random_episodes3)
            add_offline_data_to_buffer_megaenv(episodes, random_epi_list, replay_buffer, framestack=frame_stack)

        else:

            add_offline_data_to_buffer(episodes, replay_buffer, framestack=frame_stack)

        length = episodes['reward'].shape[0]

        num_steps += length

        print("Loaded {} offline timesteps so far...".format(int(num_steps)))

        if num_steps >= replay_buffer_size:

            break

    print("Finished, loaded {} timesteps.".format(int(num_steps)))

 

 

def add_offline_data_to_buffer(offline_data: dict, replay_buffer: EfficientReplayBuffer, framestack: int = 3):

    offline_data_length = offline_data['reward'].shape[0]

    for v in offline_data.values():

        assert v.shape[0] == offline_data_length

    done_list = np.argwhere(offline_data['step_type']==2)

    assert len(done_list) > 1

    interval = done_list[1] - done_list[0]

    now = -1

    max_k = 15

   

    for idx in range(offline_data_length):

        time_step = get_timestep_from_idx(offline_data, idx)

       

        if not time_step.first():

            now += 1

            stacked_frames.append(time_step.observation)

            time_step.observation=np.concatenate(stacked_frames, axis=0)

            time_step_stack = time_step

            # rindex = random.randint(now+1, min(interval-1, now+max_k))

            rindex = min(interval-1, now+max_k)

            rindex = rindex - now

            time_step_stack.k_step = rindex

            replay_buffer.add(time_step_stack)

        else:

            now = -1

            stacked_frames = deque(maxlen=framestack)

            while len(stacked_frames) < framestack:

                stacked_frames.append(time_step.observation)

            time_step.observation=np.concatenate(stacked_frames, axis=0)

            time_step_stack = time_step

            rindex = random.randint(now+1, min(interval-1, now+max_k))

            rindex = rindex - now

            time_step_stack.k_step = rindex

            replay_buffer.add(time_step_stack)

 

 

def add_offline_data_to_buffer_megaenv(offline_data: dict, random_offline_data_list: list, replay_buffer: EfficientReplayBuffer, framestack: int = 3):

    offline_data_length = offline_data['reward'].shape[0]

    for v in offline_data.values():

        assert v.shape[0] == offline_data_length

    done_list = np.argwhere(offline_data['step_type']==2)

    assert len(done_list) > 1

    interval = done_list[1] - done_list[0]

    now = -1

    max_k = 15

 

    env_size = 9 #should always be square shape

   

    for idx in range(offline_data_length):

        time_step = get_timestep_from_idx(offline_data, idx)

 

        rand_index = np.random.randint(low=0, high=len(random_offline_data_list[0])-1, size=env_size-1)
        

 

        rand_time_step = []

        for env_id in range(env_size-1):
            dataset_idx = np.random.randint(0,3)
            rand_time_step.append(get_timestep_from_idx(random_offline_data_list[dataset_idx], rand_index[env_id]))

       

        if not time_step.first():

            now += 1

            stacked_frames.append(time_step.observation)

            time_step.observation=np.concatenate(stacked_frames, axis=0)

 

            for env_id in range(env_size-1):

                rand_stacked_frames[env_id].append(rand_time_step[env_id].observation)

                rand_time_step[env_id].observation=np.concatenate(rand_stacked_frames[env_id], axis=0)

 

            #time_step_pos = np.random.randint(low=0, high=env_size)

            time_step_pos = 2

            mega_observation = np.zeros((9, 252, 252), dtype=np.uint8) # Nine 3 x 84 x 84 images placed in a square shape

 

            rand_ctr = 0

            for pos_i in range(int(np.sqrt(env_size))):

                for pos_j in range(int(np.sqrt(env_size))):

                    if time_step_pos == np.sqrt(env_size) * pos_i + pos_j:

                        mega_observation[:, ((pos_i) * 84):((pos_i+1) * 84), ((pos_j) * 84):((pos_j+1) * 84)] = time_step.observation

                    else:

                        mega_observation[:, ((pos_i) * 84):((pos_i+1) * 84), ((pos_j) * 84):((pos_j+1) * 84)] = rand_time_step[rand_ctr].observation

                        rand_ctr += 1

 

            time_step.observation = mega_observation

 

            time_step_stack = time_step

            # rindex = random.randint(now+1, min(interval-1, now+max_k))

            rindex = min(interval-1, now+max_k)

            rindex = rindex - now

            time_step_stack.k_step = rindex

            replay_buffer.add(time_step_stack)

        else:

            now = -1

            stacked_frames = deque(maxlen=framestack)

           

            rand_stacked_frames = []

            for env_id in range(env_size-1):

                rand_stacked_frames.append(deque(maxlen=framestack))

 

            while len(stacked_frames) < framestack:

                stacked_frames.append(time_step.observation)

 

                for env_id in range(env_size-1):

                    rand_stacked_frames[env_id].append(rand_time_step[env_id].observation)

 

            time_step.observation=np.concatenate(stacked_frames, axis=0)

            for env_id in range(env_size-1):

                rand_time_step[env_id].observation=np.concatenate(rand_stacked_frames[env_id], axis=0)

 

            # time_step_pos = np.random.randint(low=0, high=env_size)

            time_step_pos = 2

            mega_observation = np.zeros((9, 252, 252), dtype=np.uint8) # Nine 3 x 84 x 84 images placed in a square shape

 

            rand_ctr = 0
            
            for pos_i in range(int(np.sqrt(env_size))):

                for pos_j in range(int(np.sqrt(env_size))):

                    if time_step_pos == np.sqrt(env_size) * pos_i + pos_j:

                        mega_observation[:, ((pos_i) * 84):((pos_i+1) * 84), ((pos_j) * 84):((pos_j+1) * 84)] = time_step.observation

                    else:
                        mega_observation[:, ((pos_i) * 84):((pos_i+1) * 84), ((pos_j) * 84):((pos_j+1) * 84)] = rand_time_step[rand_ctr].observation

                        rand_ctr += 1

 

            time_step.observation = mega_observation

 

            time_step_stack = time_step

            rindex = random.randint(now+1, min(interval-1, now+max_k))

            rindex = rindex - now

            time_step_stack.k_step = rindex

            replay_buffer.add(time_step_stack)

        # if idx < 10:
        # import matplotlib.pyplot as plt
        # plt.imshow(time_step.observation[:3].transpose(1,2,0))
        # plt.savefig('/data/hyzang/project/rl/representation/acstate/v-d4rl/plots/9_exo_grid/'+str(idx)+'.png')
        # exit(0)

 

def add_offline_data_to_buffer_cifar(offline_data: dict, replay_buffer: EfficientReplayBuffer, framestack: int = 3):

    offline_data_length = offline_data['reward'].shape[0]

    for v in offline_data.values():

        assert v.shape[0] == offline_data_length

    done_list = np.argwhere(offline_data['step_type']==2)

    assert len(done_list) > 1

    interval = done_list[1] - done_list[0]

    now = -1

    max_k = 15

 

    # Add CIFAR images as distractors in the top left corner

    training_data = datasets.CIFAR10(root="data", train=True, download=True,

                                  transform=transforms.Compose([

                                      transforms.ToTensor(),

                                      #transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))

                                  ]))

 

    #training_loader = DataLoader(training_data,

    #                         batch_size=batch_size,

    #                         shuffle=True,

    #                         pin_memory=True)

   

    for idx in range(offline_data_length):

        time_step = get_timestep_from_idx(offline_data, idx)

       

        if not time_step.first():

            now += 1

            stacked_frames.append(time_step.observation)

            time_step.observation=np.concatenate(stacked_frames, axis=0)

 

            index = np.random.randint(low=0, high=training_data.__len__()-1)

            cifar_x, _ = training_data.__getitem__(1) #next(iter(training_loader))

            cifar_x = np.tile(cifar_x.float() * 255, (3, 1, 1))

 

            shape = cifar_x.shape

 

            time_step.observation[:, :shape[1], :shape[1]] = cifar_x

 

            time_step_stack = time_step

            # rindex = random.randint(now+1, min(interval-1, now+max_k))

            rindex = min(interval-1, now+max_k)

            rindex = rindex - now

            time_step_stack.k_step = rindex

            replay_buffer.add(time_step_stack)

        else:

            now = -1

            stacked_frames = deque(maxlen=framestack)

            while len(stacked_frames) < framestack:

                stacked_frames.append(time_step.observation)

            time_step.observation=np.concatenate(stacked_frames, axis=0)

 

            index = np.random.randint(low=0, high=training_data.__len__()-1)

            cifar_x, _ = training_data.__getitem__(0) #next(iter(training_loader))

            cifar_x = np.tile(cifar_x.float() * 255, (3, 1, 1))

 

            shape = cifar_x.shape

 

            time_step.observation[:, :shape[1], :shape[1]] = cifar_x

 

            time_step_stack = time_step

            rindex = random.randint(now+1, min(interval-1, now+max_k))

            rindex = rindex - now

            time_step_stack.k_step = rindex

            replay_buffer.add(time_step_stack)

       

 

def get_timestep_from_idx(offline_data: dict, idx: int):

   

    return dmc.ExtendedTimeStep(

        step_type=step_type_lookup[offline_data['step_type'][idx]],

        reward=offline_data['reward'][idx],

        observation=offline_data['observation'][idx],

        discount=offline_data['discount'][idx],

        action=offline_data['action'][idx],

        k_step = idx

    )