import os
from typing import Dict

import attr
import numpy as np

import rlf.rl.utils as rutils



@attr.s(auto_attribs=True, slots=True)
class RunResult:
    prefix: str
    eval_result: Dict = {}


def run_policy(run_settings, runner=None):
    # this part is only to check args.bc_init
    from rlf.args import get_default_parser
    default_parser = get_default_parser()
    default_args, rest = default_parser.parse_known_args()

    if default_args.bc_init:
        from rlf.algos.il.bc import BehavioralCloning   
        algo = BehavioralCloning()
        runner_bc_init = run_settings.create_runner(algo=algo)
        args_bc_init = runner_bc_init.args
        end_update_bc_init = runner_bc_init.updater.get_num_updates()
        policy_init_from_bc, num_steps_pretrain = run_policy_from_args(run_settings, runner_bc_init, end_update_bc_init, 
                                                                       args_bc_init, close_log=False)
    else:
        policy_init_from_bc = None
        num_steps_pretrain = 0
    
    runner = run_settings.create_runner()
    args = runner.args
    end_update = runner.updater.get_num_updates()
    return run_policy_from_args(run_settings, runner, end_update, 
                                args, close_log=True, policy_load=policy_init_from_bc, num_steps_pretrain=num_steps_pretrain)



def run_policy_from_args(run_settings, runner, end_update, args, policy_load=None, close_log=True, num_steps_pretrain=0):
    if args.ray:
        raise NotImplementedError("Ray is not supported in this version of the code.")
        import ray
        from ray import tune

        # Release resources as they will be recreated by Ray
        runner.close()

        use_config = eval(args.ray_config)
        use_config["cwd"] = os.getcwd()
        use_config = run_settings.get_add_ray_config(use_config)

        rutils.pstart_sep()
        print("Running ray for %i updates per run" % end_update)
        rutils.pend_sep()

        ray.init(local_mode=args.ray_debug)
        tune.run(
            type(run_settings),
            resources_per_trial={"cpu": args.ray_cpus, "gpu": args.ray_gpus},
            stop={"training_iteration": end_update},
            num_samples=args.ray_nsamples,
            global_checkpoint_period=np.inf,
            config=use_config,
            **run_settings.get_add_ray_kwargs()
        )
    else:
        if policy_load is not None:
            runner.policy = policy_load
        args = runner.args

        if runner.should_load_from_checkpoint():
            runner.load_from_checkpoint()
        

        if args.eval_gen_heatmap:
            runner.gen_heatmap()
            return RunResult(prefix=args.prefix)

        if args.eval_only:
            eval_result = runner.full_eval(run_settings.create_traj_saver)
            return RunResult(prefix=args.prefix, eval_result=eval_result)

        if args.visualize_minigrid_path:
            eval_result = runner.visualize_minigrid_path(run_settings.create_traj_saver)
            return RunResult(prefix=args.prefix, eval_result=eval_result)

        start_update = 0
        if args.resume:
            start_update = runner.resume()

        runner.setup()
        print("RL Training (%d/%d)" % (start_update, end_update))

        if runner.should_start_with_eval:
            runner.eval(-1)

        # Initialize outside the loop just in case there are no updates.
        j = 0
        for j in range(start_update, end_update):
            updater_log_vals = runner.training_iter(j)
            if args.log_interval > 0 and (j + 1) % args.log_interval == 0:
                log_dict = runner.log_vals(updater_log_vals, j, num_steps_pretrain=num_steps_pretrain)
            if args.save_interval > 0 and (j + 1) % args.save_interval == 0:
                runner.save(j)
            if args.eval_interval > 0 and (j + 1) % args.eval_interval == 0:
                runner.eval(j)
        
        num_steps_pretrain_final = runner.updater.get_completed_update_steps(end_update + 1)

        if args.eval_interval > 0:
            runner.eval(j + 1)
        if args.save_interval > 0:
            runner.save(j + 1)
        policy = runner.policy
        runner.close(close_log=close_log)
        return policy, num_steps_pretrain_final



