# https://github.com/kzl/decision-transformer/blob/master/atari/create_dataset.py
import os, base64, copy, re
from pathlib import Path

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import numpy as np
from tqdm.auto import tqdm

from utils import logs_handler, helpers, misc
from utils import process_image as processing
from utils.fixed_replay_buffer import FixedReplayBuffer
from utils.base_data import BaseDataset, make_split, sample_indices, get_cumulative_rewards, gather_returns,\
    check_terminal_idxs, get_episode_length

from utils.temporal import TemporalAugmentation

ALE_DATA = helpers.json_load(os.path.join(Path(__file__).parent.parent, 'ale_envs.json'))

logger = logs_handler.get_logger(__name__)

def to_ale_name(name):
    return '_'.join(map(lambda s: s.lower(), re.findall('[A-Z][^A-Z]*', name)))

def to_ale_action(name, act):
    assert name in ALE_DATA, f'ALE_DATA is undefined for "{name}"'
    assert act < len(ALE_DATA[name]['actions']), f'ALE_DATA is invalid for "{name}", action? {act}'
    return ALE_DATA[name]['actions'][act]

def get_human_normalized_score(ale_name, score):
    random_score, human_score = ALE_DATA[ale_name]['score']
    return misc.get_normalized_score(score, human_score, random_score)

def log_returns_stats(game, rewards, done_idxs):
    check_terminal_idxs(done_idxs)
    returns = gather_returns(get_cumulative_rewards(rewards, done_idxs), done_idxs)
    logger.info(f'{game} > Min Return: {min(returns)}')    
    logger.info(f'{game} > Max Return: {max(returns)}')
    logger.info(f'{game} > Avg Return: {np.mean(returns)}')
    logger.info(f'{game} > Stdv Return: {np.std(returns)}')

def random_dataset(num_steps, stack_size):
    # obs, actions, rewards, done_idxs
    obs = np.random.randn(num_steps, 84, 84, stack_size)
    actions = np.random.randint(low=0, high=12, size=(num_steps, ))
    rewards = np.zeros_like(actions)
    done_idxs = []
    idx = 0
    while True:
        k = np.random.randint(low=max(10, int(0.1*num_steps)), high=max(20, int(0.5*num_steps)))
        if idx + k >= num_steps:
            if num_steps - done_idxs[-1] < 10:
                done_idxs[-1] = num_steps
            else:
                done_idxs.append(num_steps)
            break
        idx += k
        done_idxs.append(idx)
    return obs, actions, rewards, done_idxs

