#!/usr/bin/env python3

"""
Generate a set of agent demonstrations.

The agent can either be a trained model or the heuristic expert (bot).

Demonstration generation can take a long time, but it can be parallelized
if you have a cluster at your disposal. Provide a script that launches
make_agent_demos.py at your cluster as --job-script and the number of jobs as --jobs.


"""

import argparse
import gym
import logging
import sys
import subprocess
import os
import time
import numpy as np
import blosc
import torch
import pdb
import random
import copy

import babyai.utils as utils

# Parse arguments

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--env", required=True,
                    help="name of the environment to be run (REQUIRED)")
parser.add_argument("--model", default='BOT',
                    help="name of the trained model (REQUIRED)")
parser.add_argument("--demos", default=None,
                    help="path to save demonstrations (based on --model and --origin by default)")
parser.add_argument("--episodes", type=int, default=1000000,
                    help="number of episodes to generate demonstrations for")
parser.add_argument("--valid-episodes", type=int, default=512,
                    help="number of validation episodes to generate demonstrations for")
parser.add_argument("--seed", type=int, default=0,
                    help="start random seed")
parser.add_argument("--argmax", action="store_true", default=False,
                    help="action with highest probability is selected")
parser.add_argument("--log-interval", type=int, default=100,
                    help="interval between progress reports")
parser.add_argument("--save-interval", type=int, default=2500,
                    help="interval between demonstrations saving")
parser.add_argument("--filter-steps", type=int, default=0,
                    help="filter out demos with number of steps more than filter-steps")
parser.add_argument("--on-exception", type=str, default='warn', choices=('warn', 'crash'),
                    help="How to handle exceptions during demo generation")

parser.add_argument("--job-script", type=str, default=None,
                    help="The script that launches make_agent_demos.py at a cluster.")
parser.add_argument("--jobs", type=int, default=0,
                    help="Split generation in that many jobs")

args = parser.parse_args()
logger = logging.getLogger(__name__)

# Set seed for all randomness sources


def print_demo_lengths(demos):
    num_frames_per_episode = [len(demo[2]) for demo in demos]
    logger.info('Demo length: {:.3f}+-{:.3f}'.format(
        np.mean(num_frames_per_episode), np.std(num_frames_per_episode)))


def generate_demos(n_episodes, valid, seed, shift=0):
    utils.seed(seed)

    # Generate environment
    env = gym.make(args.env)
    save_count = 0
    agent = utils.load_agent(env, args.model, args.demos, 'agent', args.argmax, args.env)
    demos_path = utils.get_demos_path(args.demos, args.env, 'agent_suboptimal', valid)
    demos = []
    count = 0
    checkpoint_time = time.time()

    just_crashed = False
    while True:
        if count == n_episodes:
            break

        done = False
        if just_crashed:
            logger.info("reset the environment to find a mission that the bot can solve")
            env.reset()
        else:
            env.seed(seed + count)

        
        while (True):
            obs = env.reset()
            agent.on_reset()
            goal = agent.get_goal_state()
            if (len(goal) == 1):
                #print(goal)
                goal = list(goal[0])
                break
        
        actions = []
        mission = obs["mission"]
        full_obs = env.gen_full_obs()
        goal_state = None
        cur_states = []
        next_states = []
        images = []
        action_temp = [None]*6
        try:
            while not done:
                cur_state = env.gen_agent_pos() + [obs['direction']] + goal
                cur_states.append(cur_state)
                #pos_and_direction = env.gen_agent_pos() + [obs['direction']] + goal
                action = agent.act(obs)['action']
                
                random_action = random.randint(0,5)
                random_prob = random.random() > 0.6
                if (random_prob and action_temp[random_action] is not None):
                    random_action = copy.deepcopy(action_temp[random_action])
                else:
                    #action = agent.act(obs)['action']
                    if (action_temp[action.value] is None):
                        action_temp[action.value] = copy.deepcopy(action)
                
                #pdb.set_trace()
                if isinstance(action, torch.Tensor):
                    action = action.item()

                if (random_prob and action_temp[random_action] is not None):
                    if isinstance(random_action, torch.Tensor):
                        random_action = random_action.item()
                    #random_action = random_action.item()
                    new_obs, reward, done, _ = env.step(random_action)
                else:
                    new_obs, reward, done, _ = env.step(action)
                agent.analyze_feedback(reward, done)

                images.append(full_obs)
                next_state = env.gen_agent_pos() + [new_obs['direction']] + goal
                next_states.append(next_state)
                actions.append(action)
                full_obs = env.gen_full_obs()
                obs = new_obs
            if reward > 0 and (args.filter_steps == 0 or len(actions) <= args.filter_steps):
                goal_state = env.gen_agent_pos() + [obs['direction']]
                #demos.append((mission, blosc.pack_array(np.array(images)), directions, actions))
                #assert abs(goal_state[0] - goal[0]) <= 1 
                #assert abs(goal_state[1] - goal[1]) <= 1
                #print(f"goal_state {goal_state}, goal {goal}")
                #assert len(next_states) == len(cur_states)
                #pdb.set_trace()
                
                demos.append((mission, blosc.pack_array(np.array(images)), cur_states, actions, next_states))
                just_crashed = False
                count += 1

            if reward == 0:
                if args.on_exception == 'crash':
                    raise Exception("mission failed, the seed is {}".format(seed + count))
                just_crashed = True
                logger.info("mission failed")
        except (Exception, AssertionError):
            if args.on_exception == 'crash':
                raise
            just_crashed = True
            logger.exception("error while generating demo #{}".format(count))
            continue
        
        '''
        if len(demos) and len(demos) % args.log_interval == 0:
            now = time.time()
            demos_per_second = args.log_interval / (now - checkpoint_time)
            to_go = (n_episodes - len(demos)) / demos_per_second
            logger.info("demo #{}, {:.3f} demos per second, {:.3f} seconds to go".format(
                len(demos) - 1, demos_per_second, to_go))
            checkpoint_time = now
        '''
        # Save demonstrations

        if args.save_interval > 0 and count < n_episodes and count % args.save_interval == 0:
            logger.info("Saving demos...")
            utils.save_demos(demos, demos_path)
            logger.info("{} demos saved".format(count))
            # print statistics for the last 100 demonstrations
            print_demo_lengths(demos[-100:])

        if (count % 50000 == 0):
            utils.save_demos(demos, demos_path)
            save_count += 1
            demos_path = utils.get_demos_path(args.demos, args.env, 'agent'+str(save_count), valid)
            demos = []
            print(f"save 50000 {save_count}")

    # Save demonstrations
    logger.info("Saving demos...")
    utils.save_demos(demos, demos_path)
    logger.info("{} demos saved".format(count))
    print_demo_lengths(demos[-100:])


