import argparse
import os
import math
from typing import Dict, Tuple

import numpy as np
import ray
import torchvision

# from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks
from ray.rllib.algorithms.callbacks import DefaultCallbacks, MultiCallbacks

from ray.rllib.models import ModelCatalog
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.typing import PolicyID, AgentID

from model_hetjoippo_mp import PolicyHetJOIPPO

# from ray.rllib.algorithms.ppo import PPO as PPOTrainer
from ray.rllib.agents.ppo import PPOTrainer

import wandb
from ray.rllib import RolloutWorker, Policy, SampleBatch
from ray.rllib.evaluation import Episode
from ray.tune import register_env

from ray.air.integrations.wandb import WandbLoggerCallback
# from ray.tune.integration.wandb import WandbLoggerCallback

from meltingpot.python import substrate
from examples.rllib import utils

from scenario_config import SCENARIO_CONFIG
from config import Config

import time
import torch

class RenderingCallbacks(DefaultCallbacks):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def on_postprocess_trajectory(
            self,
            *,
            worker: "RolloutWorker",
            episode: Episode,
            agent_id: AgentID,
            policy_id: PolicyID,
            policies: Dict[PolicyID, Policy],
            postprocessed_batch: SampleBatch,
            original_batches: Dict[AgentID, Tuple[Policy, SampleBatch]],
            **kwargs,
    ) -> None:
        obs = np.copy(postprocessed_batch["obs"]["WORLD.RGB"])
        vid = np.transpose(obs, (0, 3, 1, 2))
        episode.media["rendering"] = wandb.Video(
            vid, fps=5, format="mp4"
        )


class SAECheckpointCallbacks(DefaultCallbacks):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def on_postprocess_trajectory(
        self,
        *,
        worker: "RolloutWorker",
        episode: Episode,
        agent_id: AgentID,
        policy_id: PolicyID,
        policies: Dict[PolicyID, Policy],
        postprocessed_batch: SampleBatch,
        original_batches: Dict[AgentID, Tuple[Policy, SampleBatch]],
        **kwargs,
    ) -> None:
        time_str = time.strftime("%Y%m%d-%H%M%S")
        pisa = policies[policy_id].model.pisa
        cnn = policies[policy_id].model.cnn_autoencoder
        file_str = f"weights/pisa_policy_wk{worker.worker_index}_{time_str}.pt"
        torch.save(pisa.state_dict(), file_str)
        file_str = f"weights/cnn_policy_wk{worker.worker_index}_{time_str}.pt"
        torch.save(cnn.state_dict(), file_str)
        print(f"Saved CNN AE + PISA to {file_str}")


class ReconstructionLossCallbacks(DefaultCallbacks):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def on_postprocess_trajectory(
            self,
            *,
            worker: "RolloutWorker",
            episode: Episode,
            agent_id: AgentID,
            policy_id: PolicyID,
            policies: Dict[PolicyID, Policy],
            postprocessed_batch: SampleBatch,
            original_batches: Dict[AgentID, Tuple[Policy, SampleBatch]],
            **kwargs,
    ) -> None:
        pi = policies[policy_id]

        obs = torch.stack(
            [
                torch.tensor(
                    postprocessed_batch["obs"]["all"][f"player_{i}"]
                ).permute(0, 3, 1, 2).float()
                for i in range(pi.model.n_agents)
            ],
            dim=1
        )  # [batches, agents, obs_size] (hopefully)
        obs /= 255.0
        n_batches = obs.shape[0]

        obs = torch.flatten(obs, start_dim=0, end_dim=1)  # [batches * agents, obs_size]
        batch = torch.arange(
            n_batches, device=obs.device
        ).repeat_interleave(pi.model.n_agents)

        cnn_autoencoder = pi.model.cnn_autoencoder
        pisa = pi.model.pisa
        enc_obs = cnn_autoencoder.encode(obs)
        x_recon, _ = pisa(enc_obs, batch=batch)
        dec_obs = cnn_autoencoder.decode(enc_obs)
        dec_xr = cnn_autoencoder.decode(x_recon)
        cnn_loss = torch.nn.functional.mse_loss(obs, dec_obs)
        sae_loss = pisa.loss()["loss"]

        if torch.is_tensor(sae_loss):
            sae_loss = sae_loss.item()
        if torch.is_tensor(cnn_loss):
            cnn_loss = cnn_loss.item()

        # Add if not worker or create if it is. Tracks running means.
        episode.custom_metrics[f"{policy_id}/sae_loss"] = sae_loss
        episode.custom_metrics[f"{policy_id}/vae_loss"] = cnn_loss

        # Visualise reconstruction
        b = 4
        viz = torchvision.utils.make_grid(
            torch.cat([obs[-b:], dec_obs[-b:], dec_xr[-b:]], dim=0), nrow=b
        )
        episode.media["sample_reconstruction"] = wandb.Image(viz)

