# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import time
import os
import sys
import yaml
from subprocess import check_output, CalledProcessError
from utils.config import BaseFLAGS, expand, parse
from utils.Logger import logger, FileSink, CSVWriter


class FLAGS(BaseFLAGS):
    _initialized = False

    seed = 1500
    log_dir = None
    run_id = None
    algorithm = 'IQL_fix'
    message = ''
    num_data_dict = {'CliffWalking': 409600, 'Bandit': 12800}
    # BC Bandit 20000 CliffWalking 6000000
    dis_num_data_dict = {'CliffWalking': 6000000, 'Bandit': 20000}
    dis_num_data = 6000000

    class env(BaseFLAGS):
        id = 'CliffWalking'  # 'Bandit'
        # origin 40
        ns = 40
        na = 5
        # for Bandit, max_episode_steps = 3200, for CliffWalking, max_episode_steps = 800
        max_episode_steps = 3200
        init_dist_type = "NonUniform"

    class GTAL(BaseFLAGS):
        max_num_iterations = 100
        is_average = True

    class DisGTAL(BaseFLAGS):
        max_num_iter_dict = {'CliffWalking': 200, 'Bandit': 50}
        max_num_iterations = 1600
        train_policy_freq = 10
        reward_opt_type = 'PG'
        is_average = True

    class DisFEM(BaseFLAGS):
        max_num_iterations = 1600
        train_policy_freq = 10

    class GAIL(BaseFLAGS):
        max_num_iterations = 400
        train_policy_freq = 10
        is_average = True

    class DAgger(BaseFLAGS):
        max_num_iterations = 100
        num_data_ratio_dict = {'CliffWalking': 4, 'Bandit': 2}



    @classmethod
    def set_seed(cls):
        if cls.seed == 0:  # auto seed
            cls.seed = int.from_bytes(os.urandom(3), 'little') + 1  # never use seed 0 for RNG, 0 is for `urandom`
        logger.warning("Setting random seed to %s", cls.seed)

        import numpy as np
        # import tensorflow as tf  # we dont use tensorflow here
        import random
        np.random.seed(cls.seed)
        # tf.set_seed(cls.seed+1000)
        random.seed(cls.seed+2000)

    @classmethod
    def finalize(cls):
        log_dir = cls.log_dir
        if log_dir is None:
            run_id = cls.run_id
            if run_id is None:
                if cls.algorithm == 'disBC':
                    run_id = '{}-{}-{}-{}'.format(cls.algorithm,
                                                        cls.env.id,
                                                        cls.seed,
                                                  time.strftime('%Y-%m-%d-%H-%M-%S'))
                elif cls.algorithm == 'disDAgger':
                    run_id = '{}-{}-{}-{}'.format(cls.algorithm,
                                                     cls.env.id,
                                                     cls.seed,
                                                     time.strftime('%Y-%m-%d-%H-%M-%S'))
                else:
                    run_id = '{}-{}-{}-{}'.format(cls.algorithm,
                                                     cls.env.id,
                                                     cls.seed,
                                                     time.strftime('%Y-%m-%d-%H-%M-%S'))

            log_dir = os.path.join("logs", run_id)
            cls.log_dir = log_dir

        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        # if os.path.exists('.git'):
        #     for t in range(60):
        #         try:
        #             check_output(['git', 'checkout-index', '-a', '--prefix={}/src/'.format(cls.log_dir)])
        #             break
        #         except CalledProcessError:
        #             pass
        #         time.sleep(1)
        #     else:
        #         raise RuntimeError('Failed after 60 trials.')

        yaml.dump(cls.as_dict(), open(os.path.join(log_dir, 'config.yml'), 'w'), default_flow_style=False)
        # logger.add_sink(FileSink(os.path.join(log_dir, 'log.json')))
        logger.add_sink(FileSink(os.path.join(log_dir, 'log.txt')))
        logger.add_csvwriter(CSVWriter(os.path.join(log_dir, 'progress.csv')))
        logger.info("log_dir = %s", log_dir)

        # cls.set_frozen()


parse(FLAGS)