def generate_demos_cluster():
    demos_per_job = args.episodes // args.jobs
    demos_path = utils.get_demos_path(args.demos, args.env, 'agent')
    job_demo_names = [os.path.realpath(demos_path + '.shard{}'.format(i))
                     for i in range(args.jobs)]
    for demo_name in job_demo_names:
        job_demos_path = utils.get_demos_path(demo_name)
        if os.path.exists(job_demos_path):
            os.remove(job_demos_path)

    command = [args.job_script]
    command += sys.argv[1:]
    for i in range(args.jobs):
        cmd_i = list(map(str,
            command
              + ['--seed', args.seed + i * demos_per_job]
              + ['--demos', job_demo_names[i]]
              + ['--episodes', demos_per_job]
              + ['--jobs', 0]
              + ['--valid-episodes', 0]))
        logger.info('LAUNCH COMMAND')
        logger.info(cmd_i)
        output = subprocess.check_output(cmd_i)
        logger.info('LAUNCH OUTPUT')
        logger.info(output.decode('utf-8'))

    job_demos = [None] * args.jobs
    while True:
        jobs_done = 0
        for i in range(args.jobs):
            if job_demos[i] is None or len(job_demos[i]) < demos_per_job:
                try:
                    logger.info("Trying to load shard {}".format(i))
                    job_demos[i] = utils.load_demos(utils.get_demos_path(job_demo_names[i]))
                    logger.info("{} demos ready in shard {}".format(
                        len(job_demos[i]), i))
                except Exception:
                    logger.exception("Failed to load the shard")
            if job_demos[i] and len(job_demos[i]) == demos_per_job:
                jobs_done += 1
        logger.info("{} out of {} shards done".format(jobs_done, args.jobs))
        if jobs_done == args.jobs:
            break
        logger.info("sleep for 60 seconds")
        time.sleep(60)

    # Training demos
    all_demos = []
    for demos in job_demos:
        all_demos.extend(demos)
    utils.save_demos(all_demos, demos_path)

#

logging.basicConfig(level='INFO', format="%(asctime)s: %(levelname)s: %(message)s")
logger.info(args)
# Training demos
if args.jobs == 0:
    generate_demos(args.episodes, False, args.seed)
else:
    generate_demos_cluster()
# Validation demos
#if args.valid_episodes:
#    generate_demos(args.valid_episodes, True, int(1e9))