def setup_callbacks(**kwargs):
    if kwargs["excalibur"]:
        callbacks = []
        if kwargs["train_specific"]:
            # Checkpoint PISA when trained with policy loss
            callbacks.insert(0, SAECheckpointCallbacks)
        return callbacks
    else:
        callbacks = [RenderingCallbacks]
        if not kwargs["no_comms"]:
            # Log AE / PISA loss when they are being used
            callbacks.insert(0, ReconstructionLossCallbacks)
        if kwargs["train_specific"]:
            # Checkpoint PISA when trained with policy loss
            callbacks.insert(0, SAECheckpointCallbacks)
        return callbacks

def setup_meltingpot_policies(**kwargs):

    player_roles = substrate.get_config(
        kwargs["scenario"]
    ).default_player_roles
    register_env("meltingpot", utils.env_creator)
    env_config = {"substrate": kwargs["scenario"], "roles": player_roles}
    test_env = utils.env_creator(env_config)

    single_obs_size = math.prod(
        test_env.observation_space['player_0']['RGB'].shape
    )

    # Setup PPO with policies, one per entry in default player roles.
    policies = {}
    player_to_agent = {}
    for i in range(len(player_roles)):
        policies[f"agent_{i}"] = PolicySpec(
            policy_class=None,  # Use default policy
            observation_space=test_env.observation_space[f"player_{i}"],
            action_space=test_env.action_space[f"player_{i}"],
        )
        player_to_agent[f"player_{i}"] = f"agent_{i}"

    def policy_mapping_fn(agent_id, episode, worker, **kwargs):
        del kwargs
        return player_to_agent[agent_id]
    
    return policies, player_roles, single_obs_size, policy_mapping_fn

