import os
import platform

import numpy as np
from ray.rllib.agents.callbacks import DefaultCallbacks as RllibCallbackBase
from ray.tune.callback import Callback as TuneCallbackBase


class CustomMetricCallback(RllibCallbackBase):
    DEFAULT_CUSTOM_METRICS = [
        'team_reward', 'individual_reward', 'total_reward',
        'vis_reward',
        'reconstruction_reward',

        'distance_reward',
        'centering_reward',
        'iot_reward',
        'obstruction_reward',
        'mpjpe_3d',
        'state_reward',
        # 'reward_peak',
        'anti_collision_reward',

        # 'pck3d_5',
        # 'pck3d_15',
        # 'pck3d_25',
        # 'pck3d_35',

        'pck3d_5',
        'pck3d_10',
        'pck3d_15',
        'pck3d_20',
        'pck3d_25',
        'pck3d_30',
        'pck3d_35',
        'pck3d_40',
        'pck3d_45',
        'pck3d_50',
        'pck3d_60',
        'pck3d_70',
        'pck3d_80',
        'pck3d_90',
        'pck3d_100',
        'pck3d_110',

        'ex_pck3d_5',
        'ex_pck3d_10',
        'ex_pck3d_15',
        'ex_pck3d_20',
        'ex_pck3d_25',
        'ex_pck3d_30',
        'ex_pck3d_35',
        'ex_pck3d_40',
        'ex_pck3d_45',
        'ex_pck3d_50',
        'ex_pck3d_60',
        'ex_pck3d_70',
        'ex_pck3d_80',
        'ex_pck3d_90',
        'ex_pck3d_100',
        'ex_pck3d_110',

        'min_partial_mpjpe_2c',
        'min_partial_mpjpe_3c',
        'min_partial_mpjpe_4c',

        'min_partial_pck20_2c',
        'min_partial_pck20_3c',
        'min_partial_pck20_4c',

        'mpjpe_0_best_rate', 'mpjpe_0_best_diff',
        'mpjpe_1_best_rate', 'mpjpe_1_best_diff',
        'mpjpe_2_best_rate', 'mpjpe_2_best_diff',

        'mpjpe_01_best_rate', 'mpjpe_01_best_diff',
        'mpjpe_12_best_rate', 'mpjpe_12_best_diff',
        'mpjpe_20_best_rate', 'mpjpe_20_best_diff',
        'mpjpe_012_best_rate', 'mpjpe_012_best_diff',

        'lost_joints_ratio',

        'reward_running_mean', 'reward_running_stddev', 'reward_normalized'
    ]

    def __init__(self, custom_metrics=None):
        super().__init__()

        self.custom_metrics = custom_metrics or self.DEFAULT_CUSTOM_METRICS

    def on_episode_start(self, *, worker, base_env, policies, episode, env_index, **kwargs):

        self.group_agents = True if policies.policy_config['env_config']['algo'].upper() == "QMIX" else False

        for key in self.custom_metrics:
            episode.user_data[key] = []
        episode.user_data['num_cameras'] = None
        episode.user_data['num_humans'] = None

    def on_episode_step(self, *, worker, base_env, episode, env_index, **kwargs):

        agent_infos = list(map(episode.last_info_for, episode.get_agents()))

        if self.group_agents:
            team = list(episode._agent_to_index.keys())[-1]
            agent_infos = episode._agent_to_last_info[team]['_group_info']

        for key in self.custom_metrics:
            values = []
            for info in agent_infos:
                try:
                    values.append(info[key])
                except KeyError:
                    pass
                if episode.user_data['num_cameras'] is None:
                    episode.user_data['num_cameras'] = info['num_cameras']
                if episode.user_data['num_humans'] is None:
                    episode.user_data['num_humans'] = info['num_humans']

            if len(values) > 0:
                episode.user_data[key].append(np.mean(values))

    def on_episode_end(self, *, worker, base_env, policies, episode, env_index, **kwargs):
        suffixes = ['']
        num_cameras = episode.user_data['num_cameras']
        num_humans = episode.user_data['num_humans']
        if num_cameras is not None and num_cameras is not None:
            suffixes.append(f'_{num_cameras}c{num_humans}h')

        for suffix in suffixes:
            for key in self.custom_metrics:
                episode.custom_metrics[f'{key}{suffix}'] = float(np.mean(episode.user_data[key]))
                if ('reward' in key or 'mpjpe' in key) and not key.startswith('episode'):
                    episode.custom_metrics[f'episode_{key}{suffix}'] = float(np.sum(episode.user_data[key]))
                    episode.custom_metrics[f'last_{key}{suffix}'] = float(np.mean(episode.user_data[key][-10:]))
                if key == 'mpjpe_3d':
                    episode.custom_metrics[f'{key}_stddev{suffix}'] = float(np.std(episode.user_data[key], ddof=1))
                    success_rate_interval = list(range(0, 21)) + list(range(30,110,10))
                    for thresh in success_rate_interval:
                        success_rate = float(np.mean(np.asarray(episode.user_data['mpjpe_3d']) <= thresh))
                        episode.custom_metrics[f'episode_mpjpe_3d_success_rate_{thresh}{suffix}'] = success_rate


class SymlinkCheckpointCallback(TuneCallbackBase):

    def on_checkpoint(self, iteration, trials, trial, checkpoint, **info):
        source = checkpoint.value
        for target_dir in (trial.logdir, trial.local_dir):
            target = os.path.join(target_dir, 'latest-checkpoint')
            print('Symlink "{}" to "{}".'.format(source, target))
            self.symlink(source, target)

    @staticmethod
    def symlink(source, target):
        temp_target = '{}.temp'.format(target)

        os_symlink = getattr(os, "symlink", None)
        if callable(os_symlink):
            os_symlink(source, temp_target)
        elif platform.system() == 'Windows':
            import ctypes
            csl = ctypes.windll.kernel32.CreateSymbolicLinkW
            csl.argtypes = (ctypes.c_wchar_p, ctypes.c_wchar_p, ctypes.c_uint32)
            csl.restype = ctypes.c_ubyte
            flags = 1 if os.path.isdir(source) else 0
            if csl(temp_target, source, flags) == 0:
                raise ctypes.WinError('Cannot create symlink "{}" to "{}".'.format(source, target))
        else:
            raise OSError('Cannot create symlink "{}" to "{}".'.format(source, target))

        os.replace(temp_target, target)
