from datetime import datetime
import random
import yaml

import torch
import numpy as np
from yacs.config import CfgNode as CN

# Copied from d4rl.infos
REF_SCORES = {'Hopper': {'min': -20.272305, 'max': 3234.3},
              'Hopper-v2': {'min': -20.272305, 'max': 3234.3},
              'Hopper-v3': {'min': -20.272305, 'max': 3234.3},
              'HalfCheetah': {'min': -280.178953, 'max': 12135.0},
              'HalfCheetah-v2': {'min': -280.178953, 'max': 12135.0},
              'HalfCheetah-v3': {'min': -280.178953, 'max': 12135.0},
              'Walker2d': {'min': 1.629008, 'max': 4592.3},
              'Walker2d-v2': {'min': 1.629008, 'max': 4592.3},
              'Walker2d-v3': {'min': 1.629008, 'max': 4592.3}}


def get_date_time_str(add_hash=True):
    now = datetime.now()
    return_str = 'date_%s_time_%s' % (now.strftime('%d_%m_%Y'), now.strftime('%H_%M'))
    if add_hash:
        return_str = '%s_hash_%s' % (return_str, now.strftime('%f'))
    return return_str


def set_seed(seed, fully_deterministic=True):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        if fully_deterministic:
            torch.backends.cudnn.deterministic = True


def save_config(config, path):
    def convert_config_to_dict(cfg_node, key_list):
        if not isinstance(cfg_node, CN):
            return cfg_node

        cfg_dict = dict(cfg_node)
        for k, v in cfg_dict.items():
            cfg_dict[k] = convert_config_to_dict(v, key_list + [k])
        return cfg_dict

    config_dict = convert_config_to_dict(config, [])
    with open(path, 'w') as f:
        yaml.dump(config_dict, f, default_flow_style=False)
    return config_dict


def get_d4rl_normalized_score(type, score):
    if type in REF_SCORES:
        norm_score = (score - REF_SCORES[type]['min']) / (REF_SCORES[type]['max'] - REF_SCORES[type]['min'])
    else:
        norm_score = score
    return norm_score


def get_eval_statistics(all_rewards, env_type):
    avg_reward = np.mean(all_rewards)
    std_reward = np.std(all_rewards)

    # Normalize to D4RL scores
    normalized_scores = get_d4rl_normalized_score(env_type, np.array(all_rewards)) * 100.0
    avg_norm_reward = np.mean(normalized_scores)
    std_norm_reward = np.std(normalized_scores)
    return avg_reward, std_reward, avg_norm_reward, std_norm_reward


def num(s):
    try:
        return int(s)
    except ValueError:
        return float(s)


class Dict2Class(object):
    def __init__(self, my_dict):
        for key in my_dict:
            setattr(self, key.replace('-', '_'), my_dict[key])
