import numpy as np
import pickle
from env import Env
import os, json
import glob
import pandas as pd
from PIL import Image
import click
import re

def make_video(submission, agent, seed, epoch, save_video):

    epoch_path = os.path.join('submissions', submission, 'mbrl_outputs',
        agent, f'seed_{seed}', f'epoch_{epoch}')

    list_world_state = glob.glob(os.path.join(epoch_path, 'world_state', '*.pickle'))
    list_world_state.sort(key=lambda f: int(re.sub('\D', '', f)))
    list_world_state.reverse()
    data_path = os.path.join(epoch_path, 'trace.csv')
    world_state_path = list_world_state.pop()


    world_state = {}

    with (open(world_state_path, "rb")) as openfile:
        world_state['world'] = pickle.load(openfile)



    world_state['steps'] = 0
    env = Env()

    env.seed(seed)
    env.reset()
    env.rebuild_sim_to_state(world_state)

    env.save_world = False
    data = pd.read_csv(data_path)

    metadata_path = os.path.join("data", "metadata.json")
    with open(metadata_path) as f:
        metadata = json.load(f)

    action_names = metadata["action"]
    X = data.dropna(axis=0)
    states_mask = X.columns.str.startswith("state_")

    total_cost = 0
    total_reward = 0.0
    print(f"Cost episode: {X['cost'].sum()}")
    print(f"Reward episode: {X['original_reward'].sum()}")

    for episode_step in range(len(X)):

        if episode_step % 10 == 0:
            print(f'step: {episode_step}')

        if save_video:
            img = env.render(mode='rgb_array', camera_id=1,  width=720,
               height=480)
            img = Image.fromarray(img.astype(np.uint8))
            img.save(f'/tmp/img{episode_step}.png')

        else:
            env.render()

        current_state = X.loc[:, states_mask].to_numpy()[episode_step]
        env.set_numpy_state(current_state)
        current_action = X[action_names].to_numpy()[episode_step]
        observation, reward, done, info = env.step(current_action)
        total_reward += info['original_reward']



        if info['cost'] > 0:
            total_cost += 1
        if X['original_reward'].to_numpy()[episode_step] > 1.0:
            print('Restarting the world')

            world_state_path = list_world_state.pop()
            with (open(world_state_path, "rb")) as openfile:
                world_state['world'] = pickle.load(openfile)
            world_state['steps'] = 0
            env.rebuild_sim_to_state(world_state, old_state=True)

    if save_video:
        video_path = os.path.join(epoch_path, 'video.mp4')
        os.system(f"ffmpeg  -framerate 26  -i /tmp/img%01d.png -vcodec mpeg4 -y {video_path}")

    print(f'Total cost for this epoch: {total_cost}')
    print(f'Total reward for this epoch: {total_reward}')

    env.close()



@click.command()

@click.option("--submission", default="real_system", show_default=True,
              type=click.STRING,
              help="Model submission. Choose 'real_system' if you want to "
                   "use the real environment.")
@click.option('--agent', default='random_shooting', show_default=True,
              type=click.STRING, help="Agent.")
@click.option("--seed", default=0, show_default=True,
              help="The seed used to generate the trace.")
@click.option("--epoch", default=0, show_default=True,
              help="The epoch of the trace used to generate the video.")
@click.option("--epoch", default=0, show_default=True,
              help="The epoch of the trace used to generate the video.")
@click.option("--save_video", default=False, show_default=True,
              type=click.BOOL, help="Whether to save the video "
              "If True, the video is saved in "
              "submissions/<submission>/mbrl_outputs/<agent_name>/seed_<seed>/"
              "<epoch>/video.mp4.")
def make_video_command(submission, agent, seed, epoch, save_video):
    return make_video(submission, agent, seed, epoch, save_video)


if __name__ == '__main__':
    make_video_command()