import numpy as np
import os
import collections
from os.path import dirname, abspath
from copy import deepcopy
from sacred import Experiment, SETTINGS
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds
import sys
import torch
import logging
import random
import yaml
from types import SimpleNamespace as SN
import pprint
from algos import *
from utils import *
from buffer import *
from data_process import *
from evaluation import w_offline_ab

root_path = ""

def get_logger():
    logger = logging.getLogger()
    logger.handlers = []
    ch = logging.StreamHandler()
    formatter = logging.Formatter(
        '[%(levelname)s %(asctime)s] %(name)s %(message)s', '%H:%M:%S')
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    logger.setLevel('DEBUG')
    return logger


# set to "no" if you want to see stdout/stderr in console
SETTINGS['CAPTURE_MODE'] = "fd"
logger = get_logger()

ex = Experiment()
ex.logger = logger
ex.captured_out_filter = apply_backspaces_and_linefeeds

results_path = os.path.join(root_path, "results")


@ex.main
def my_main(_run, _config, _log):
    # Setting the random seed throughout the modules
    random.seed(_config["seed"])
    np.random.seed(_config["seed"])
    torch.manual_seed(_config["seed"])
    # run the framework
    run(_run, _config, _log)


def parse_config_file(params):
    config_file = "ddpg"
    for i, v in enumerate(params):
        if v.split("=")[0] == "--config":
            config_file = v.split("=")[1]
            del params[i]
            break
    return config_file

def Create_Policy(args):
    if args.algo=="ddpg":
        return DDPG.DDPG(args)
    if args.algo=="td3":
        return TD3.TD3(args)
    if args.algo=="td3_bc":
        return TD3_BC.TD3_BC(args)
    if args.algo=="bcq":
        return BCQ.BCQ(args)
    if args.algo=="sl":
        return SL.SL(args)
    if args.algo=="iql":
        return IQL.IQL(args)
    if args.algo=="resact":
        return ResAct.ResAct(args)
    if args.algo=="sl_vae":
        return SL_VAE.SL_VAE(args)

def run(_run, _config, _log):
    args = SN(**_config)
    args.ex_results_path = os.path.join(args.ex_results_path, str(_run._id))
    # setup loggers
    logger = Logger(_log)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config,
                                    indent=4,
                                    width=1)
    _log.info("\n\n" + experiment_params + "\n")

    if args.use_tensorboard:
        logger.setup_tb(args.ex_results_path)

    # sacred is on by default
    logger.setup_sacred(_run)

    start_time = time.time()
    last_time = start_time
    logger.console_logger.info("Beginning training for {} epochs.".format(args.n_epoch))
    
    last_test_step=-args.test_every-1
    last_save_step=0
    last_log_step=0

    policy=Create_Policy(args) 
    
    test_json_path="" # path to the test json
    with open(test_json_path,'r') as f:
        test_json=json.load(f)
    buffer_path="" # path to the buffer
    best_return=-1000000.
    steps_till_now=0
    
    replay_buffer=ReplayBuffer(args)
    replay_buffer.load(buffer_path)
    for n in range(args.n_epoch):
        for e in range(int(replay_buffer.size//args.batch_size)):
            loss=policy.train(replay_buffer, args.batch_size)
            overall_step=steps_till_now+e

            for key in loss.keys():
                logger.log_stat(key, loss[key].item(),overall_step)
            if args.save_model and (overall_step-last_test_step) / args.test_every >= 1.0:
                returns=[]
                for test_user in tqdm(list(test_json.keys())[:args.test_n_user]):
                    his=test_json[test_user]
                    R=w_offline_ab(policy, his, immediate=args.test_immediate)
                    if R: # filter out invalid user
                        returns.append(R)
                returns=np.array(returns).mean().item()
                logger.console_logger.info(f"Perform evaluation at {overall_step} steps: {returns}.")
                logger.log_stat("test_return", returns, overall_step)
                if returns > best_return:
                    best_return=returns
                    save_path = os.path.join(
                        args.ex_results_path, "models/")
                    os.makedirs(save_path, exist_ok=True)
                    policy.save(save_path)
                last_test_step=overall_step
            if (overall_step-last_log_step) / args.log_every >= 1.0:
                logger.console_logger.info("Estimated time left: {}. Time passed: {}".format(
                time_left(last_time, last_log_step, overall_step, args.n_epoch*int(replay_buffer.size//args.batch_size)), time_str(time.time() - start_time)))
                last_time = time.time()
                logger.log_stat("steps", overall_step, overall_step)
                logger.print_recent_stats()
                last_log_step=overall_step
        steps_till_now+=int(replay_buffer.size//args.batch_size)
            
if __name__ == '__main__':
    params = deepcopy(sys.argv)
    config_file = parse_config_file(params)
    ex.add_config(f'') # add path to the config file
    logger.info(
        f"Saving to FileStorageObserver in {root_path}/results/{config_file}.")
    file_obs_path = os.path.join(results_path, config_file)
    ex.add_config(name=config_file)
    ex.add_config(ex_results_path=file_obs_path)
    ex.observers.append(FileStorageObserver.create(file_obs_path))
    ex.run_commandline(params)
