





from typing import Any, Callable, Iterator, List, Optional, Tuple, Union, cast

import gym
import numpy as np
import torch
from typing_extensions import Protocol

from d3rlpy.dataset import Episode, TransitionMiniBatch
from d3rlpy.preprocessing.reward_scalers import RewardScaler
from d3rlpy.preprocessing.stack import StackedObservation

from d3rlpy.metrics.scorer import AlgoProtocol

import pandas as pd
import os

from copy import deepcopy

import pickle

from utils import ReplayBuffer

WINDOW_SIZE = 1024









class ReplayBuffer(ReplayBuffer):
    def convert_D4RL(self, dataset):
        self.state = dataset['observations']
        self.action = dataset['actions']
        self.next_state = dataset['next_observations']
        self.reward = dataset['rewards'].reshape(-1, 1)
        self.not_done = 1. - dataset['terminals'].reshape(-1, 1)
        # self.flag = dataset['flag'].reshape(-1, 1)
        self.size = self.state.shape[0]
        
    def convert_D3RL(self, dataset):
        print('OTIL Replay Buffer')
        observations, actions, rewards = dataset.observations, dataset.actions, dataset.rewards
        terminals, episode_terminals = dataset.terminals, dataset.episode_terminals

        obs_ = []
        next_obs_ = []
        action_ = []
        reward_ = []
        done_ = []
        flag_ = []

        N = len(observations)

        episode_step = 0
        # print('ReplayBuffer_Convert_D3RL,',episode_terminals.sum())
        for i in range(N-1):
            obs = observations[i].astype(np.float32)
            new_obs = observations[i+1].astype(np.float32)
            action = actions[i].astype(np.float32)
            reward = rewards[i].astype(np.float32)
            done_bool = bool(terminals[i])
            flag = terminals[i]

            # if done_bool:
            #     # Skip this transition and don't apply terminals on the last step of an episode
            #     episode_step = 0
            #     continue

            obs_.append(obs)
            next_obs_.append(new_obs)
            action_.append(action)
            reward_.append(reward)
            done_.append(done_bool)
            flag_.append(flag)
            episode_step += 1

        self.state = np.array(obs_)
        self.action = np.array(action_)
        self.next_state = np.array(next_obs_)
        self.reward = np.array(reward_).reshape(-1, 1)
        self.not_done = 1. - np.array(done_).reshape(-1, 1)
        self.flag = np.array(flag_).reshape(-1, 1)
        self.size = self.state.shape[0]