def create_dataset(replay_logs_dir, num_steps, start_buffer=0, num_buffers=50, trajectories_per_buffer=10, 
                   stack_size=4, map_ale_actions=False, cache_dir=None, cachefile_prefix=None):
    name = replay_logs_dir.split('/')[-3]

    logger.info(f'{name} > REPLAY_LOGS_DIR - {replay_logs_dir}')
    logger.info(f'{name} > NUM_STEPS - {num_steps}')

    frb: FixedReplayBuffer
    end_buffer = start_buffer + num_buffers - 1
    assert end_buffer < 50, f'{name} > end_buffer={end_buffer} is too large, expected: end_buffer < 50'
    
    cachefile_prefix = cachefile_prefix or base64.b16encode(replay_logs_dir.encode()).decode()
    buffer_desc = f'{num_buffers}_{trajectories_per_buffer}'
    file_suffix = f'{cachefile_prefix}_{num_steps}_{buffer_desc}_{stack_size}.pkl'
    cache_path = os.path.join(cache_dir, file_suffix) if (cache_dir is not None) else None

    logger.info(f'{name} > CACHE_PATH - {cache_path}')

    if (cache_path is not None) and os.path.exists(cache_path):
        logger.info(f'"{name}" > Found cached dataset!')
        (obs, actions, rewards, done_idxs) = helpers.pickle_load(cache_path)
        data = (obs, actions, rewards, done_idxs)
        log_returns_stats(name, rewards, done_idxs)
        return data
    
    obs, actions, rewards, done_idxs = [], [], [], []
    
    # $store$_action_ckpt.{BUFFER_INDEX}.gz --> $store$_action_ckpt.1.gz 
    transitions_per_buffer = np.zeros(end_buffer, dtype=np.int32)
    
    replay_capacity = 100000
    num_trajectories = 0
    num_iterations = 0
    
    pbar = tqdm(total=num_steps)
    while len(obs) < num_steps:
        num_iterations += 1
        num_obs_preload = len(obs)
        buffer_num = int(np.random.randint(low=start_buffer, high=end_buffer, size=()))
        i = transitions_per_buffer[buffer_num]
        logger.info(f'{name} > Loading from buffer {buffer_num} which has {i} already loaded')
        frb = FixedReplayBuffer(data_dir=replay_logs_dir, 
                                replay_suffix=buffer_num,
                                observation_shape=(84, 84),
                                stack_size=stack_size, 
                                update_horizon=1, 
                                gamma=0.99, 
                                observation_dtype=np.uint8, 
                                batch_size=32, 
                                replay_capacity=replay_capacity)
        if frb.loaded:
            max_num_transitions = replay_capacity
            trajectories_to_load, eof = trajectories_per_buffer, False
            inner_step = 0
            while not eof:
                inner_step += 1
                pbar.set_description(f'[{name}]: Load Trajectory: {inner_step}/x', refresh=False)
                (state, action, ret), terminal = frb.sample_transition(i)
                if map_ale_actions:
                    action = to_ale_action(name, action)
                obs.append(state)
                actions.append(action)
                rewards.append(ret)
                if terminal:
                    done_idxs.append(len(obs))
                    eof = (trajectories_to_load == 0)
                    trajectories_to_load = max(0, trajectories_to_load - 1)
                    
                i += 1
                if i >= replay_capacity:
                    done_idxs = list(filter(lambda idx: idx <= max_num_transitions, done_idxs))
                    lb_eps_length = min(get_episode_length(done_idxs))
                    last_eps_length = max_num_transitions - done_idxs[-1]
                    if last_eps_length >= lb_eps_length:
                        done_idxs.append(max_num_transitions)
                    else:
                        max_num_transitions = done_idxs[-1]
                    obs, actions, rewards = \
                                            obs[:max_num_transitions], actions[:max_num_transitions], rewards[:max_num_transitions]          
                    i = transitions_per_buffer[buffer_num]
                    break
            num_trajectories += (trajectories_per_buffer - trajectories_to_load)
            transitions_per_buffer[buffer_num] = i

        log = f'This buffer has {i} loaded transitions and there are now {len(obs)} transitions total divided into {num_trajectories} trajectories'
        pbar.set_description(f'[{name}]: num_iterations={num_iterations} | {log}', refresh=False)
        pbar.update(len(obs) - num_obs_preload)

    pbar.close()
    
    obs, actions, rewards, done_idxs =\
        np.array(obs), np.array(actions), np.array(rewards), np.array(done_idxs)
    
    data = (obs, actions, rewards, done_idxs)
    log_returns_stats(name, rewards, done_idxs)

    if cache_path is not None:
        helpers.pickle_save(data, cache_path, protocol=4)

    num_empty_frames = (obs.max((1, 2, 3)) == 0.0).sum()
    logger.info(f'{name} > No. Empty Frames: {num_empty_frames}')
    return data 

