from ensurepip import version
import logging
import wandb

import numpy as np


train_logger = logging.getLogger('train')
test_logger = logging.getLogger('train_test')


def _log(config, step_count, log_data, replay_buffer, lr, shared_storage, wandb_run, vis_result):
    # print("[DEBUG] call _log")
    loss_data, td_data, priority_data, other_data, version_gap_dist = log_data
    total_loss, weighted_loss, loss, reg_loss, policy_loss, value_prefix_loss, value_loss, consistency_loss = loss_data
    re_policy_version_gap, re_value_version_gap, alive_ratio = other_data
    if vis_result:
        new_priority, target_value_prefix, target_value, trans_target_value_prefix, trans_target_value, target_value_prefix_phi, target_value_phi, \
        pred_value_prefix, pred_value, target_policies, predicted_policies, state_lst, other_loss, other_log, other_dist = td_data
        # batch_weights, batch_indices = priority_data

    # print("[DEBUG] before call buffer")
    # print("[DEBUG] call buffer replay_episodes_collected", replay_buffer.episodes_collected())
    replay_episodes_collected, replay_buffer_size, total_num, (num_fresh_entries, fresh_entries_prob), (sampling_entropy, high_ratio, (prob_pos_nstep_r, prob_zero_nstep_r, prob_neg_nstep_r)) = [
        replay_buffer.episodes_collected(), replay_buffer.size(),
        replay_buffer.get_total_len(), replay_buffer.get_fresh_log(), replay_buffer.get_sampling_log()]

    # print("[DEBUG] before call shared_storage")
    worker_logs = shared_storage.get_worker_logs()

    worker_ori_reward, worker_reward, worker_reward_max, worker_eps_len, worker_eps_len_max, test_counter, test_dict, temperature, visit_entropy, priority_self_play, distributions, last_eps_logs = worker_logs
    last_eps_len, last_eps_ori_reward, last_eps_reward = last_eps_logs


    # print("[DEBUG] after call shared_storage")
    _msg = '#{:<10} Total Loss: {:<8.3f} [weighted Loss:{:<8.3f} Policy Loss: {:<8.3f} Value Loss: {:<8.3f} ' \
           'Reward Sum Loss: {:<8.3f} Consistency Loss: {:<8.3f} ] ' \
           'Replay Episodes Collected: {:<10d} Buffer Size: {:<10d} Transition Number: {:<8.3f}k ' \
           'Batch Size: {:<10d} Lr: {:<8.3f}'
    _msg = _msg.format(step_count, total_loss, weighted_loss, policy_loss, value_loss, value_prefix_loss, consistency_loss,
                       replay_episodes_collected, replay_buffer_size, total_num / 1000, config.batch_size, lr)
    train_logger.info(_msg)

    if test_dict is not None:
        mean_score = np.mean(test_dict['mean_score'])
        max_score = np.mean(test_dict['max_score'])
        min_score = np.mean(test_dict['min_score'])
        std_score = np.mean(test_dict['std_score'])
        test_msg = '#{:<10} Test Mean Score of {}: {:<10} (max: {:<10}, min:{:<10}, std: {:<10})' \
                   ''.format(test_counter, config.env_name, mean_score, max_score, min_score, std_score)
        test_logger.info(test_msg)

    if wandb_run is not None:
        tag = 'Train'
        if vis_result:
            log_data = dict()

            target_value_prefix = target_value_prefix.flatten()
            pred_value_prefix = pred_value_prefix.flatten()
            target_value = target_value.flatten()
            pred_value = pred_value.flatten()

            log_data['{}_statistics/target_value_prefix_mean'.format(tag)] = target_value_prefix.mean()
            log_data['{}_statistics/target_value_prefix_std'.format(tag)] = target_value_prefix.std()
            log_data['{}_statistics/pre_value_prefix_mean'.format(tag)] = pred_value_prefix.mean()
            log_data['{}_statistics/pre_value_prefix_std'.format(tag)] = pred_value_prefix.std()

            log_data['{}_statistics/target_value_mean'.format(tag)] = target_value.mean()
            log_data['{}_statistics/target_value_std'.format(tag)] = target_value.std()
            log_data['{}_statistics/pre_value_mean'.format(tag)] = pred_value.mean()
            log_data['{}_statistics/pre_value_std'.format(tag)] = pred_value.std()

            for key, val in other_loss.items():
                if val >= 0:
                    log_data['{}_metric/'.format(tag) + key] = val

            for key, val in other_log.items():
                log_data['{}_weight/'.format(tag) + key] = val

        log_data['{}/total_loss'.format(tag)] = total_loss
        log_data['{}/loss'.format(tag)] = loss
        log_data['{}/weighted_loss'.format(tag)] = weighted_loss
        log_data['{}/reg_loss'.format(tag)] = reg_loss
        log_data['{}/policy_loss'.format(tag)] = policy_loss
        log_data['{}/value_loss'.format(tag)] = value_loss
        log_data['{}/value_prefix_loss'.format(tag)] = value_prefix_loss
        log_data['{}/consistency_loss'.format(tag)] = consistency_loss
        log_data['{}/episodes_collected'.format(tag)] = replay_episodes_collected
        log_data['{}/replay_buffer_len'.format(tag)] = replay_buffer_size
        log_data['{}/total_node_num'.format(tag)] = total_num
        log_data['{}/lr'.format(tag)] = lr
        
        log_data['{}/re_policy_version_gap'.format(tag)] = re_policy_version_gap.mean()
        log_data['{}/re_policy_version_gap_max'.format(tag)] = re_policy_version_gap.max()
        log_data['{}/re_policy_version_gap_min'.format(tag)] = re_policy_version_gap.min()
        log_data['{}/re_value_version_gap'.format(tag)] = re_value_version_gap.mean()
        log_data['{}/re_value_version_gap_max'.format(tag)] = re_value_version_gap.max()
        log_data['{}/re_value_version_gap_min'.format(tag)] = re_value_version_gap.min()
        log_data['{}/alive_ratio'.format(tag)] = alive_ratio

        log_data['{}_replay_buffer/num_fresh_entries'.format(tag)] = num_fresh_entries
        log_data['{}_replay_buffer/fresh_entries_prob'.format(tag)] = fresh_entries_prob
        log_data['{}_replay_buffer/sampling_entropy'.format(tag)] = sampling_entropy
        log_data['{}_replay_buffer/high_ratio'.format(tag)] = high_ratio
        log_data['{}_replay_buffer/prob_pos_nstep_r'.format(tag)] = prob_pos_nstep_r
        log_data['{}_replay_buffer/prob_zero_nstep_r'.format(tag)] = prob_zero_nstep_r
        log_data['{}_replay_buffer/prob_neg_nstep_r'.format(tag)] = prob_neg_nstep_r

        if worker_reward is not None:
            log_data['workers/ori_reward'] = worker_ori_reward
            log_data['workers/clip_reward'] = worker_reward
            log_data['workers/clip_reward_max'] = worker_reward_max
            log_data['workers/eps_len'] = worker_eps_len
            log_data['workers/eps_len_max'] = worker_eps_len_max
            log_data['workers/temperature'] = temperature
            log_data['workers/visit_entropy'] = visit_entropy
            log_data['workers/priority_self_play'] = priority_self_play
            for key, val in distributions.items():
                if len(val) == 0:
                    continue
        #  last_eps_len, last_eps_ori_reward, last_eps_reward
        log_data['workers/last_eps_len'] = last_eps_len
        log_data['workers/last_eps_ori_reward'] = last_eps_ori_reward
        log_data['workers/last_eps_reward'] = last_eps_reward

        wandb_run.log(log_data, step=step_count)

        if version_gap_dist is not None:
            version_gap_dist = version_gap_dist / (version_gap_dist.sum() + 1e-6)
            table = wandb.Table(data=version_gap_dist, columns=["version_gap_dist"])
            wandb.log({'version_gap_dist': wandb.plot.histogram(table, "version_gap_dist", title="version_gap_dist")})


        if test_dict is not None:
            wandb_run.log({'train/test_counter': test_counter}, step=step_count)
            for key, val in test_dict.items():
                wandb_run.log({'train/{}'.format(key): np.mean(val)}, step=step_count)


def _log_worker(config, step_count, watchdog_server, replay_buffer, shared_storage, wandb_run):
    total_transitions = replay_buffer.get_total_len()
    reanalyze_batch_count = watchdog_server.get_reanalyze_batch_count()

    log_data = dict()
    log_data["worker/reanalyze_batch_count"] = reanalyze_batch_count
    log_data["worker/total_transitions"] = total_transitions

    if wandb_run is not None:
        wandb_run.log(log_data)
    