def policy(**kwargs):

    ModelCatalog.register_custom_model("policy_net", PolicyHetJOIPPO)
    callbacks = setup_callbacks(**kwargs)
    if not ray.is_initialized():
        if kwargs["excalibur"]:
            ray.init(address="auto")
        else:
            ray.init()
        print("ray intialised.")

    policies, player_roles, single_obs_size, policy_mapping_fn = setup_meltingpot_policies(**kwargs)

    if Config.device == 'cuda':
        num_gpus = 1  # Driver GPU
    else:
        num_gpus = 0

    # Determine mode
    if kwargs["task_agnostic"]:
        mode = "task_agnostic"
    elif kwargs["task_specific"]:
        mode = "task_specific"
    elif kwargs["train_specific"]:
        mode = "train_specific"
    else:
        mode = "no_comms"

    print("\n\n-----------------------------------------------------------\n\n")
    print(f"experiment type = {mode}")
    print(f"device = {Config.device}")
    print(f"substrate = {kwargs['scenario']}")
    print(f"seed = {kwargs['seed']}")
    print(f"cnn path = {kwargs['cnn_path']}")
    print(f"pisa path = {kwargs['pisa_path']}")
    print(f"pisa latent dim = {kwargs['pisa_dim']}")
    print(f"excalibur = {kwargs['excalibur']}")
    print("\n\n-----------------------------------------------------------\n\n")

    ray.tune.run(
        PPOTrainer,
        local_dir="~/ray_results",
        name=kwargs["resume"] if kwargs["resume"] is not None else f"PPO_{time.strftime('%Y%m%d-%H%M%S')}",
        resume=kwargs["resume"] is not None,
        stop={"training_iteration": kwargs["training_iterations"]},
        checkpoint_freq=1,
        keep_checkpoints_num=2,
        checkpoint_at_end=True,
        checkpoint_score_attr="episode_reward_mean",
        callbacks=[
            WandbLoggerCallback(
                project=Config.WANDB_PROJECT,
                name=f"{kwargs['scenario']}+{mode}+{kwargs['seed']}",
                entity=Config.WANDB_ENTITY,
                api_key="",
            )
        ],
        config={
            # "_enable_rl_module_api": False,
            # "_enable_learner_api": False,

            "seed": kwargs["seed"],
            "framework": "torch",
            "env": "meltingpot",
            "render_env": False,
            "train_batch_size": kwargs["train_batch_size"],
            "rollout_fragment_length": kwargs["rollout_fragment_length"],
            "sgd_minibatch_size": kwargs["sgd_minibatch_size"],
            "num_gpus": num_gpus,
            "num_workers": kwargs["num_workers"],
            "num_cpus_per_worker": kwargs["num_cpus_per_worker"],
            "model": {
                "custom_model": "policy_net",
                "custom_model_config": {
                    **kwargs,
                    "cnn_path": os.path.abspath(kwargs["cnn_path"]) if kwargs["cnn_path"] is not None else kwargs["cnn_path"],
                    "pisa_path": os.path.abspath(kwargs["pisa_path"]) if kwargs["pisa_path"] is not None else kwargs["pisa_path"],
                    "single_obs_size": single_obs_size,
                    "wandb_grouping": f"{kwargs['scenario']}+{mode}",
                },
            },
            "preprocessor_pref": None,  # Use raw observations
            "multiagent": {
                "policies": policies,
                "policy_mapping_fn": policy_mapping_fn,
            },
            "env_config": {
                "substrate": kwargs["scenario"],
                "roles": player_roles,
            },
            "evaluation_interval": kwargs["eval_interval"],
            "evaluation_duration": 1,
            "evaluation_num_workers": 1,
            "evaluation_parallel_to_training": True,
            "evaluation_config": {
                "num_envs_per_worker": 1,
                "callbacks": MultiCallbacks(callbacks),
            },
        },
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog='Train policy with SAE')

    # Modes
    parser.add_argument('--task_agnostic', action='store_true', default=False, help='Task-agnostic pre-trained PISA experiment')
    parser.add_argument('--task_specific', action='store_true', default=False, help='Reused pre-trained PISA experiment')
    parser.add_argument('--train_specific', action='store_true', default=False, help='Train PISA with policy losses experiment')
    parser.add_argument('--no_comms', action='store_true', default=False, help='No communications experiment')

    # Required
    parser.add_argument('--scenario', type=str, default=None, help='MeltingPot scenario')
    parser.add_argument('--pisa_dim', type=int, default=None, help='PISA latent state dimensionality') # FIXME: Is this required? Can't we infer it?
    parser.add_argument('--cnn_path', type=str, default=None, help='Path to CNN autoencoder state dict')
    parser.add_argument('--pisa_path', type=str, default=None, help='Path to PISA autoencoder state dict')
    parser.add_argument('--seed', type=int, default=None)

    # Optional
    parser.add_argument('--resume', type=str, default=None, help="Name of run to resume")
    parser.add_argument('--excalibur', action='store_true', default=False, help='Disable callbacks for compatibility on excalibur/HPC')
    parser.add_argument('--train_batch_size', default=6400, type=int, help='Train batch size')
    parser.add_argument('--sgd_minibatch_size', default=128, type=int, help='SGD minibatch size')
    parser.add_argument('--training_iterations', default=5000, type=int, help='Number of training iterations')
    parser.add_argument('--rollout_fragment_length', default=100, type=int, help='Rollout fragment length')
    parser.add_argument('--eval_interval', default=10, type=int, help='Evaluation interval')
    parser.add_argument('--num_workers', default=2, type=int)
    parser.add_argument('--num_cpus_per_worker', default=1, type=int)
    parser.add_argument('--device', default='cuda', type=str)

    args = parser.parse_args()

    # Check valid argument configuration
    assert args.task_agnostic or args.task_specific or args.train_specific or args.no_comms, "No experiment mode specified"
    assert args.scenario is not None, "--scenario not specified"
    assert args.pisa_dim is not None, "--pisa_dim not specified"
    assert args.seed is not None, "--seed not specified"
    assert args.cnn_path, "--cnn_path not specified"
    if args.task_agnostic or args.task_specific:
        assert args.pisa_path, "--pisa_path not specified"

    # Set global configuration
    Config.device = args.device

    policy(max_steps=SCENARIO_CONFIG[args.scenario]["max_steps"], **vars(args))
