# NOTE: Must source the mp_venv environment to use this file

import argparse
import time

import torch
from torchvision.utils import save_image

from math import prod
import random

import matplotlib
matplotlib.use("QtAgg")
# matplotlib.use('webagg')
import matplotlib.pyplot as plt
# plt.interactive(False)

from meltingpot.python import substrate
from examples.rllib import utils

import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../"))
from config import Config
from scenario_config import SCENARIO_CONFIG


def _generate_random_action(previous_act, n_actions, num_envs, drift=0.8):
    return random.randint(0, n_actions)


def sample(
        scenario_name,
        steps,
        num_envs,
        render
):
    init_time = time.time()

    num_agents = SCENARIO_CONFIG[scenario_name]["num_agents"]
    reset_after = SCENARIO_CONFIG[scenario_name]["reset_after"]

    # Construct substrate
    player_roles = substrate.get_config(scenario_name).default_player_roles
    env_config = {"substrate": scenario_name, "roles": player_roles}
    env = utils.env_creator(env_config)
    obs_size = env.observation_space['player_0']['RGB'].shape
    num_actions = env.action_space['player_0'].n - 1

    from pathlib import Path


    timestr = time.strftime("%Y%m%d-%H%M%S")

    prev_act = [None for _ in range(num_agents)]
    for s in range(steps):

        # Generate action
        actions = {}
        for i in range(num_agents):
            act = _generate_random_action(prev_act[i], num_actions, num_envs)
            actions[f'player_{i}'] = act
            prev_act[i] = act

        obs, rewards, dones, _, _ = env.step(actions)
        # print(rewards)  # To get a basic understanding of how sparse the environment is

        Path(f"./samples/{scenario_name}_{timestr}/{s}").mkdir(parents=True, exist_ok=True)
        for i in range(num_agents):
            img = torch.tensor(obs[f'player_{i}']['RGB']).permute(2, 0, 1) / 255.0
            save_image(img, f'samples/{scenario_name}_{timestr}/{s}/{i}.png')

        # Reset environment if done
        if dones['__all__'] is True:
            env.reset()

        # Reset environment after a while to ensure we don't sample crazily out-of-distribution
        if reset_after is not None:
            if s % reset_after == 0:
                env.reset()

        if render:
            rgb = env.render(mode="rgb_array")
            plt.imshow(rgb)
            plt.pause(0.01)

        if s % 10 == 0:
            print(f"{s}/{steps}")

    if render:
        plt.show()

    total_time = time.time() - init_time
    print(
        f"It took: {total_time}s for {steps} steps of {num_envs} parallel environments on device {Config.device}"
    )


if __name__ == "__main__":
    # Parse sampling arguments
    parser = argparse.ArgumentParser(prog='Sample observations randomly from VMAS scenarios')
    parser.add_argument('-c', '--scenario', default=None, help='VMAS scenario')
    parser.add_argument('--steps', default=200, type=int, help='number of sampling steps')
    parser.add_argument('--num_envs', default=1, type=int, help='vectorized environments to sample from')
    parser.add_argument('--render', action='store_true', default=False, help='render scenario while sampling')
    parser.add_argument('-d', '--device', default='cuda')
    args = parser.parse_args()

    # Set global configuration
    Config.device = args.device

    sample(
        args.scenario,
        args.steps,
        args.num_envs,
        args.render,
    )
