import copy
import importlib
import os
import pickle
from pydoc import describe
import time
from glob import glob
import argparse

import ray
import json
import torch
from torch.utils.tensorboard import SummaryWriter

from models import muzero_models
import replay_buffer
import self_play
import shared_storage
import trainer
import evaluator
from games.base_game import Game

import io
import PIL.Image
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

os.environ["OMP_NUM_THREADS"] = "1"  # Necessary for multithreading.


@ray.remote(num_cpus=0, num_gpus=0)
class CPUActor:
    # Trick to force DataParallel to stay on CPU to get weights on CPU even if there is a GPU
    def __init__(self):
        pass

    def get_initial_weights(self, config):
        model = muzero_models.MuZeroLinesModel(config.network, **config.model_config)
        weigths = model.get_weights()
        summary = str(model).replace("\n", " \n\n")
        return weigths, summary


parser = argparse.ArgumentParser(description="PyTorch Scalable Agent")
parser.add_argument("--game_name", default="cityengine", type=str, help="Game to play (choose dataset).")
parser.add_argument("--gpus_reanalyse", default=1, type=int, help="Num GPUs for reanalyze worker.")
parser.add_argument("--gpus_actors", default=0, type=int, help="Number of GPUs actors.")
parser.add_argument("--gpus_evaluator", default=1, type=int, help="Use GPU for the evaluator")
parser.add_argument("--gpus_training_worker", default=1, type=int, help="GPUs for the trainer worker.")
parser.add_argument("--resume_checkpoint", default=None, type=str, help="Checkpoint to resume.")
parser.add_argument("--output_folder_name", default="", type=str)
parser.add_argument("--skip_load_optimizer", default=0, type=int)