class AtariReplayDataset(BaseDataset):

    def __init__(self, replay_logs_dir, num_steps, start_buffer, num_buffers, stack_size, image_size=(84, 84), num_frames=16, 
                 num_clips=None, channel_last=False, use_dynamic_range=False, split_size=None, first_split=False, overlap_ratios=None, 
                 unlabeled_ratio=None, unknown_action=None, map_ale_actions=False, use_strl=False, low_contrast_mode=None, cache_dir=None, cachefile_prefix=None):
        assert (low_contrast_mode is None) or (low_contrast_mode in {'keep_semantics'})
        self.use_strl = use_strl
        self.low_contrast_mode = low_contrast_mode
        self.augmenter =\
            TemporalAugmentation(max_num_frames=2*num_frames, channel_last=False, 
                                    p_scale=0.95, p_geo=0.95, max_value=255.0)

        self.replay_logs_dir = replay_logs_dir
        self.num_steps = num_steps
        self.unlabeled_ratio = unlabeled_ratio
        
        if isinstance(replay_logs_dir, list):
            frames, actions, rewards, done_idxs =  [], [], [], []
            original_actions = []
            last_idx = 0
            for i in range(0, len(replay_logs_dir)):
                cachefile = cachefile_prefix[i] if cachefile_prefix else None
                iframes, iactions, irewards, idone_idxs = create_dataset(replay_logs_dir[i], num_steps[i], start_buffer=start_buffer,
                                                                         num_buffers=num_buffers, stack_size=stack_size,
                                                                         map_ale_actions=map_ale_actions, cache_dir=cache_dir, 
                                                                         cachefile_prefix=cachefile)
                
                ioriginal_actions = copy.deepcopy(iactions)
                if (unlabeled_ratio is not None) and (unlabeled_ratio[i] is not None):
                    unlabeled_indices = sample_indices(len(iactions), unlabeled_ratio[i])
                    iactions[unlabeled_indices] = -1
            
                frames.append(iframes)
                actions.append(iactions)
                original_actions.append(ioriginal_actions)
                rewards.append(irewards)
                done_idxs.append(list(map(lambda idx: idx + last_idx, idone_idxs)))
                last_idx += idone_idxs[-1]
                                                
            frames = np.concatenate(frames, axis=0)
            actions = np.concatenate(actions, axis=0)
            original_actions = np.concatenate(original_actions, axis=0)
            rewards = np.concatenate(rewards, axis=0)
            done_idxs = np.concatenate(done_idxs, axis=0)

        else:
            frames, actions, rewards, done_idxs = create_dataset(replay_logs_dir, num_steps, start_buffer=start_buffer, 
                                                                 num_buffers=num_buffers, stack_size=stack_size, 
                                                                 map_ale_actions=map_ale_actions, cache_dir=cache_dir, 
                                                                 cachefile_prefix=cachefile_prefix)
            original_actions = copy.deepcopy(actions)
            if unlabeled_ratio is not None:
                unlabeled_indices = sample_indices(len(actions), unlabeled_ratio)
                actions[unlabeled_indices] = -1

        unlabeled_count = len(actions[actions == -1])
        if unlabeled_count > 0:
            logger.info(f'Unlabled: {unlabeled_count}')
            logger.info(f'Labeled: {len(actions) - unlabeled_count}')
        
        actions[actions == -1] = (actions.max() + 1) if (unknown_action is None) else unknown_action
        
        self.action_space = list(np.unique(actions))
        logger.info(f'Actions Space: {self.action_space}')
                
        indices = None
        if (split_size is not None) and first_split:
            logger.info('First Split ...')
            (indices, done_idxs), _ = make_split(done_idxs, split_size, seed=0)
        elif (split_size is not None) and (not first_split):
            logger.info('Second Split ...')
            _, (indices, done_idxs) = make_split(done_idxs, split_size, seed=0)
        
        extra_data_dict = {'original_action': original_actions}

        super().__init__(frames=frames, actions=actions, rewards=rewards, extra_data_dict=extra_data_dict, terminal_idxs=done_idxs, 
                         image_size=image_size, num_frames=num_frames, frame_rate=1, skip_frames=0, num_clips=num_clips, overlap_ratios=overlap_ratios, 
                         channel_last=channel_last, use_dynamic_range=use_dynamic_range, indices=indices, actions_discrete=True)

    def process_observation(self, index):
        frame = processing.cv_process(self.frames[index], self.image_size, self.channel_last)
        return frame
