import numpy as np
import pickle
import os
import datetime
import torch
import time
from .torch_utils import *
from pathlib import Path
import itertools
from collections import OrderedDict, Sequence


def run_steps(agent):
    config = agent.config
    agent_name = agent.__class__.__name__
    t0 = time.time()
    while True:
        if config.save_interval and not agent.total_steps % config.save_interval:
            agent.save('data/%s-%s-%d' % (agent_name, config.tag, agent.total_steps))
        if config.log_interval and (agent.total_steps % config.log_interval == 0):
            agent.logger.info('steps %d, %.2f steps/s' % (agent.total_steps, config.log_interval / (time.time() - t0)))
            t0 = time.time()
        if config.max_steps and agent.total_steps >= config.max_steps:
            print('Final evaluation !!!!!!!!!!!!!!!!!!!!!!')
            current_return = agent.eval_episodes()['episodic_return_test']
            print('Evaluation steps: {}, return: {:.2f}'.format(agent.total_steps, float(current_return)))
            if not os.path.exists('model'):
                os.makedirs('model')
            agent.save('model/{}/{}'.format(config.game_file, config.logtxtname))
            print('Model saved: model/{}/{}'.format(config.game_file, config.logtxtname))
            agent.close()
            break
        agent.step()
        agent.switch_task()


def run_steps_evaluate(agent):
    config = agent.config
    config.eval_episodes = 10
    # config.eval_episodes = 100
    agent_name = agent.__class__.__name__
    # (1) load model
    agent.load('model/{}/{}'.format(config.game_file, config.logtxtname))  # load model and normalized layer
    # (2) evaluate under env
    print('{}: Evaluation with {} eval episodes!'.format(agent_name, config.eval_episodes),
          time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
    # agent.eval_episodes()
    print('Plotting TSNE!')
    agent.TSNE()
    agent.close()

def get_time_str():
    return datetime.datetime.now().strftime("%y%m%d-%H%M%S")


def get_default_log_dir(name):
    return './log/%s-%s' % (name, get_time_str())


def mkdir(path):
    Path(path).mkdir(parents=True, exist_ok=True)


def close_obj(obj):
    if hasattr(obj, 'close'):
        obj.close()


def random_sample(indices, batch_size):
    indices = np.asarray(np.random.permutation(indices))
    batches = indices[:len(indices) // batch_size * batch_size].reshape(-1, batch_size)
    for batch in batches:
        yield batch
    r = len(indices) % batch_size
    if r:
        yield indices[-r:]


def is_plain_type(x):
    for t in [str, int, float, bool]:
        if isinstance(x, t):
            return True
    return False


def generate_tag(params):
    if 'tag' in params.keys():
        return
    game = params['game']
    params.setdefault('run', 0)
    run = params['run']
    del params['game']
    del params['run']
    str = ['%s_%s' % (k, v if is_plain_type(v) else v.__name__) for k, v in sorted(params.items())]
    tag = '%s-%s-run-%d' % (game, '-'.join(str), run)
    params['tag'] = tag
    params['game'] = game
    params['run'] = run


def translate(pattern):
    groups = pattern.split('.')
    pattern = ('\.').join(groups)
    return pattern


def split(a, n):
    k, m = divmod(len(a), n)
    return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))


class HyperParameter:
    def __init__(self, id, param):
        self.id = id
        self.param = dict()
        for key, item in param:
            self.param[key] = item

    def __str__(self):
        return str(self.id)

    def dict(self):
        return self.param


class HyperParameters(Sequence):
    def __init__(self, ordered_params):
        if not isinstance(ordered_params, OrderedDict):
            raise NotImplementedError
        params = []
        for key in ordered_params.keys():
            param = [[key, iterm] for iterm in ordered_params[key]]
            params.append(param)
        self.params = list(itertools.product(*params))

    def __getitem__(self, index):
        return HyperParameter(index, self.params[index])

    def __len__(self):
        return len(self.params)