def train(flags):
    flags.skip_load_optimizer = flags.skip_load_optimizer > 0

    game_module = importlib.import_module("games." + flags.game_name)
    config = game_module.MuZeroConfig(flags.output_folder_name)

    total_gpus = (
        flags.gpus_reanalyse + flags.gpus_actors + flags.gpus_evaluator
    ) + flags.gpus_training_worker

    assert total_gpus <= torch.cuda.device_count(), "Not eoungh gpus."
    print("A total of {} gpus requested".format(total_gpus))

    ray.init(num_gpus=total_gpus, ignore_reinit_error=True)

    # Checkpoint and replay buffer used to initialize workers
    checkpoint = {
        "weights": None,
        "optimizer_state": None,
        "total_reward": 0,
        "muzero_reward": 0,
        "opponent_reward": 0,
        "episode_length": 0,
        "mean_value": 0,
        "training_step": 0,
        "slow_lr": 0,
        "fast_lr": 0,
        "total_loss": 0,
        "value_loss": 0,
        "reward_loss": 0,
        "policy_loss": 0,
        "consistency_loss": 0,
        "num_played_games": 0,
        "num_played_steps": 0,
        "num_reanalysed_games": 0,
        "terminate": False,
        "end_reward_percentage": 0.0,
        "negative_reward_percentage": 0.0,
    }

    if flags.resume_checkpoint is not None:
        # override checkpoint ..
        config.model_config["checkpoint"] = flags.resume_checkpoint
        checkpoint.update(torch.load(flags.resume_checkpoint, map_location="cpu"))

    checkpoint["replay_buffer_size"] = 0

    replay_buffer_dict = {}

    cpu_actor = CPUActor.remote()
    cpu_weights = cpu_actor.get_initial_weights.remote(config)
    checkpoint["weights"], summary = copy.deepcopy(ray.get(cpu_weights))

    # Workers
    self_play_workers = None
    reanalyse_worker = None
    replay_buffer_worker = None
    shared_storage_worker = None

    os.makedirs(config.results_path, exist_ok=True)
    with open(os.path.join(config.results_path, "config.json"), "w") as f:
        json.dump(vars(config), f)

    ### Initialize workers
    training_worker = \
        trainer.Trainer.options(num_cpus=0, num_gpus=flags.gpus_training_worker,).remote(
            checkpoint,
            config,
            skip_load_optimizer=flags.skip_load_optimizer,
        )
        
    shared_storage_worker = shared_storage.SharedStorage.remote(
        checkpoint,
        config,
    )
    shared_storage_worker.set_info.remote("terminate", False)

    replay_buffer_worker = replay_buffer.ReplayBuffer.remote(
        checkpoint, replay_buffer_dict, config
    )

    evaluation_worker = evaluator.Evaluate.options(
        num_cpus=0,
        num_gpus=flags.gpus_evaluator,
    ).remote(
        Game,
        config,
    )

    if config.use_last_model_value:
        reanalyse_worker = replay_buffer.Reanalyse.options(
            num_cpus=0,
            num_gpus=flags.gpus_reanalyse,
        ).remote(checkpoint, config)

    self_play_workers = [
        self_play.SelfPlay.options(num_cpus=0, num_gpus=1).remote(
            Game,
            config,
            config.seed + seed,
        )
        for seed in range(min(config.num_workers, flags.gpus_actors))
    ] + [
        self_play.SelfPlay.options(num_cpus=0, num_gpus=0,).remote(
            Game,
            config,
            config.seed + seed,
        )
        for seed in range(
            min(config.num_workers, flags.gpus_actors),
            config.num_workers,
        )
    ]

    ### Launch workers
    evaluation_worker.continuous_self_evaluation.remote(shared_storage_worker)

    [
        self_play_worker.continuous_self_play.remote(
            shared_storage_worker, replay_buffer_worker
        )
        for self_play_worker in self_play_workers
    ]
    
    training_worker.continuous_update_weights.remote(replay_buffer_worker, shared_storage_worker)

    if config.use_last_model_value:
        reanalyse_worker.reanalyse.remote(replay_buffer_worker, shared_storage_worker)

    # Write everything in TensorBoard
    writer = SummaryWriter(config.results_path)

    # Save hyperparameters to TensorBoard
    hp_table = [f"| {key} | {value} |" for key, value in config.__dict__.items()]
    writer.add_text(
        "Hyperparameters",
        "| Parameter | Value |\n|-------|-------|\n" + "\n".join(hp_table),
    )
    # Save model representation
    writer.add_text(
        "Model summary",
        summary,
    )
    # Loop for updating the training performance
    keys = [
        "training_step",
        "slow_lr",
        "fast_lr",
        "total_loss",
        "value_loss",
        "reward_loss",
        "policy_loss",
        "consistency_loss",
        "num_played_games",
        "num_played_steps",
        "num_reanalysed_games",
        "negative_reward_percentage",
        "end_reward_percentage",
        "replay_buffer_size",
    ]
    info = ray.get(shared_storage_worker.get_info.remote(keys))
    last_evaluation_step = -1
    try:
        while info["training_step"] < config.training_steps:
            info = ray.get(shared_storage_worker.get_info.remote(keys))
            writer.add_scalar(
                "Workers/Played_games",
                info["num_played_games"],
                info["training_step"],
            )
            writer.add_scalar(
                "Workers/Training_steps",
                info["training_step"],
                info["training_step"],
            )
            writer.add_scalar(
                "Workers/played_steps",
                info["num_played_steps"],
                info["training_step"],
            )
            writer.add_scalar(
                "Workers/Reanalysed_games",
                info["num_reanalysed_games"],
                info["training_step"],
            )
            writer.add_scalar(
                "Workers/Training_steps_per_played_step_ratio",
                info["training_step"] / max(1, info["num_played_steps"]),
                info["training_step"],
            )
            writer.add_scalar(
                "Workers/Slow_Learning_rate", info["slow_lr"], info["training_step"]
            )
            writer.add_scalar(
                "Workers/Fast_Learning_rate", info["fast_lr"], info["training_step"]
            )
            writer.add_scalar(
                "Workers/Replay_buffer_size",
                info["replay_buffer_size"],
                info["training_step"],
            )
            writer.add_scalar(
                "Workers/Negative_reward_percentage",
                info["negative_reward_percentage"],
                info["training_step"],
            )
            writer.add_scalar(
                "Workers/End_reward_percentage",
                info["end_reward_percentage"],
                info["training_step"],
            )
            writer.add_scalar(
                "Loss/Total_weighted_loss",
                info["total_loss"],
                info["training_step"],
            )
            writer.add_scalar(
                "Loss/consistency_loss",
                info["consistency_loss"],
                info["training_step"],
            )
            writer.add_scalar(
                "Loss/Value_loss", info["value_loss"], info["training_step"]
            )
            writer.add_scalar(
                "Loss/Reward_loss", info["reward_loss"], info["training_step"]
            )
            writer.add_scalar(
                "Loss/Policy_loss", info["policy_loss"], info["training_step"]
            )

            print(
                f'Training step: {info["training_step"]}/{config.training_steps}. Played games: {info["num_played_games"]}. Loss: {info["total_loss"]:.2f}',
                end="\r",
            )

            evaluation_info = ray.get(
                shared_storage_worker.get_evaluation_info.remote()
            )
            if (
                "evaluation_last_step" in evaluation_info
                and evaluation_info["evaluation_last_step"] > last_evaluation_step
            ):
                last_evaluation_step = evaluation_info["evaluation_last_step"]
                for key, value in evaluation_info.items():
                    if "rendered_game" in key:
                        video = value
                        fig, axes = plt.subplots(
                            nrows=video.shape[0], figsize=(12, 6 * len(video))
                        )
                        for img, ax in zip(video, axes):
                            ax.axis("off")
                            ax.imshow(img.cpu().numpy())

                        buf = io.BytesIO()
                        plt.savefig(buf, format="jpeg")
                        buf.seek(0)

                        image = PIL.Image.open(buf)
                        image = ToTensor()(image)  # .unsqueeze(0)

                        writer.add_image(key, image, last_evaluation_step)
                    else:
                        if key.endswith("hg"):
                            writer.add_histogram(key, value, last_evaluation_step)
                        else:
                            writer.add_scalar(key, value, last_evaluation_step)

            time.sleep(100)
    except KeyboardInterrupt:
        pass

    if shared_storage_worker:
        shared_storage_worker.set_info.remote("terminate", True)
        checkpoint = ray.get(shared_storage_worker.get_checkpoint.remote())
    if replay_buffer_worker:
        replay_buffer_dict = ray.get(replay_buffer_worker.get_buffer.remote())

    if config.save_model:
        # Persist replay buffer to disk
        print("\n\nPersisting replay buffer games to disk...")
        pickle.dump(
            {
                "buffer": replay_buffer_dict,
                "num_played_games": checkpoint["num_played_games"],
                "num_played_steps": checkpoint["num_played_steps"],
                "num_reanalysed_games": checkpoint["num_reanalysed_games"],
            },
            open(os.path.join(config.results_path, "replay_buffer.pkl"), "wb"),
        )

    ray.shutdown()


if __name__ == "__main__":
    flags = parser.parse_args()
    train(flags)
